diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..176a458f9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..1f011157e --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: []# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE/blank_issue.yml b/.github/ISSUE_TEMPLATE/blank_issue.yml new file mode 100644 index 000000000..bbd855958 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/blank_issue.yml @@ -0,0 +1,12 @@ +name: Blank Issue +description: Submit an issue about Tensorflow.NET. +labels: [Blank Issue] +body: + - type: textarea + id: description + attributes: + label: Description + description: Please describe the issue here. + placeholder: Description + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 000000000..14e237951 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,48 @@ +name: BUG Report +description: Report a BUG of Tensorflow.NET. +title: "[BUG Report]: " +labels: [bug-report] +body: + - type: markdown + attributes: + value: | + We welcome bug reports! Any unexpected behavior could be a BUG and this template help us gather the information to fix it. + - type: textarea + id: background + attributes: + label: Description + description: Please share a clear and concise description of the problem. + placeholder: Description + validations: + required: true + - type: textarea + id: repro-steps + attributes: + label: Reproduction Steps + description: | + Please include minimal steps to reproduce the problem if possible. E.g.: the smallest possible code snippet; or a small project, with steps to run it. It will greatly help us to locate the reason of the problem. + placeholder: Minimal Reproduction + validations: + required: false + - type: textarea + id: known-workarounds + attributes: + label: Known Workarounds + description: | + Please provide a description of any known workarounds. + placeholder: Known Workarounds + validations: + required: false + - type: textarea + id: configuration + attributes: + label: Configuration and Other Information + description: | + Please provide more information on your configuration: + * Which version of Tensorflow.NET is the code depending on? + * Which version of .NET runtime is the code running on? + * What is the OS? + * Any other information about this problem? + placeholder: Configuration + validations: + required: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/documention_issue.yml b/.github/ISSUE_TEMPLATE/documention_issue.yml new file mode 100644 index 000000000..f8a04e40f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documention_issue.yml @@ -0,0 +1,30 @@ +name: Documentation Issue +description: Report an issue about Tensorflow.NET ducumention or require a documention. +title: "[Documention Issue]: " +labels: [Documention Issue] +body: + - type: markdown + attributes: + value: | + Welcome to suggest to Tensorflow.NET documention! This template will help us gather the information we need to improve it. + - type: textarea + id: brief-description + attributes: + label: Brief Description + description: Please describe the problem or the requst for new documention here. + placeholder: Description + validations: + required: true + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: | + Please provide some alternative information here, if any. + placeholder: Alternatives + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for your contributing! diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 000000000..9ce3f1663 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,50 @@ +name: Feature Request +description: Request/Propose a new feature of Tensorflow.NET. +title: "[Feature Request]: " +labels: [feature-request] +body: + - type: markdown + attributes: + value: | + We welcome feature proposal/request! This template will help us gather the information we need to implement the new feature. + - type: textarea + id: background + attributes: + label: Background and Feature Description + description: Please describe the purpose and value of the new feature here. If the feature is linked to a specific problem, please describe it or put the link here. + placeholder: Purpose + validations: + required: true + - type: textarea + id: api-proposal + attributes: + label: API Definition and Usage + description: | + Please tell us the new API related to the requested feature, if any. + placeholder: API declaration (no method bodies) + value: | + ```cs + public Tensor NewFunc(Tensor x, int y); + + var result = NewFunc(input, index); + ``` + validations: + required: false + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: | + Please provide some alternative information of the feature, if any. For example, if you request a feature which depends on a specific device, please provide the device information. + placeholder: Alternatives + validations: + required: false + - type: textarea + id: risks + attributes: + label: Risks + description: | + Please mention any risks that to your knowledge the API proposal might entail, such as breaking changes, performance regressions, etc. + placeholder: Risks + validations: + required: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/performance_issue.yml b/.github/ISSUE_TEMPLATE/performance_issue.yml new file mode 100644 index 000000000..cbe86d329 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/performance_issue.yml @@ -0,0 +1,48 @@ +name: Performance Issue +description: Submit an issue about performance problem or regression of Tensorflow.NET. +title: "[Performance Issue]: " +labels: [Performance Issue] +body: + - type: markdown + attributes: + value: | + We welcome issues about Tensorflow.NET performance! This template will help us gather the information we need to locate the problem improve the performance. + - type: textarea + id: brief-description + attributes: + label: Brief Description + description: Please give a brief description about the performance issue here. + placeholder: Description + validations: + required: true + - type: textarea + id: device-and-context + attributes: + label: Device and Context + description: | + Please describe the device and context you used when you encounter the performance problem/regression. + placeholder: Device and Context + validations: + required: true + - type: textarea + id: benchmark + attributes: + label: Benchmark + description: | + We will appreciate it if you'd like to provide benchmark comparison of the performance issue. + placeholder: Benchmark + validations: + required: false + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: | + Please provide some alternative information of the performance issue here, if any. For example, we'll appreciate it if you'd like to provide the the code to reproduce the performance problem. + placeholder: Alternatives + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for your contributing! diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml new file mode 100644 index 000000000..ca38be340 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -0,0 +1,30 @@ +name: Question +description: Ask any question about Tensorflow.NET and discuss with community members. +title: "[Question]: " +labels: [Question] +body: + - type: markdown + attributes: + value: | + Any question about Tensorflow.NET is welcomed! This template will help us get your point. + - type: textarea + id: description + attributes: + label: Description + description: Please describe your question here. + placeholder: Description + validations: + required: true + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: | + Please provide some alternative information here, if any. + placeholder: Alternatives + validations: + required: false + - type: markdown + attributes: + value: | + We are always willing to answer your questions! diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml new file mode 100644 index 000000000..9fd34fc49 --- /dev/null +++ b/.github/workflows/build_and_test.yml @@ -0,0 +1,66 @@ +# This workflow will build a .NET project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-net + +name: build_and_test + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + types: ["opened", "reopened", "synchronize", "ready_for_review", "auto_merge_enabled"] + +jobs: + windows: + + runs-on: windows-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup .NET 6 + uses: actions/setup-dotnet@v3 + with: + dotnet-version: 6.0.x + - name: Restore dependencies + run: dotnet restore + - name: Build CPU version + run: dotnet build --no-restore + - name: Test CPU version + run: dotnet test --no-build --verbosity normal + - name: uninstall redist cpu for unit tests + run: dotnet remove tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist + - name: install redist gpu for unit tests + run: dotnet add tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Windows-GPU + - name: Restore dependencies + run: dotnet restore + - name: Build GPU version + run: dotnet build --no-restore +# - name: Test GPU version +# run: dotnet test --no-build --verbosity normal + + linux: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup .NET + uses: actions/setup-dotnet@v3 + with: + dotnet-version: 6.0.x + - name: Restore dependencies + run: dotnet restore + - name: Build CPU version + run: dotnet build --no-restore + - name: Test CPU version + run: dotnet test --no-build --verbosity normal + - name: uninstall redist cpu for unit tests + run: dotnet remove tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist + - name: install redist gpu for unit tests + run: dotnet add tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Linux-GPU + - name: Restore dependencies + run: dotnet restore + - name: Build GPU version + run: dotnet build --no-restore +# - name: Test GPU version +# run: dotnet test --no-build --verbosity normal diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..02601764c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,62 @@ +name: auto-release + +on: + workflow_run: + workflows: ["release-prepare"] + types: + - completed + +env: + MYGET_API_TOKEN: ${{ SECRETS.MYGET_API_KEY }} + GITHUB_TOKEN: ${{ SECRETS.RINNE_GITHUB_TOKEN }} + +jobs: + release_to_myget: + runs-on: windows-latest +# needs: run-semantic-release + + steps: + - uses: actions/checkout@v3 + - name: Setup .NET 6.0.x SDK + uses: actions/setup-dotnet@v3 + with: + dotnet-version: 6.0.x + + - name: Check .NET info + run: dotnet --info + + - name: Install dependencies + run: dotnet restore + + - name: Build solution + run: dotnet build -c Release --no-restore + + - name: Pack packages + run: | + git fetch --unshallow; + git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*"; + git fetch origin; + $LastTag = git describe --tags; + $DroppedTag = ($LastTag).TrimStart('v'); + echo "Last tag is: $DroppedTag"; + $Suffix = "-nightly" + $Version = "${DroppedTag}${Suffix}"; + echo "Publishing version: $Version"; + dotnet pack ./src/TensorFlowNET.Core/Tensorflow.Binding.csproj -c Release -o packages /p:PackageVersion=$Version /p:Version=$Version; + dotnet pack ./src/TensorFlowNET.Keras/Tensorflow.Keras.csproj -c Release -o packages /p:PackageVersion=$Version /p:Version=$Version; + dotnet pack ./src/TensorflowNET.Hub/Tensorflow.Hub.csproj -c Release -o packages /p:PackageVersion=$Version /p:Version=$Version; + + if($LastExitCode -ne 0) + { + Write-Warning -Message "Pack packages warming, last exit code is ${LastExitCode}." + $LastExitCode = 0; + } + + - name: Upload packages artifacts + uses: actions/upload-artifact@v4.0.0 + with: + name: "drop-ci-packages" + path: './packages' + + - name: Push TensorFlow.NET to myget.org + run: dotnet nuget push .\packages\TensorFlow*.nupkg --source https://www.myget.org/F/scisharp/api/v3/index.json -k ${{ secrets.MYGET_API_KEY }} --skip-duplicate diff --git a/.github/workflows/release_prepare.yml b/.github/workflows/release_prepare.yml new file mode 100644 index 000000000..b21c6665c --- /dev/null +++ b/.github/workflows/release_prepare.yml @@ -0,0 +1,46 @@ +name: release-prepare + +on: + pull_request: + branches: + - master + types: [ closed ] + +env: + MYGET_API_TOKEN: ${{ SECRETS.MYGET_API_KEY }} + GITHUB_TOKEN: ${{ SECRETS.RINNE_GITHUB_TOKEN }} + +jobs: + build: + if: contains(github.event.pull_request.labels.*.name, 'auto-release') + runs-on: windows-latest + + steps: + - uses: actions/checkout@v3 + - name: Setup .NET 6.0.x SDK + uses: actions/setup-dotnet@v3 + with: + dotnet-version: 6.0.x + + - name: Check .NET info + run: dotnet --info + + - name: Install dependencies + run: dotnet restore + + - name: Build solution + run: dotnet build -c Release --no-restore + +# run-semantic-release: +# runs-on: ubuntu-latest +# needs: build + +# steps: +# - name: Checkout +# uses: actions/checkout@v2 + +# - name: Run semantic-release +# run: | +# export PATH=$PATH:$(yarn global bin) +# yarn global add semantic-release@17.4.3 +# semantic-release \ No newline at end of file diff --git a/.github/workflows/semantic.yml b/.github/workflows/semantic.yml new file mode 100644 index 000000000..db8c06a3e --- /dev/null +++ b/.github/workflows/semantic.yml @@ -0,0 +1,17 @@ +name: Semantic + +on: + pull_request: + branches: [ "master" ] + +jobs: + semantic-pull-request: + name: Semantic check + runs-on: windows-latest + steps: + - name: semantic-pull-request + uses: amannn/action-semantic-pull-request@v4 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + validateSingleCommit: true diff --git a/.gitignore b/.gitignore index 1a6a75a22..231d8379a 100644 --- a/.gitignore +++ b/.gitignore @@ -62,7 +62,6 @@ StyleCopReport.xml *_p.c *_i.h *.ilk -*.meta *.obj *.iobj *.pch @@ -328,9 +327,15 @@ ASALocalRun/ # MFractors (Xamarin productivity tool) working folder .mfractor/ -/tensorflowlib/win7-x64/native/libtensorflow.dll -/tensorflowlib/osx/native/libtensorflow_framework.dylib -/tensorflowlib/osx/native/libtensorflow.dylib -/tensorflowlib/linux/native/libtensorflow_framework.so -/tensorflowlib/linux/native/libtensorflow.so -/src/TensorFlowNET.Core/tensorflow.dll +/docs/build +src/TensorFlowNET.Native/bazel-* +src/TensorFlowNET.Native/c_api.h +/.vscode +test/TensorFlowNET.Examples/mnist + + +# training model resources +.resources +/redist +*.xml +*.xsd diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..ee3236a46 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,3 @@ +# You can find more information about CODEOWNERS here: https://help.github.com/en/articles/about-code-owners +# These owners will be the default owners for everything in the repo. +* @Oceania2018 \ No newline at end of file diff --git a/Directory.Build.props b/Directory.Build.props new file mode 100644 index 000000000..065690ec9 --- /dev/null +++ b/Directory.Build.props @@ -0,0 +1,17 @@ + + + + + + true + $(NoWarn),1573,1591,1712 + + + diff --git a/Directory.Build.targets b/Directory.Build.targets new file mode 100644 index 000000000..341027f3c --- /dev/null +++ b/Directory.Build.targets @@ -0,0 +1,3 @@ + + + diff --git a/README.md b/README.md index a0004962e..75cad0aa7 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,261 @@ -# TensorFlow.NET -TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). +![logo](docs/assets/tf.net.logo.png) +**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. TensorFlow.NET has built-in Keras high-level interface and is released as an independent package [TensorFlow.Keras](https://www.nuget.org/packages/TensorFlow.Keras/). + +[![Discord](https://img.shields.io/discord/1106946823282761851?label=Discord)](https://discord.gg/qRVm82fKTS) +[![QQ群聊](https://img.shields.io/static/v1?label=QQ&message=群聊&color=brightgreen)](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=sN9VVMwbWjs5L0ATpizKKxOcZdEPMrp8&authKey=RLDw41bLTrEyEgZZi%2FzT4pYk%2BwmEFgFcrhs8ZbkiVY7a4JFckzJefaYNW6Lk4yPX&noverify=0&group_code=985366726) [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) -![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/tensorflow-net-p7kmsjyo10ey?svg=true) +[![CI Status](https://github.com/SciSharp/TensorFlow.NET/actions/workflows/build_and_test.yml/badge.svg)](https://github.com/SciSharp/TensorFlow.NET/actions/workflows/build_and_test.yml) +[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) +[![TensorFlow.NET Badge](https://img.shields.io/nuget/v/TensorFlow.NET?label=TensorFlow.NET)](https://www.nuget.org/packages/TensorFlow.NET) +[![TensorFlow.Keras Badge](https://img.shields.io/nuget/v/TensorFlow.Keras?label=TensorFlow.Keras)](https://www.nuget.org/packages/TensorFlow.Keras) +[![MyGet Badge](https://img.shields.io/badge/dynamic/json?color=purple&label=Nightly%20Release&prefix=myget-v&query=items%5B0%5D.lower&url=https%3A%2F%2Fwww.myget.org%2FF%2Fscisharp%2Fapi%2Fv3%2Fregistration1%2Ftensorflow.net%2Findex.json)](https://www.myget.org/feed/scisharp/package/nuget/Tensorflow.NET) +[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) + +English | [中文](docs/README-CN.md) + +> [!IMPORTANT] +> We're happy that our work on tensorflow.net has attracted many users. However, at this time, none of the main maintainers of this repo is available for new features and bug fix. We won't refuse PRs and will help to review them. +> +> If you would like to be a contributor or maintainer of tensorflow.net, we'd like to help you to start up. +> +> We feel sorry for that and we'll resume the maintaining for this project once one of us has bandwidth for it. +> + +*master branch and v0.100.x is corresponding to tensorflow v2.10, v0.6x branch is from tensorflow v2.6, v0.15-tensorflow1.15 is from tensorflow1.15. Please add `https://www.myget.org/F/scisharp/api/v3/index.json` to nuget source to use nightly release.* -TensorFlow.NET is a member project of SciSharp stack. ![tensors_flowing](docs/assets/tensors_flowing.gif) -### How to use -Download the pre-compiled dll [here](tensorflowlib) and place it in the bin folder. +## Why Tensorflow.NET ? -Import tensorflow.net. -```cs -using Tensorflow; +`SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Since the APIs are kept as similar as possible you can immediately adapt any existing TensorFlow code in C# or F# with a zero learning curve. Take a look at a comparison picture and see how comfortably a TensorFlow/Python script translates into a C# program with TensorFlow.NET. + +![python vs csharp](docs/assets/syntax-comparision.png) + +SciSharp's philosophy allows a large number of machine learning code written in Python to be quickly migrated to .NET, enabling .NET developers to use cutting edge machine learning models and access a vast number of TensorFlow resources which would not be possible without this project. + +In comparison to other projects, like for instance [TensorFlowSharp](https://www.nuget.org/packages/TensorFlowSharp/) which only provide TensorFlow's low-level C++ API and can only run models that were built using Python, Tensorflow.NET makes it possible to build the pipeline of training and inference with pure C# and F#. Besides, Tensorflow.NET provides binding of Tensorflow.Keras to make it easy to transfer your code from python to .NET. + +[ML.NET](https://github.com/dotnet/machinelearning) also take Tensorflow.NET as one of the backends to train and infer your model, which provides better integration with .NET. + +## Documention + +Introduction and simple examples:[Tensorflow.NET Documents](https://scisharp.github.io/tensorflow-net-docs) + +Detailed documention:[The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html) + +Examples:[TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples) + +Troubleshooting of running example or installation:[Tensorflow.NET FAQ](tensorflowlib/README.md) + +## Usage + +### Installation + +You can search the package name in NuGet Manager, or use the commands below in package manager console. + +The installation contains two parts, the first is the main body: + +```sh +### Install Tensorflow.NET +PM> Install-Package TensorFlow.NET + +### Install Tensorflow.Keras +PM> Install-Package TensorFlow.Keras ``` -Add two constants. -```cs -// Create a Constant op -var a = tf.constant(4.0f); -var b = tf.constant(5.0f); -var c = tf.add(a, b); +The second part is the computing support part. Only one of the following packages is needed, depending on your device and system. -using (var sess = tf.Session()) -{ - var o = sess.run(c); -} ``` +### CPU version for Windows and Linux +PM> Install-Package SciSharp.TensorFlow.Redist + +### CPU version for MacOS +PM> Install-Package SciSharp.TensorFlow.Redist-OSX + +### GPU version for Windows (CUDA and cuDNN are required) +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU + +### GPU version for Linux (CUDA and cuDNN are required) +PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU +``` + + +Two simple examples are given here to introduce the basic usage of Tensorflow.NET. As you can see, it's easy to write C# code just like that in Python. + +### Example - Linear Regression in `Eager` mode + +```csharp +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow; +using Tensorflow.NumPy; + +// Parameters +var training_steps = 1000; +var learning_rate = 0.01f; +var display_step = 100; + +// Sample data +var X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, + 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); +var Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, + 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); +var n_samples = X.shape[0]; -Feed placeholder. -```cs -// Create a placeholder op -var a = tf.placeholder(tf.float32); -var b = tf.placeholder(tf.float32); -var c = tf.add(a, b); +// We can set a fixed init value in order to demo +var W = tf.Variable(-0.06f, name: "weight"); +var b = tf.Variable(-0.73f, name: "bias"); +var optimizer = keras.optimizers.SGD(learning_rate); -using(var sess = tf.Session()) +// Run training for the given number of steps. +foreach (var step in range(1, training_steps + 1)) { - var feed_dict = new Dictionary(); - feed_dict.Add(a, 3.0f); - feed_dict.Add(b, 2.0f); + // Run the optimization to update W and b values. + // Wrap computation inside a GradientTape for automatic differentiation. + using var g = tf.GradientTape(); + // Linear regression (Wx + b). + var pred = W * X + b; + // Mean square error. + var loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + // should stop recording + // Compute gradients. + var gradients = g.gradient(loss, (W, b)); - var o = sess.run(c, feed_dict); + // Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, (W, b))); + + if (step % display_step == 0) + { + pred = W * X + b; + loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + print($"step: {step}, loss: {loss.numpy()}, W: {W.numpy()}, b: {b.numpy()}"); + } } ``` + +Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube). + +### Example - Toy version of `ResNet` in `Keras` functional API + +```csharp +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow; +using Tensorflow.NumPy; + +var layers = keras.layers; +// input layer +var inputs = keras.Input(shape: (32, 32, 3), name: "img"); +// convolutional layer +var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs); +x = layers.Conv2D(64, 3, activation: "relu").Apply(x); +var block_1_output = layers.MaxPooling2D(3).Apply(x); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); +var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output)); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); +var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output)); +x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output); +x = layers.GlobalAveragePooling2D().Apply(x); +x = layers.Dense(256, activation: "relu").Apply(x); +x = layers.Dropout(0.5f).Apply(x); +// output layer +var outputs = layers.Dense(10).Apply(x); +// build keras model +var model = keras.Model(inputs, outputs, name: "toy_resnet"); +model.summary(); +// compile keras model in tensorflow static graph +model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), + loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), + metrics: new[] { "acc" }); +// prepare dataset +var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); +// normalize the input +x_train = x_train / 255.0f; +// training +model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], + batch_size: 64, + epochs: 10, + validation_split: 0.2f); +// save the model +model.save("./toy_resnet_model"); +``` + +The F# example for linear regression is available [here](docs/Example-fsharp.md). + +More adcanced examples could be found in [TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples). + +## Version Relationships + +| TensorFlow.NET Versions | tensorflow 1.14, cuda 10.0 | tensorflow 1.15, cuda 10.0 | tensorflow 2.3, cuda 10.1 | tensorflow 2.4, cuda 11 | tensorflow 2.7, cuda 11 |tensorflow 2.10, cuda 11 | +| -------------------------- | ------------- | -------------- | ------------- | ------------- | ------------ | ------------ | +| tf.net 0.10x, tf.keras 0.10 | | | | | | x | +| tf.net 0.7x, tf.keras 0.7 | | | | | x | | +| tf.net 0.4x, tf.keras 0.5 | | | | x | | | +| tf.net 0.3x, tf.keras 0.4 | | | x | | | | +| tf.net 0.2x | | x | x | | | | +| tf.net 0.15 | x | x | | | | | +| tf.net 0.14 | x | | | | | | + + +``` +tf.net 0.4x -> tf native 2.4 +tf.net 0.6x -> tf native 2.6 +tf.net 0.7x -> tf native 2.7 +tf.net 0.10x -> tf native 2.10 +... +``` + +## Contribution: + +Feel like contributing to one of the hottest projects in the Machine Learning field? Want to know how Tensorflow magically creates the computational graph? + +We appreciate every contribution however small! There are tasks for novices to experts alike, if everyone tackles only a small task the sum of contributions will be huge. + +You can: +- Star Tensorflow.NET or share it with others +- Tell us about the missing APIs compared to Tensorflow +- Port Tensorflow unit tests from Python to C# or F# +- Port Tensorflow examples to C# or F# and raise issues if you come accross missing parts of the API or BUG +- Debug one of the unit tests that is marked as Ignored to get it to work +- Debug one of the not yet working examples and get it to work +- Help us to complete the documentions. + + +#### How to debug unit tests: + +The best way to find out why a unit test is failing is to single step it in C# or F# and its corresponding Python at the same time to see where the flow of execution digresses or where variables exhibit different values. Good Python IDEs like PyCharm let you single step into the tensorflow library code. + +#### Git Knowhow for Contributors + +Add SciSharp/TensorFlow.NET as upstream to your local repo ... +```git +git remote add upstream git@github.com:SciSharp/TensorFlow.NET.git +``` + +Please make sure you keep your fork up to date by regularly pulling from upstream. +```git +git pull upstream master +``` + +### Support +Buy our book to make open source project be sustainable [TensorFlow.NET实战](https://item.jd.com/13441549.html) +

+ + + +

+ +### Contact + +Join our chat on [Discord](https://discord.gg/qRVm82fKTS) or [Gitter](https://gitter.im/sci-sharp/community). + +Follow us on [Twitter](https://twitter.com/ScisharpStack), [Facebook](https://www.facebook.com/scisharp.stack.9), [Medium](https://medium.com/scisharp), [LinkedIn](https://www.linkedin.com/company/scisharp-stack/). + +TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/) +
+ diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 8936dd3d9..e0c273568 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -1,42 +1,390 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 15 -VisualStudioVersion = 15.0.28307.136 +# Visual Studio Version 17 +VisualStudioVersion = 17.4.33213.308 MinimumVisualStudioVersion = 10.0.40219.1 -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{1B1BC950-2CB0-48E2-B4CD-8172AFF67A10}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding.UnitTest", "test\TensorFlowNET.UnitTest\Tensorflow.Binding.UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{49D71826-C03D-4FA7-9BAC-22C1327E65CF}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\TensorFlowNET.Text\Tensorflow.Text.csproj", "{1AB8108D-4FFE-4A16-88E7-328EAF686370}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{01A1787F-A9BE-4221-84E8-6360DD010AB6}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{1B0918B9-65AD-4F34-A287-AF4597B27DBD}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{E1A5D2B7-10AF-4876-85C0-7714EF274214}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.CodeGen", "tools\Tensorflow.CodeGen\Tensorflow.CodeGen.csproj", "{3D92142F-EEDB-469B-B03C-4E38728BFE4C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Redist.NativeLibrarySplitter", "tools\Tensorflow.Redist.NativeLibrarySplitter\Tensorflow.Redist.NativeLibrarySplitter.csproj", "{AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "tools\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{D24FCAA5-548C-4251-B226-A1B6535D0845}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "tools\TensorFlowNET.Benchmarks\Tensorflow.Benchmark.csproj", "{C23563DB-FE21-48E7-A411-87A109E4A899}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Console", "tools\TensorFlowNET.Console\Tensorflow.Console.csproj", "{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlow.Kernel.UnitTest", "test\TensorFlow.Kernel.UnitTest\TensorFlow.Kernel.UnitTest.csproj", "{654A027D-1364-4729-880B-144DFE1FF5BB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.UnitTest", "test\Tensorflow.UnitTest\Tensorflow.UnitTest.csproj", "{A73DF5A6-866E-4AED-9017-AA2EE86368C4}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + GPU|Any CPU = GPU|Any CPU + GPU|x64 = GPU|x64 + GPU|x86 = GPU|x86 Release|Any CPU = Release|Any CPU + Release|x64 = Release|x64 + Release|x86 = Release|x86 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution - {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.Build.0 = Debug|Any CPU - {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.ActiveCfg = Release|Any CPU - {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.Build.0 = Release|Any CPU - {1B1BC950-2CB0-48E2-B4CD-8172AFF67A10}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {1B1BC950-2CB0-48E2-B4CD-8172AFF67A10}.Debug|Any CPU.Build.0 = Debug|Any CPU - {1B1BC950-2CB0-48E2-B4CD-8172AFF67A10}.Release|Any CPU.ActiveCfg = Release|Any CPU - {1B1BC950-2CB0-48E2-B4CD-8172AFF67A10}.Release|Any CPU.Build.0 = Release|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.ActiveCfg = Release|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.Build.0 = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x86.ActiveCfg = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x86.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|Any CPU.ActiveCfg = GPU|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|Any CPU.Build.0 = GPU|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|x64.ActiveCfg = GPU|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|x64.Build.0 = GPU|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|x86.ActiveCfg = GPU|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.GPU|x86.Build.0 = GPU|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|x64 + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x86.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x86.Build.0 = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x86.ActiveCfg = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x86.Build.0 = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|Any CPU.Build.0 = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|x64.ActiveCfg = Release|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|x64.Build.0 = Release|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|x86.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.GPU|x86.Build.0 = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64 + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|x64.ActiveCfg = Debug|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|x64.Build.0 = Debug|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|x86.ActiveCfg = Debug|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|x86.Build.0 = Debug|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|Any CPU.ActiveCfg = GPU|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|Any CPU.Build.0 = GPU|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|x64.ActiveCfg = GPU|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|x64.Build.0 = GPU|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|x86.ActiveCfg = GPU|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.GPU|x86.Build.0 = GPU|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|Any CPU.Build.0 = Release|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x64.ActiveCfg = Release|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x64.Build.0 = Release|x64 + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.ActiveCfg = Release|Any CPU + {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|Any CPU.Build.0 = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|x64.ActiveCfg = Release|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|x64.Build.0 = Release|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|x86.ActiveCfg = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.GPU|x86.Build.0 = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|Any CPU.Build.0 = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.ActiveCfg = Release|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.Build.0 = Release|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.ActiveCfg = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|Any CPU.Build.0 = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|x64.ActiveCfg = Release|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|x64.Build.0 = Release|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|x86.ActiveCfg = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.GPU|x86.Build.0 = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|Any CPU.Build.0 = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.ActiveCfg = Release|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|Any CPU.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|x64.ActiveCfg = Release|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|x64.Build.0 = Release|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|x86.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.GPU|x86.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|x64.ActiveCfg = Debug|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|x64.Build.0 = Debug|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|x86.ActiveCfg = Debug|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|x86.Build.0 = Debug|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|Any CPU.Build.0 = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|x64.ActiveCfg = Release|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|x64.Build.0 = Release|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|x86.ActiveCfg = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.GPU|x86.Build.0 = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|Any CPU.Build.0 = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.ActiveCfg = Release|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.ActiveCfg = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.ActiveCfg = Debug|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.Build.0 = Debug|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.ActiveCfg = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.Build.0 = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|Any CPU.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|Any CPU.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|x64.ActiveCfg = Release|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|x64.Build.0 = Release|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|x86.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.GPU|x86.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|Any CPU.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|x64.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|x64.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|x86.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.GPU|x86.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|Any CPU.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|x64.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|x64.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|x86.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.GPU|x86.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x64.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x64.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x86.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x86.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|Any CPU.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x64.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x64.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x86.ActiveCfg = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x86.Build.0 = Debug|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|Any CPU.Build.0 = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x64.ActiveCfg = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x64.Build.0 = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x86.ActiveCfg = Release|Any CPU + {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x86.Build.0 = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x64.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x64.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x86.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x86.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|Any CPU.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x64.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x64.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x86.ActiveCfg = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x86.Build.0 = Debug|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|Any CPU.Build.0 = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x64.ActiveCfg = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x64.Build.0 = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x86.ActiveCfg = Release|Any CPU + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x86.Build.0 = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x64.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x64.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x86.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x86.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|Any CPU.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x64.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x64.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x86.ActiveCfg = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x86.Build.0 = Debug|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|Any CPU.Build.0 = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x64.ActiveCfg = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x64.Build.0 = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x86.ActiveCfg = Release|Any CPU + {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x86.Build.0 = Release|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x64.ActiveCfg = Debug|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x64.Build.0 = Debug|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x86.ActiveCfg = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x86.Build.0 = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|Any CPU.Build.0 = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x64.ActiveCfg = Debug|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x64.Build.0 = Debug|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x86.ActiveCfg = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x86.Build.0 = Debug|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|Any CPU.Build.0 = Release|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x64.ActiveCfg = Release|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x64.Build.0 = Release|x64 + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x86.ActiveCfg = Release|Any CPU + {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x86.Build.0 = Release|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x64.ActiveCfg = Debug|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x64.Build.0 = Debug|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x86.ActiveCfg = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x86.Build.0 = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|Any CPU.Build.0 = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x64.ActiveCfg = Debug|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x64.Build.0 = Debug|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x86.ActiveCfg = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x86.Build.0 = Debug|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|Any CPU.Build.0 = Release|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.ActiveCfg = Release|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.Build.0 = Release|x64 + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.ActiveCfg = Release|Any CPU + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.Build.0 = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.ActiveCfg = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.Build.0 = Debug|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.Build.0 = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.ActiveCfg = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.Build.0 = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.ActiveCfg = Release|Any CPU + {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.Build.0 = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x64.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x64.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x86.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x86.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|Any CPU.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|Any CPU.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x64.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x64.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x86.ActiveCfg = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x86.Build.0 = Debug|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|Any CPU.Build.0 = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x64.ActiveCfg = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x64.Build.0 = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x86.ActiveCfg = Release|Any CPU + {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {49D71826-C03D-4FA7-9BAC-22C1327E65CF} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} + {1AB8108D-4FFE-4A16-88E7-328EAF686370} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} + {F17AAECB-960A-4E18-A270-BAD776F0E55B} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18} = {01A1787F-A9BE-4221-84E8-6360DD010AB6} + {7DEA8760-E401-4872-81F3-405F185A13A0} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {3D92142F-EEDB-469B-B03C-4E38728BFE4C} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {D24FCAA5-548C-4251-B226-A1B6535D0845} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {C23563DB-FE21-48E7-A411-87A109E4A899} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} + {654A027D-1364-4729-880B-144DFE1FF5BB} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + {A73DF5A6-866E-4AED-9017-AA2EE86368C4} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} + EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A} EndGlobalSection diff --git a/TensorFlow.NET.sln.DotSettings b/TensorFlow.NET.sln.DotSettings new file mode 100644 index 000000000..aba8725cc --- /dev/null +++ b/TensorFlow.NET.sln.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/data/dbpedia_subset.zip b/data/dbpedia_subset.zip new file mode 100644 index 000000000..120ac8a10 Binary files /dev/null and b/data/dbpedia_subset.zip differ diff --git a/data/imdb.zip b/data/imdb.zip new file mode 100644 index 000000000..f38c48402 Binary files /dev/null and b/data/imdb.zip differ diff --git a/data/img001.bmp b/data/img001.bmp new file mode 100644 index 000000000..d149d76f1 Binary files /dev/null and b/data/img001.bmp differ diff --git a/data/linear_regression.zip b/data/linear_regression.zip new file mode 100644 index 000000000..50415d840 Binary files /dev/null and b/data/linear_regression.zip differ diff --git a/data/lstm_crf_ner.zip b/data/lstm_crf_ner.zip new file mode 100644 index 000000000..9e47ca934 Binary files /dev/null and b/data/lstm_crf_ner.zip differ diff --git a/data/nb_example.npy b/data/nb_example.npy new file mode 100644 index 000000000..4547812ca Binary files /dev/null and b/data/nb_example.npy differ diff --git a/data/shasta-daisy.jpg b/data/shasta-daisy.jpg new file mode 100644 index 000000000..9a0a46eb0 Binary files /dev/null and b/data/shasta-daisy.jpg differ diff --git a/data/text8.zip b/data/text8.zip new file mode 100644 index 000000000..436e05b2d Binary files /dev/null and b/data/text8.zip differ diff --git a/data/tfhub_modules.zip b/data/tfhub_modules.zip new file mode 100644 index 000000000..a61ba9c30 Binary files /dev/null and b/data/tfhub_modules.zip differ diff --git a/docs/Example-fsharp.md b/docs/Example-fsharp.md new file mode 100644 index 000000000..578543454 --- /dev/null +++ b/docs/Example-fsharp.md @@ -0,0 +1,55 @@ +Linear Regression in `Eager` mode: + +```fsharp +#r "nuget: TensorFlow.Net" +#r "nuget: TensorFlow.Keras" +#r "nuget: SciSharp.TensorFlow.Redist" + +open Tensorflow +open Tensorflow.NumPy +open type Tensorflow.Binding +open type Tensorflow.KerasApi + +let tf = New() +tf.enable_eager_execution() + +// Parameters +let training_steps = 1000 +let learning_rate = 0.01f +let display_step = 100 + +// Sample data +let train_X = + np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, + 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f) +let train_Y = + np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, + 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f) +let n_samples = train_X.shape.[0] + +// We can set a fixed init value in order to demo +let W = tf.Variable(-0.06f,name = "weight") +let b = tf.Variable(-0.73f, name = "bias") +let optimizer = keras.optimizers.SGD(learning_rate) + +// Run training for the given number of steps. +for step = 1 to (training_steps + 1) do + // Run the optimization to update W and b values. + // Wrap computation inside a GradientTape for automatic differentiation. + use g = tf.GradientTape() + // Linear regression (Wx + b). + let pred = W * train_X + b + // Mean square error. + let loss = tf.reduce_sum(tf.pow(pred - train_Y,2)) / (2 * n_samples) + // should stop recording + // compute gradients + let gradients = g.gradient(loss,struct (W,b)) + + // Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, struct (W,b))) + + if (step % display_step) = 0 then + let pred = W * train_X + b + let loss = tf.reduce_sum(tf.pow(pred-train_Y,2)) / (2 * n_samples) + printfn $"step: {step}, loss: {loss.numpy()}, W: {W.numpy()}, b: {b.numpy()}" +``` \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..69fe55ecf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/README-CN.md b/docs/README-CN.md new file mode 100644 index 000000000..9776b0fb8 --- /dev/null +++ b/docs/README-CN.md @@ -0,0 +1,228 @@ +![logo](assets/tf.net.logo.png) + +**Tensorflow.NET**是AI框架[TensorFlow](https://www.tensorflow.org/)在.NET平台上的实现,支持C#和F#,可以用来搭建深度学习模型并进行训练和推理,并内置了Numpy API,可以用来进行其它科学计算。 + +Tensorflow.NET并非对于Python的简单封装,而是基于C API的pure C#实现,因此使用时无需额外的环境,可以很方便地用NuGet直接安装使用。并且dotnet团队提供的[ML.NET](https://github.com/dotnet/machinelearning)也依赖于Tensorflow.NET,支持调用Tensorflow.NET进行训练和推理,可以很方便地融入.NET生态。 + +与tensorflow相同,Tensorflow.NET也内置了Keras这一高级API,只要在安装Tensorflow.NET的同时安装Tensorflow.Keras就可以使用,Keras支持以模块化的方式调用模型,给模型的搭建提供了极大的便利。 + +[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) +[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) +[![NuGet](https://img.shields.io/nuget/dt/TensorFlow.NET.svg)](https://www.nuget.org/packages/TensorFlow.NET) +[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) +[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) + +中文 | [English](https://github.com/SciSharp/TensorFlow.NET#readme) + +*当前主分支与Tensorflow2.10版本相对应,支持Eager Mode,同时也支持v1的静态图。* + + +![tensors_flowing](assets/tensors_flowing.gif) + +## Why Tensorflow.NET? + +`SciSharp STACK`开源社区的目标是构建.NET平台下易用的科学计算库,而Tensorflow.NET就是其中最具代表性的仓库之一。在深度学习领域Python是主流,无论是初学者还是资深开发者,模型的搭建和训练都常常使用Python写就的AI框架,比如tensorflow。但在实际应用深度学习模型的时候,又可能希望用到.NET生态,亦或只是因为.NET是自己最熟悉的领域,这时候Tensorflow.NET就有显著的优点,因为它不仅可以和.NET生态很好地贴合,其API还使得开发者很容易将Python代码迁移过来。下面的对比就是很好的例子,Python代码和C#代码有着高度相似的API,这会使得迁移的时候无需做过多修改。 + +![python vs csharp](assets/syntax-comparision.png) + +除了高度相似的API外,Tensorflow.NET与tensorflow也已经打通数据通道,tensorflow训练并保存的模型可以在Tensorflow.NET中直接读取并继续训练或推理,反之Tensorflow.NET保存的模型也可以在tensorflow中读取,这大大方便了模型的训练和部署。 + +与其它类似的库比如[TensorFlowSharp](https://www.nuget.org/packages/TensorFlowSharp/)相比,Tensorflow.NET的实现更加完全,提供了更多的高级API,使用起来更为方便,更新也更加迅速。 + + +## 文档 + +基本介绍与简单用例:[Tensorflow.NET Documents](https://scisharp.github.io/tensorflow-net-docs) + +详细文档:[The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html) + +例程:[TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples) + +运行例程常见问题:[Tensorflow.NET FAQ](tensorflowlib/README.md) + +## 安装与使用 + +安装可以在NuGet包管理器中搜索包名安装,也可以用下面命令行的方式。 + +安装分为两个部分,第一部分是Tensorflow.NET的主体: + +```sh +### 安装Tensorflow.NET +PM> Install-Package TensorFlow.NET + +### 安装Tensorflow.Keras +PM> Install-Package TensorFlow.Keras +``` + +第二部分是计算支持部分,只需要根据自己的设备和系统选择下面之一即可: + +``` +### CPU版本,支持Windows、Linux和Mac +PM> Install-Package SciSharp.TensorFlow.Redist + +### Windows下的GPU版本(需要安装CUDA和cuDNN) +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU + +### Linux下的GPU版本(需要安装CUDA和cuDNN) +PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU +``` + +下面给出两个简单的例子,更多例子可以在[TensorFlow.NET Examples]中查看。 + +### 简单例子(使用Eager Mode进行线性回归) + +```csharp +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow; +using Tensorflow.NumPy; + +// Parameters +var training_steps = 1000; +var learning_rate = 0.01f; +var display_step = 100; + +// Sample data +var X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, + 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); +var Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, + 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); +var n_samples = X.shape[0]; + +// We can set a fixed init value in order to demo +var W = tf.Variable(-0.06f, name: "weight"); +var b = tf.Variable(-0.73f, name: "bias"); +var optimizer = keras.optimizers.SGD(learning_rate); + +// Run training for the given number of steps. +foreach (var step in range(1, training_steps + 1)) +{ + // Run the optimization to update W and b values. + // Wrap computation inside a GradientTape for automatic differentiation. + using var g = tf.GradientTape(); + // Linear regression (Wx + b). + var pred = W * X + b; + // Mean square error. + var loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + // should stop recording + // Compute gradients. + var gradients = g.gradient(loss, (W, b)); + + // Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, (W, b))); + + if (step % display_step == 0) + { + pred = W * X + b; + loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + print($"step: {step}, loss: {loss.numpy()}, W: {W.numpy()}, b: {b.numpy()}"); + } +} +``` + +这一用例也可以在[Jupyter Notebook Example](https://github.com/SciSharp/SciSharpCube)进行运行. + +### 简单例子(使用Keras搭建Resnet) + +```csharp +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow; +using Tensorflow.NumPy; + +var layers = keras.layers; +// input layer +var inputs = keras.Input(shape: (32, 32, 3), name: "img"); +// convolutional layer +var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs); +x = layers.Conv2D(64, 3, activation: "relu").Apply(x); +var block_1_output = layers.MaxPooling2D(3).Apply(x); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); +var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output)); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output); +x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); +var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output)); +x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output); +x = layers.GlobalAveragePooling2D().Apply(x); +x = layers.Dense(256, activation: "relu").Apply(x); +x = layers.Dropout(0.5f).Apply(x); +// output layer +var outputs = layers.Dense(10).Apply(x); +// build keras model +var model = keras.Model(inputs, outputs, name: "toy_resnet"); +model.summary(); +// compile keras model in tensorflow static graph +model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), + loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), + metrics: new[] { "acc" }); +// prepare dataset +var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); +// normalize the input +x_train = x_train / 255.0f; +// training +model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], + batch_size: 64, + epochs: 10, + validation_split: 0.2f); +// save the model +model.save("./toy_resnet_model"); +``` + +此外,Tensorflow.NET也支持用F#搭建上述模型进行训练和推理。 + +## Tensorflow.NET版本对应关系 + +| TensorFlow.NET Versions | tensorflow 1.14, cuda 10.0 | tensorflow 1.15, cuda 10.0 | tensorflow 2.3, cuda 10.1 | tensorflow 2.4, cuda 11 | tensorflow 2.7, cuda 11 |tensorflow 2.10, cuda 11 | +| -------------------------- | ------------- | -------------- | ------------- | ------------- | ------------ | ------------ | +| tf.net 0.10x, tf.keras 0.10 | | | | | | x | +| tf.net 0.7x, tf.keras 0.7 | | | | | x | | +| tf.net 0.4x, tf.keras 0.5 | | | | x | | | +| tf.net 0.3x, tf.keras 0.4 | | | x | | | | +| tf.net 0.2x | | x | x | | | | +| tf.net 0.15 | x | x | | | | | +| tf.net 0.14 | x | | | | | | + + +``` +tf.net 0.4x -> tf native 2.4 +tf.net 0.6x -> tf native 2.6 +tf.net 0.7x -> tf native 2.7 +tf.net 0.10x -> tf native 2.10 +... +``` + +如果使用过程中发现有缺失的版本,请告知我们,谢谢! + +请注意Tensorflow.NET与Tensorflow.Keras版本存在一一对应关系,请安装与Tensorflow.NET对应的Tensorflow.Keras版本。 + +## 参与我们的开发: + +我们欢迎任何人的任何形式的贡献!无论是文档中的错误纠正,新特性提议,还是BUG修复等等,都会使得Tensorflow.NET项目越来越好,Tensorflow.NET的全体开发者也会积极帮助解决您提出的问题。 + +下面任何一种形式都可以帮助Tensorflow.NET越来越好: + +* Star和分享Tensorflow.NET项目 +* 为Tensorflow.NET添加更多的用例 +* 在issue中告知我们Tensorflow.NET目前相比tensorflow缺少的API或者没有对齐的特性 +* 在issue中提出Tensorflow.NET存在的BUG或者可以改进的地方 +* 在待办事项清单中选择一个进行或者解决某个issue +* 帮助我们完善文档,这也十分重要 + + +## 支持我们 +我们推出了[TensorFlow.NET实战](https://item.jd.com/13441549.html)这本书,包含了Tensorflow.NET主要开发者编写的讲解与实战例程,欢迎您的购买,希望这本书可以给您带来帮助。 +

+ + + +

+ +## 联系我们 + +可以在 [Twitter](https://twitter.com/ScisharpStack), [Facebook](https://www.facebook.com/scisharp.stack.9), [Medium](https://medium.com/scisharp), [LinkedIn](https://www.linkedin.com/company/scisharp-stack/)中关注我们,也可以在[Gitter](https://gitter.im/sci-sharp/community)中与项目开发者以及其它使用者进行沟通交流,也欢迎在仓库中提起issue。 + +TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/) +
+ diff --git a/docs/README.md b/docs/README.md index e69de29bb..0e3c00484 100644 --- a/docs/README.md +++ b/docs/README.md @@ -0,0 +1,21 @@ +### Instll Sphinx +```cmd +pip install sphinx +pip install recommonmark +pip install sphinx_rtd_theme +``` + +### Init the docs +```cmd +sphinx-quickstarts +``` + +### Build the docs +```cmd +make html +``` + + + +Access the compiled docs: [https://tensorflownet.readthedocs.io](https://tensorflownet.readthedocs.io/) + diff --git a/docs/RELEASE.md b/docs/RELEASE.md new file mode 100644 index 000000000..62a1be238 --- /dev/null +++ b/docs/RELEASE.md @@ -0,0 +1,44 @@ +# Release Notes + +**Thanks to our Contributors!** + +This release contains contributions from many people at SciSharp as well as the external contributors. + +**Release Date 02/06/2021** + +### TensorFlow.Binding v0.33.0 + +* Improve memory usage +* Fix minor bugs + +### TensorFlow.Keras v0.4.0 + +* Add Subtract layer + +* Add model.load_weights and model.save_weights + +* Fix memory leak issue + +* Support to build YOLOv3 object detection model + + + +**Release Date 01/09/2021** + +### TensorFlow.Binding v0.32.0 + +* Fix input `dtype` for `MapDataset`. +* Fix `image_dataset_from_directory` function. +* Fix `tf.transpose`. +* Add `array_ops.where_v2`, `array_ops.select_v2`, `array_ops.softplus`. +* Add `dataset.dataset_cardinality`. + +### TensorFlow.Keras v0.3.0 + +* Fix `weight` init value for `double` type in `compute_weighted_loss`. +* Add `MeanSquaredError `, `MeanAbsolutePercentageError `, `MeanAbsoluteError` and `MeanSquaredLogarithmicError` loss functions. +* `Sequential` model API works. +* Add `ShellProgressBar` to show training progress better. + + + diff --git a/docs/The-Definitive-Guide/CH_1 Tensor.md b/docs/The-Definitive-Guide/CH_1 Tensor.md deleted file mode 100644 index 1d09ba42d..000000000 --- a/docs/The-Definitive-Guide/CH_1 Tensor.md +++ /dev/null @@ -1,41 +0,0 @@ -# 第一章: Tensor - -### Represents one of the outputs of an Operation - -### 表示一个操作的输出 - - - -##### What is Tensor? - -##### Tensor 是什么? - -Tensor holds a multi-dimensional array of elements of a single data type which is very similar with numpy's ndarray. - -Tensor是一个具有单一数据类型的多维数组容器,非常类似于numpy里的ndarray。如果你对numpy非常熟悉的话,那么对Tensor的理解会相当容易。 - - - -##### How to create a Tensor? - -##### 如何创建一个Tensor? - - - - - -TF uses column major order. - -TF 采用的是按列存储模式,如果我们用NumSharp产生一个2 X 3的矩阵,如果按顺序从0到5访问数据的话,是不会得到1-6的数字的,而是得到1,4, 2, 5, 3, 6这个顺序的一组数字。 - -```cs -// generate a matrix:[[1, 2, 3], [4, 5, 6]] -var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); -// the index will be 0 2 4 1 3 5, it's column-major order. -``` - - - -![column-major order](assets/column-major-order.png) - -![row-major order](assets/row-major-order.png) diff --git a/docs/The-Definitive-Guide/CH_3 Operation.md b/docs/The-Definitive-Guide/CH_3 Operation.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/docs/The-Definitive-Guide/CH_4 Variable.md b/docs/The-Definitive-Guide/CH_4 Variable.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/docs/The-Definitive-Guide/CH_5 Session.md b/docs/The-Definitive-Guide/CH_5 Session.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/docs/The-Definitive-Guide/CH_6 Graph.md b/docs/The-Definitive-Guide/CH_6 Graph.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/docs/The-Definitive-Guide/Foreword.md b/docs/The-Definitive-Guide/Foreword.md deleted file mode 100644 index 0a5232f9f..000000000 --- a/docs/The-Definitive-Guide/Foreword.md +++ /dev/null @@ -1,9 +0,0 @@ -# Foreword 前言 - -One of the most nerve-wracking periods when releasing the first version of an open source project occurs when the gitter community is created. You are all alone, eagerly hoping and wishing for the first user to come along. I still vividly remember those days. - - - -当我开始写这个项目的时候,我同时也在整理编码过程时候的想法,Tensorflow是个巨大最复杂的工程,很容易超出个人能力范围,所以想尽可能地把当时的思路记录下来,也想趁着记录整理的过程把思路理清。 - -When I started writing this project, I was also sorting out the idea of the coding process. Tensorflow is a huge and complicated project, and it is easy to go beyond the scope of personal ability. Therefore, I want to record the thoughts at the time as much as possible. The process of recording and sorting clears the way of thinking. \ No newline at end of file diff --git a/docs/The-Definitive-Guide/Preface.md b/docs/The-Definitive-Guide/Preface.md deleted file mode 100644 index ad91bf5ca..000000000 --- a/docs/The-Definitive-Guide/Preface.md +++ /dev/null @@ -1,35 +0,0 @@ - - -# Preface 序 - - - - - - - - - - - - - -Why do I start the Tensorflow.NET project? - -我为什么会写Tensorflow.NET? - -再过几天就是2018年圣诞节,看着孩子一天天长大并懂事,感慨时间过得太快。IT技术更新换代比以往任何时候都更快,各种前后端技术纷纷涌现。大数据,人工智能和区块链,容器技术和微服务,分布式计算和无服务器技术,让人眼花缭乱。Amazon AI服务接口宣称不需要具有任何机器学习经验的工程师就能使用,让像我这样刚静下心来学习了两年并打算将来转行做AI架构的想法泼了一桶凉水。 - -TensorFlow is an open source project for machine learning especially for deep learning. It's used for both research and production at Google company. It's designed according to dataflow programming pattern across a range of tasks. - - - -为了避免混淆,本书中对TensorFlow中定义的特有类不进行翻译,比如Tensor, Graph, Shape这些词都会保留英文名称。 - - - -术语简称: - -TF: Google TensorFlow - -TF.NET: Tensorflow.NET \ No newline at end of file diff --git a/docs/The-Definitive-Guide/README.md b/docs/The-Definitive-Guide/README.md deleted file mode 100644 index 926aded7b..000000000 --- a/docs/The-Definitive-Guide/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# The Definitive Guide to Tensorflow.NET -# Tensorflow.NET 权威指南 - - - - -### The CSharp binding for Google's TensorFlow - An Open Source Machine Learning Framework for Everyone -### 谷歌TensorFlow的C#封装库,开源机器学习框架。 - - - - - - - -

-Haiping Chen & Christian Kahr
-Christmas, 2018
-陈海平 & 克里斯汀 卡尔
-2018年圣诞节 -

- - - - - diff --git a/docs/The-Definitive-Guide/Table of Contents.md b/docs/The-Definitive-Guide/Table of Contents.md deleted file mode 100644 index 6551e9262..000000000 --- a/docs/The-Definitive-Guide/Table of Contents.md +++ /dev/null @@ -1,18 +0,0 @@ -# Table of Contents - -### Foreword...........................................................................................xxi - -### Preface..............................................................................................xxiii - -## Part I. Getting Started - -##### 1. You Know, for Machine Learning............................................................................................ 3 - -​ Installing Tensorflow.NET -​ Running Tensorflow.NET -​ Talking to Tensorflow.NET - -## Part II. Tensorflow.NET in Depth - -## Part III. Dealing with Human Language - diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 000000000..c4192631f --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/docs/assets/Cover.psd b/docs/assets/Cover.psd new file mode 100644 index 000000000..665487608 Binary files /dev/null and b/docs/assets/Cover.psd differ diff --git a/docs/assets/Logo.md b/docs/assets/Logo.md new file mode 100644 index 000000000..21e7858ae --- /dev/null +++ b/docs/assets/Logo.md @@ -0,0 +1,3 @@ +TensorFlow.NET logo (c) 2019 by Meinrad Recheis. + +The logo is based on the original Tensorflow logo which is copyrighted by the respective creator. \ No newline at end of file diff --git a/docs/assets/TensorBoard-nn.png b/docs/assets/TensorBoard-nn.png new file mode 100644 index 000000000..23ccc3db5 Binary files /dev/null and b/docs/assets/TensorBoard-nn.png differ diff --git a/docs/assets/WeChatCollection.jpg b/docs/assets/WeChatCollection.jpg new file mode 100644 index 000000000..587b54991 Binary files /dev/null and b/docs/assets/WeChatCollection.jpg differ diff --git a/docs/assets/cnn-result.png b/docs/assets/cnn-result.png new file mode 100644 index 000000000..e1cea1e48 Binary files /dev/null and b/docs/assets/cnn-result.png differ diff --git a/docs/assets/cnn.png b/docs/assets/cnn.png new file mode 100644 index 000000000..78c7a6808 Binary files /dev/null and b/docs/assets/cnn.png differ diff --git a/docs/assets/eager-mode-add.png b/docs/assets/eager-mode-add.png new file mode 100644 index 000000000..e3700fa62 Binary files /dev/null and b/docs/assets/eager-mode-add.png differ diff --git a/docs/assets/graph_vis_animation.gif b/docs/assets/graph_vis_animation.gif new file mode 100644 index 000000000..556383270 Binary files /dev/null and b/docs/assets/graph_vis_animation.gif differ diff --git a/docs/assets/mnist.png b/docs/assets/mnist.png new file mode 100644 index 000000000..824818721 Binary files /dev/null and b/docs/assets/mnist.png differ diff --git a/docs/assets/nn-result.png b/docs/assets/nn-result.png new file mode 100644 index 000000000..7957b8214 Binary files /dev/null and b/docs/assets/nn-result.png differ diff --git a/docs/assets/nn.png b/docs/assets/nn.png new file mode 100644 index 000000000..8fbb6f9b8 Binary files /dev/null and b/docs/assets/nn.png differ diff --git a/docs/assets/performance-comparison.jpg b/docs/assets/performance-comparison.jpg new file mode 100644 index 000000000..382f7ab61 Binary files /dev/null and b/docs/assets/performance-comparison.jpg differ diff --git a/docs/assets/syntax-comparision.png b/docs/assets/syntax-comparision.png new file mode 100644 index 000000000..d42b5cb9c Binary files /dev/null and b/docs/assets/syntax-comparision.png differ diff --git a/docs/assets/tf.net.architecture.svg b/docs/assets/tf.net.architecture.svg new file mode 100644 index 000000000..933fd8027 --- /dev/null +++ b/docs/assets/tf.net.architecture.svg @@ -0,0 +1,370 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + TensorFlow + + + Tensor Computation Layer (C++) + + + Graph Manipulation Layer (Python) + + + TensorFlow.NET + + + Graph Manipulation Layer (C#) + + + + Tensor Computation Layer (C++) + + + + C++ API (Python) + + + + C++ API (C#) + + + + Tensor Computation Layer (C++) + + + + C++ API (C#) + + TensorFlowSharp + (by Microsoft) + + + diff --git a/docs/assets/tf.net.icon-purple.svg b/docs/assets/tf.net.icon-purple.svg new file mode 100644 index 000000000..7498987b8 --- /dev/null +++ b/docs/assets/tf.net.icon-purple.svg @@ -0,0 +1,141 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/assets/tf.net.icon-purple128.png b/docs/assets/tf.net.icon-purple128.png new file mode 100644 index 000000000..d79ee7962 Binary files /dev/null and b/docs/assets/tf.net.icon-purple128.png differ diff --git a/docs/assets/tf.net.icon-purple512.png b/docs/assets/tf.net.icon-purple512.png new file mode 100644 index 000000000..0aa94f168 Binary files /dev/null and b/docs/assets/tf.net.icon-purple512.png differ diff --git a/docs/assets/tf.net.icon-transparent.svg b/docs/assets/tf.net.icon-transparent.svg new file mode 100644 index 000000000..e361115b3 --- /dev/null +++ b/docs/assets/tf.net.icon-transparent.svg @@ -0,0 +1,141 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/assets/tf.net.icon-transparent128.png b/docs/assets/tf.net.icon-transparent128.png new file mode 100644 index 000000000..7831c9eb3 Binary files /dev/null and b/docs/assets/tf.net.icon-transparent128.png differ diff --git a/docs/assets/tf.net.icon-transparent512.png b/docs/assets/tf.net.icon-transparent512.png new file mode 100644 index 000000000..57227d9a9 Binary files /dev/null and b/docs/assets/tf.net.icon-transparent512.png differ diff --git a/docs/assets/tf.net.logo.png b/docs/assets/tf.net.logo.png new file mode 100644 index 000000000..ceebc184d Binary files /dev/null and b/docs/assets/tf.net.logo.png differ diff --git a/docs/assets/tf.net.logo.svg b/docs/assets/tf.net.logo.svg new file mode 100644 index 000000000..b6e048ad8 --- /dev/null +++ b/docs/assets/tf.net.logo.svg @@ -0,0 +1,210 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/assets/tf.net.logo512.png b/docs/assets/tf.net.logo512.png new file mode 100644 index 000000000..2e1b4eff9 Binary files /dev/null and b/docs/assets/tf.net.logo512.png differ diff --git a/docs/assets/tf2.jpg b/docs/assets/tf2.jpg new file mode 100644 index 000000000..c4ebd31ec Binary files /dev/null and b/docs/assets/tf2.jpg differ diff --git a/docs/assets/tf2.psd b/docs/assets/tf2.psd new file mode 100644 index 000000000..1cde30235 Binary files /dev/null and b/docs/assets/tf2.psd differ diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..4d9eb83d9 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/source/Constant.md b/docs/source/Constant.md new file mode 100644 index 000000000..dd6aa3bf0 --- /dev/null +++ b/docs/source/Constant.md @@ -0,0 +1,85 @@ +# Chapter 2. Constant + +In TensorFlow, a constant is a special Tensor that cannot be modified while the graph is running. Like in a linear model `y = ax + b`, constant `b` can be represented as a `Constant` Tensor. Since the constant is a Tensor, it also has all the data characteristics of Tensor, including: + +* value: scalar value or constant list matching the data type defined in TensorFlow; +* dtype: data type; +* shape: dimensions; +* name: constant's name; + + + +### How to create a Constant + +TensorFlow provides a handy function to create a Constant. In TF.NET, you can use the same function name `tf.constant` to create it. TF.NET takes the same name as python binding for the API. Naming, although this will make developers who are used to C# naming convention feel uncomfortable, but after careful consideration, I decided to give up the C# convention naming method. One of reason is for model developer, they don't have to learn a totally new different APIs. + +Initialize a scalar constant: + +```csharp +var c1 = tf.constant(3); // int +var c2 = tf.constant(1.0f); // float +var c3 = tf.constant(2.0); // double +var c4 = tf.constant("Big Tree"); // string +``` + +Initialize a constant through ndarray: + +TF.NET works very well with `NumSharp`'s `NDArray`. You can create a tensor from .NET primitive data type and NDArray as well. An `ndarray` is a (usually fixed-size) multidimensional container of items of the same type and size. The number of dimensions and items in an array is defined by its `shape`, which is a tuple of N non-negative integers that specify the sizes of each dimension. + +```csharp +// dtype=int, shape=(2, 3) +var nd = np.array(new int[,] +{ + {1, 2, 3}, + {4, 5, 6} +}); +var tensor = tf.constant(nd); +``` + +### Dive in Constant + +Now let's explore how `constant` works in `eager` mode inside the black box. + +Let's continue using the last examples, we're going to initialize a tensor in an ndarray of `[shape(2, 3), int32]`. + +##### NDArray + +The first thing we need to know is about `ndarray`'s memory model. The ndarray memory model is a very important data structure, and almost all underlying computation are inseparable from this datb a structure. One fundamental aspect of the ndarray is that an array is seen as a "chunk" of memory starting at some location. The interpretation of this memory depends on the stride information. A segment of memory is inherently 1-dimensional, and there are many different schemes for arranging the items of an N-dimensional array in a 1-dimensional block. `ndarray` objects can accommodate any strided indexing scheme. In a strided scheme, the N-dimensional index corresponds to the offset (in bytes) : . + + + +If we take a look at the real memory allocation in Visual Studio, below diagram helps us understand the data structure more intuitively. The strides keep track the size of every single dimension, help identify the actual offset in heap memory. The formula to calculate offset is: `offset = i * strides[0] + j * strides[1]`. + +For example: if you want to seek the value in `[1, 1]`, you just need to calculate `1 * 3 + 1 * 1 = 4`, converted to pointer is `0x000002556B194260 + 4 = 0x000002556B194264` where has a value `05`. + + + +Through the above diagram, we know how the data is stored in memory, and then we will look at how the data is transferred to `TensorFlow`. + +##### Tensor + +If you don't understand very well what `Tensor` is, you can go back to the chapter `Tensor` there is pretty much explanation if you skipped that chapter. Tensor is actually an NDArray that is with more than 2 dimensions. + +TensorFlow will decide whether to copy the data or use the same pointer. Normally speaking, it's more safe whenever you copy data for the following process, especially in interoperating between .NET runtime and C++ runtime that they all have their own garbage collection (GC) mechanism, application will crash if someone access a block of destroyed memory. `TF_STRING` and `TF_RESOURCE` tensors have a different representation in `TF_Tensor` than they do in `tensorflow::Tensor`. Other types have the same representation, so copy only if it is safe to do so. + + + +Before tensorflow is creating the `TF_Tensor`, it checks the shape and data size. If the size doesn't match, it will return `nullptr` pointer. + +##### Get the data of Tensor + +For `eager` mode, it's pretty simple to view the actual value in a `tensor`. + +```csharp +var data = tensor.numpy() +``` + +The `data` will be a `ndarray` variable. + +##### Other functions to create a Constant + +* tf.zeros +* tf.zeros_like +* tf.ones +* tf.ones_like +* tf.fill \ No newline at end of file diff --git a/docs/source/ConvolutionNeuralNetwork.md b/docs/source/ConvolutionNeuralNetwork.md new file mode 100644 index 000000000..6b47c9d8d --- /dev/null +++ b/docs/source/ConvolutionNeuralNetwork.md @@ -0,0 +1,350 @@ +# Chapter. Convolution Neural Network + +In this chapter, we'll implement a simple Convolutional Neural Network model. We'll implement this model to classify MNIST dataset. + + + +The structure of the neural network we're going to build is as follows. The hand-written digits images of the MNIST data which has 10 classes (from 0 to 9). The network is with 2 convolutional layers followed by 2 full-connected layers at the end. + +![neural network architecture](../assets/cnn.png) + +Get started with the implementation: + +1. **Prepare data** + + MNIST is dataset of handwritten digits which contains 55,000 examples for training, 5,000 examples for validation and 10,000 example for testing. The digits have been size-normalized and centered in a fixed-size image (28 x 28 pixels) with values from 0 and 1.Each image has been flattened and converted to a 1-D array of 784 features. It's also kind of benchmark of datasets for deep learning. + + ![MNIST dataset](../assets/mnist.png) + + We define some variables makes it easier to modify them later. + + ```csharp + using System; + using NumSharp; + using Tensorflow; + using TensorFlowNET.Examples.Utility; + using static Tensorflow.Python; + ``` + + ```csharp + const int img_h = 28; + const int img_w = 28; + int n_classes = 10; // Number of classes, one class per digit + int n_channels = 1; + ``` + + We'll write the function which automatically loads the MNIST data and returns it in our desired shape and format. There is an MNIST data helper to make life easier. + + ```csharp + Datasets mnist; + public void PrepareData() + { + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); + } + ``` + + Other than a function for loading the images and corresponding labels, we still need three more functions: + + **reformat:** reformats the data to the format acceptable for convolutional layer. + + ```csharp + private (NDArray, NDArray) Reformat(NDArray x, NDArray y) + { + var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, len(np.unique(np.argmax(y, 1)))); + var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32); + //y[0] = np.arange(num_class) == y[0]; + //var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32); + return (dataset, y); + } + ``` + + + + **randomize**: which randomizes the order of images and their labels. At the beginning of each epoch, we will re-randomize the order of data samples to make sure that the trained model is not sensitive to the order of data. + + ```csharp + private (NDArray, NDArray) randomize(NDArray x, NDArray y) + { + var perm = np.random.permutation(y.shape[0]); + + np.random.shuffle(perm); + return (mnist.train.images[perm], mnist.train.labels[perm]); + } + ``` + + **get_next_batch**: which only selects a few number of images determined by the batch_size variable (as per Stochastic Gradient Descent method). + + ```csharp + private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) + { + var x_batch = x[$"{start}:{end}"]; + var y_batch = y[$"{start}:{end}"]; + return (x_batch, y_batch); + } + ``` + +2. **Set Hyperparameters** + + There're about 55,000 images in training set, it takes a long time to calculate the gradient of the model using all there images. Therefore we use a small batch of images in each iteration of the optimizer by Stochastic Gradient Descent. + + * epoch: one forward pass and one backward pass of all the training examples. + * batch size: the number of training examples in one forward/backward pass. The higher the batch size, the more memory space you'll need. + * iteration: one forward pass and one backward pass of one batch of images the training examples. + + ```csharp + int epochs = 10; + int batch_size = 100; + float learning_rate = 0.001f; + int display_freq = 200; // Frequency of displaying the training results + ``` + +3. **Network configuration** + + 1st convolutional layer: + + ```csharp + int filter_size1 = 5; // Convolution filters are 5 x 5 pixels. + int num_filters1 = 16; // There are 16 of these filters. + int stride1 = 1; // The stride of the sliding window + ``` + + 2nd convolutional layer: + + ```csharp + int filter_size2 = 5; // Convolution filters are 5 x 5 pixels. + int num_filters2 = 32;// There are 32 of these filters. + int stride2 = 1; // The stride of the sliding window + ``` + + Fully-connected layer: + + ```csharp + h1 = 128 # Number of neurons in fully-connected layer. + ``` + + + +4. **Building the neural network** + + Let's make some functions to help build computation graph. + + **variables**: We need to define two variables `W` and `b` to construct our linear model. We use `Tensorflow Variables` of proper size and initialization to define them. + + ```csharp + // Create a weight variable with appropriate initialization + private RefVariable weight_variable(string name, int[] shape) + { + var initer = tf.truncated_normal_initializer(stddev: 0.01f); + return tf.get_variable(name, + dtype: tf.float32, + shape: shape, + initializer: initer); + } + + // Create a bias variable with appropriate initialization + private RefVariable bias_variable(string name, int[] shape) + { + var initial = tf.constant(0f, shape: shape, dtype: tf.float32); + return tf.get_variable(name, + dtype: tf.float32, + initializer: initial); + } + ``` + + **2D convolution layer**: This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. + + ```csharp + private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name) + { + return with(tf.variable_scope(name), delegate { + + var num_in_channel = x.shape[x.NDims - 1]; + var shape = new[] { filter_size, filter_size, num_in_channel, num_filters }; + var W = weight_variable("W", shape); + // var tf.summary.histogram("weight", W); + var b = bias_variable("b", new[] { num_filters }); + // tf.summary.histogram("bias", b); + var layer = tf.nn.conv2d(x, W, + strides: new[] { 1, stride, stride, 1 }, + padding: "SAME"); + layer += b; + return tf.nn.relu(layer); + }); + } + ``` + + **max-pooling layer**: Max pooling operation for temporal data. + + ```csharp + private Tensor max_pool(Tensor x, int ksize, int stride, string name) + { + return tf.nn.max_pool(x, + ksize: new[] { 1, ksize, ksize, 1 }, + strides: new[] { 1, stride, stride, 1 }, + padding: "SAME", + name: name); + } + ``` + + **flatten_layer**: Flattens the output of the convolutional layer to be fed into fully-connected layer. + + ```csharp + private Tensor flatten_layer(Tensor layer) + { + return with(tf.variable_scope("Flatten_layer"), delegate + { + var layer_shape = layer.TensorShape; + var num_features = layer_shape[new Slice(1, 4)].Size; + var layer_flat = tf.reshape(layer, new[] { -1, num_features }); + + return layer_flat; + }); + } + ``` + + + + **fully-connected layer**: Neural network consists of stacks of fully-connected (dense) layers. Having the weight (W) and bias (b) variables, a fully-connected layer is defined as `activation(W x X + b)`. The complete `fc_layer` function is as below: + + ```csharp + private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) + { + return with(tf.variable_scope(name), delegate + { + var in_dim = x.shape[1]; + + var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); + var b = bias_variable("b_" + name, new[] { num_units }); + + var layer = tf.matmul(x, W) + b; + if (use_relu) + layer = tf.nn.relu(layer); + + return layer; + }); + } + ``` + + **inputs**: Now we need to define the proper tensors to feed in the input to our model. Placeholder variable is the suitable choice for the input images and corresponding labels. This allow us to change the inputs (images and labels) to the TensorFlow graph. + + ```csharp + with(tf.name_scope("Input"), delegate + { + // Placeholders for inputs (x) and outputs(y) + x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X"); + y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); + }); + ``` + + Placeholder `y` is the variable for the true labels associated with the images that were input in the placeholder variable `x`. It holds an arbitrary number of labels and each label is a vector of length `num_classes` which is 10. + + **network layers**: After creating the proper input, we have to pass it to our model. Since we have a neural network, we can stack multiple fully-connected layers using `fc_layer` method. Note that we will not use any activation function (use_relu = false) in the last layer. The reason is that we can use `tf.nn.softmax_cross_entropy_with_logits` to calculate the loss. + + ```csharp + var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1"); + var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1"); + var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2"); + var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2"); + var layer_flat = flatten_layer(pool2); + var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true); + var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); + ``` + + **loss function, optimizer, accuracy, prediction**: After creating the network, we have to calculate the loss and optimize it, we have to calculate the `prediction` and `accuracy`. + + ```csharp + with(tf.variable_scope("Train"), delegate + { + + + with(tf.variable_scope("Optimizer"), delegate + { + optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); + }); + + with(tf.variable_scope("Accuracy"), delegate + { + var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); + }); + + with(tf.variable_scope("Prediction"), delegate + { + cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); + }); + }); + ``` + + **initialize variables**: We have to invoke a variable initializer operation to initialize all variables. + +```csharp + var init = tf.global_variables_initializer(); +``` + +5. **Train** + + After creating the graph, we can train our model. To train the model, we have to create a session and run the graph in the session. + + ```csharp + // Number of training iterations in each epoch + var num_tr_iter = y_train.len / batch_size; + + var init = tf.global_variables_initializer(); + sess.run(init); + + float loss_val = 100.0f; + float accuracy_val = 0f; + + foreach (var epoch in range(epochs)) + { + print($"Training epoch: {epoch + 1}"); + // Randomly shuffle the training data at the beginning of each epoch + (x_train, y_train) = mnist.Randomize(x_train, y_train); + + foreach (var iteration in range(num_tr_iter)) + { + var start = iteration * batch_size; + var end = (iteration + 1) * batch_size; + var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); + + // Run optimization op (backprop) + sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + + if (iteration % display_freq == 0) + { + // Calculate and display the batch loss and accuracy + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + loss_val = result[0]; + accuracy_val = result[1]; + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + } + } + + // Run validation after every epoch + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid)); + loss_val = results1[0]; + accuracy_val = results1[1]; + print("---------------------------------------------------------"); + print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); + print("---------------------------------------------------------"); + } + ``` + +6. **Test** + + After the training is done, we have to test our model to see how good it performs on a new dataset. + + ```csharp + public void Test(Session sess) + { + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test)); + loss_test = result[0]; + accuracy_test = result[1]; + print("---------------------------------------------------------"); + print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); + print("---------------------------------------------------------"); + } + ``` + +![](../assets/cnn-result.png) + diff --git a/docs/source/EagerMode.md b/docs/source/EagerMode.md new file mode 100644 index 000000000..ded56d41f --- /dev/null +++ b/docs/source/EagerMode.md @@ -0,0 +1,3 @@ +# Chapter 4. Eager Mode + +TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs: operations return concrete values instead of constructing a computational graph to run later. This makes it easy to get started with TensorFlow and debug models, and it reduces boilerplate as well. \ No newline at end of file diff --git a/docs/source/Foreword.md b/docs/source/Foreword.md new file mode 100644 index 000000000..256094f57 --- /dev/null +++ b/docs/source/Foreword.md @@ -0,0 +1,11 @@ +# Foreword + +One of the most nerve-wracking periods when releasing the first version of an open source project occurs when the [gitter](https://gitter.im/sci-sharp/community) community is created. You are all alone, eagerly hoping and wishing for the first user to come along. I still vividly remember those days. + + + +TensorFlow.NET is my third open source project. BotSharp and NumSharp are the first two. The response is pretty good. I also got a lot of stars on github. Although the first two projects are very difficult, I can't admit that TensorFlow.NET is much more difficult than the previous two, and it is an area I have never been involved with. Mainly related to GPU parallel computing, distributed computing and neural network model. When I started writing this project, I was also sorting out the idea of the coding process. TensorFlow is a huge and complicated project, and it is easy to go beyond the scope of personal ability. Therefore, I want to record the thoughts at the time as much as possible. The process of recording and sorting clears the way of thinking. + + + +All the examples in this book can be found in the github repository of TensorFlow.NET. When the source code and the code in the book are inconsistent, please refer to the source code. The sample code is typically located in the Example or UnitTest project. diff --git a/docs/source/FrontCover.md b/docs/source/FrontCover.md new file mode 100644 index 000000000..322d431c9 --- /dev/null +++ b/docs/source/FrontCover.md @@ -0,0 +1,47 @@ +# The Definitive Guide to TensorFlow.NET + + + + + + + + + + +![Front Cover](_static/front-cover.jpg) + + + + + + + + + +### The CSharp binding for Google's TensorFlow + +#### An Open Source Machine Learning Framework for Everyone + + + + + + + + + + + +

+Haiping Chen
+Christmas, 2018
+

+ + + + + + + + diff --git a/docs/source/Gradient.md b/docs/source/Gradient.md new file mode 100644 index 000000000..818ec73e7 --- /dev/null +++ b/docs/source/Gradient.md @@ -0,0 +1,15 @@ +# Chapter. Gradient + +### Register custom gradient function + +TF.NET is extensible which can be added custom gradient function. + +```csharp +// define gradient function +ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) => +{ + var grad = grads[0]; + return new Tensor[]{ }; +}); +``` + diff --git a/docs/source/Graph.md b/docs/source/Graph.md new file mode 100644 index 000000000..874cd9a42 --- /dev/null +++ b/docs/source/Graph.md @@ -0,0 +1,81 @@ +# Chapter 3. Graph + +TensorFlow uses a **dataflow graph** to represent your computation in terms of the dependencies between individual operations. A graph defines the computation. It doesn't compute anything, it doesn't hold any values, it just defines the operations that you specified in your code. + +### Defining the Graph + +We define a graph with a variable and three operations: `variable` returns the current value of our variable. `initialize` assigns the initial value of 31 to that variable. `assign` assigns the new value of 12 to that variable. + +```csharp +with(tf.Graph().as_default(), graph => +{ + var variable = tf.Variable(31, name: "tree"); + tf.global_variables_initializer(); + variable.assign(12); +}); +``` + +TF.NET simulate a `with` syntax to manage the Graph lifecycle which will be disposed when the graph instance is no long need. The graph is also what the sessions in the next chapter use when not manually specifying a graph because use invoked the `as_default()`. + +A typical graph is looks like below: + +![image](../assets/graph_vis_animation.gif) + + + +### Save Model + +Saving the model means saving all the values of the parameters and the graph. + +```python +saver = tf.train.Saver() +saver.save(sess,'./tensorflowModel.ckpt') +``` + +After saving the model there will be four files: + +* tensorflowModel.ckpt.meta: +* tensorflowModel.ckpt.data-00000-of-00001: +* tensorflowModel.ckpt.index +* checkpoint + +We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`. + +* tensorflowModel.pbtxt: + +This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs. + + + +### Freezing the Graph + +##### *Why we need it?* + +When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph. + +```csharp +from tensorflow.python.tools import freeze_graph + +freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt', + input_saver = "", + input_binary = False, + input_checkpoint = 'logistic_regression/tensorflowModel.ckpt', + output_node_names = "Softmax", + restore_op_name = "save/restore_all", + filename_tensor_name = "save/Const:0", + output_graph = 'frozentensorflowModel.pb', + clear_devices = True, + initializer_nodes = "") + +``` + +### Optimizing for Inference + +To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training. + + + +### Restoring the Model + + + diff --git a/docs/source/HelloWorld.md b/docs/source/HelloWorld.md new file mode 100644 index 000000000..8b7fbf733 --- /dev/null +++ b/docs/source/HelloWorld.md @@ -0,0 +1,77 @@ +# Get started with TensorFlow.NET + +I would describe TensorFlow as an open source machine learning framework developed by Google which can be used to build neural networks and perform a variety of machine learning tasks. it works on data flow graph where nodes are the mathematical operations and the edges are the data in the form of tensor, hence the name Tensor-Flow. + + + +Let's run a classic HelloWorld program first and see if TensorFlow is running on .NET. I can't think of a simpler way to be a HelloWorld. + + + +### Install the TensorFlow.NET SDK + +TensorFlow.NET uses the .NET Standard 2.0 standard, so your new project Target Framework can be .NET Framework or .NET Core/ .NET 5. All the examples in this book are using .NET Core 3.1 and Microsoft Visual Studio Community 2019. To start building TensorFlow program you just need to download and install the .NET SDK (Software Development Kit). You have to download the latest .NET Core SDK from offical website: https://dotnet.microsoft.com/download. + + + +1. New a project + + ![New Project](_static/new-project.png) + +2. Choose Console App (.NET Core) + + ![Console App](_static/new-project-console.png) + + + +```cmd +### install tensorflow C# binding +PM> Install-Package TensorFlow.NET + +### Install tensorflow binary +### For CPU version +PM> Install-Package SciSharp.TensorFlow.Redist + +### For GPU version (CUDA and cuDNN are required) +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU +``` + +### Start coding Hello World + +After installing the TensorFlow.NET package, you can use the `using static Tensorflow.Binding` to introduce the TensorFlow .NET library. + +TensorFlow 2.x enabled `Eager Mode` by default. About what eager mode is, I will introduce it in detail in the following chapters. + +```csharp +using System; +using static Tensorflow.Binding; + +namespace TensorFlowNET.Examples +{ + /// + /// Simple hello world using TensorFlow + /// + class Program + { + static void Main(string[] args) + { + var hello = tf.constant("Hello, TensorFlow!"); + Console.WriteLine(hello); + } + } +} +``` +After CTRL + F5 run, you will get the output. +```cmd +9/20/2020 2:15:09 AM Starting Hello World +tf.Tensor: shape=(), dtype=string, numpy=Hello, TensorFlow.NET! +9/20/2020 2:15:09 AM Completed Hello World +Example: Hello World in 0.1273463s is OK! +TensorFlow.NET v0.20.1.0 +TensorFlow Binary v2.3.0 +1 of 21 example(s) are completed. +Press [Enter] to continue... +``` + +This sample code can be found at [here](https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/src/TensorFlowNET.Examples/HelloWorld.cs). + diff --git a/docs/source/ImageRecognition.md b/docs/source/ImageRecognition.md new file mode 100644 index 000000000..74d3ee5bf --- /dev/null +++ b/docs/source/ImageRecognition.md @@ -0,0 +1,137 @@ +# Chapter. Image Recognition + +An example for using the [TensorFlow.NET](https://github.com/SciSharp/TensorFlow.NET) and [NumSharp](https://github.com/SciSharp/NumSharp) for image recognition, it will use a pre-trained inception model to predict a image which outputs the categories sorted by probability. The original paper is [here](https://arxiv.org/pdf/1512.00567.pdf). The Inception architecture of GoogLeNet was designed to perform well even under strict constraints on memory and computational budget. The computational cost of Inception is also much lower than other performing successors. This has made it feasible to utilize Inception networks in big-data scenarios, where huge amount of data needed to be processed at reasonable cost or scenarios where memory or computational capacity is inherently limited, for example in mobile vision settings. + +The GoogLeNet architecture conforms to below design principles: + +* Avoid representational bottlenecks, especially early in the network. +* Higher dimensional representations are easier to process locally within a network. +* Spatial aggregation can be done over lower dimensional embeddings without much or any loss in representational power. +* Balance the width and depth of the network. + +#### Let's get started with real code. + +##### 1. Prepare data + +This example will download the dataset and uncompress it automatically. Some external paths are omitted, please refer to the source code for the real path. + +```csharp +private void PrepareData() +{ + Directory.CreateDirectory(dir); + + // get model file + string url = "models/inception_v3_2016_08_28_frozen.pb.tar.gz"; + + string zipFile = Path.Join(dir, $"{pbFile}.tar.gz"); + Utility.Web.Download(url, zipFile); + + Utility.Compress.ExtractTGZ(zipFile, dir); + + // download sample picture + string pic = "grace_hopper.jpg"; + Utility.Web.Download($"data/{pic}", Path.Join(dir, pic)); +} +``` + +##### 2. Load image file and normalize + +We need to load a sample image to test our pre-trained inception model. Convert it into tensor and normalized the input image. The pre-trained model takes input in the form of a 4-dimensional tensor with shape [BATCH_SIZE, INPUT_HEIGHT, INPUT_WEIGHT, 3] where: + +- BATCH_SIZE allows for inference of multiple images in one pass through the graph +- INPUT_HEIGHT is the height of the images on which the model was trained +- INPUT_WEIGHT is the width of the images on which the model was trained +- 3 is the (R, G, B) values of the pixel colors represented as a float. + +```csharp +private NDArray ReadTensorFromImageFile(string file_name, + int input_height = 299, + int input_width = 299, + int input_mean = 0, + int input_std = 255) +{ + return with(tf.Graph().as_default(), graph => + { + var file_reader = tf.read_file(file_name, "file_reader"); + var image_reader = tf.image.decode_jpeg(file_reader, channels: 3, name: "jpeg_reader"); + var caster = tf.cast(image_reader, tf.float32); + var dims_expander = tf.expand_dims(caster, 0); + var resize = tf.constant(new int[] { input_height, input_width }); + var bilinear = tf.image.resize_bilinear(dims_expander, resize); + var sub = tf.subtract(bilinear, new float[] { input_mean }); + var normalized = tf.divide(sub, new float[] { input_std }); + + return with(tf.Session(graph), sess => sess.run(normalized)); + }); +} +``` + +##### 3. Load pre-trained model and predict + +Load the pre-trained inception model which is saved as Google's protobuf file format. Construct a new graph then set input and output operations in a new session. After run the session, you will get a numpy-like ndarray which is provided by NumSharp. With NumSharp, you can easily perform various operations on multiple dimensional arrays in the .NET environment. + +```csharp +public void Run() +{ + PrepareData(); + + var labels = File.ReadAllLines(Path.Join(dir, labelFile)); + + var nd = ReadTensorFromImageFile(Path.Join(dir, picFile), + input_height: input_height, + input_width: input_width, + input_mean: input_mean, + input_std: input_std); + + var graph = Graph.ImportFromPB(Path.Join(dir, pbFile)); + var input_operation = graph.get_operation_by_name(input_name); + var output_operation = graph.get_operation_by_name(output_name); + + var results = with(tf.Session(graph), + sess => sess.run(output_operation.outputs[0], + new FeedItem(input_operation.outputs[0], nd))); + + results = np.squeeze(results); + + var argsort = results.argsort(); + var top_k = argsort.Data() + .Skip(results.size - 5) + .Reverse() + .ToArray(); + + foreach (float idx in top_k) + Console.WriteLine($"{picFile}: {idx} {labels[(int)idx]}, {results[(int)idx]}"); +} +``` + +##### 4. Print the result + +The best probability is `military uniform` which is 0.8343058. It's the correct classification. + +```powershell +2/18/2019 3:56:18 AM Starting InceptionArchGoogLeNet +label_image_data\inception_v3_2016_08_28_frozen.pb.tar.gz already exists. +label_image_data\grace_hopper.jpg already exists. +2019-02-19 21:56:18.684463: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 +create_op: Const 'file_reader/filename', inputs: empty, control_inputs: empty, outputs: file_reader/filename:0 +create_op: ReadFile 'file_reader', inputs: file_reader/filename:0, control_inputs: empty, outputs: file_reader:0 +create_op: DecodeJpeg 'jpeg_reader', inputs: file_reader:0, control_inputs: empty, outputs: jpeg_reader:0 +create_op: Cast 'Cast/Cast', inputs: jpeg_reader:0, control_inputs: empty, outputs: Cast/Cast:0 +create_op: Const 'ExpandDims/dim', inputs: empty, control_inputs: empty, outputs: ExpandDims/dim:0 +create_op: ExpandDims 'ExpandDims', inputs: Cast/Cast:0, ExpandDims/dim:0, control_inputs: empty, outputs: ExpandDims:0 +create_op: Const 'Const', inputs: empty, control_inputs: empty, outputs: Const:0 +create_op: ResizeBilinear 'ResizeBilinear', inputs: ExpandDims:0, Const:0, control_inputs: empty, outputs: ResizeBilinear:0 +create_op: Const 'y', inputs: empty, control_inputs: empty, outputs: y:0 +create_op: Sub 'Sub', inputs: ResizeBilinear:0, y:0, control_inputs: empty, outputs: Sub:0 +create_op: Const 'y_1', inputs: empty, control_inputs: empty, outputs: y_1:0 +create_op: RealDiv 'truediv', inputs: Sub:0, y_1:0, control_inputs: empty, outputs: truediv:0 +grace_hopper.jpg: 653 military uniform, 0.8343058 +grace_hopper.jpg: 668 mortarboard, 0.02186947 +grace_hopper.jpg: 401 academic gown, 0.01035806 +grace_hopper.jpg: 716 pickelhaube, 0.008008132 +grace_hopper.jpg: 466 bulletproof vest, 0.005350832 +2/18/2019 3:56:25 AM Completed InceptionArchGoogLeNet +``` + +You can find the full source code from [github](https://github.com/SciSharp/TensorFlow.NET-Examples/tree/master/src/TensorFlowNET.Examples/ImageProcessing). + diff --git a/docs/source/LinearRegression.md b/docs/source/LinearRegression.md new file mode 100644 index 000000000..8033625c3 --- /dev/null +++ b/docs/source/LinearRegression.md @@ -0,0 +1,85 @@ +# Chapter. Linear Regression + + + +### What is linear regression? + +Linear regression is a linear approach to modelling the relationship between a scalar response (or dependent variable) and one or more explanatory variables (or independent variables). + +Consider the case of a single variable of interest y and a single predictor variable x. The predictor variables are called by many names: covariates, inputs, features; the predicted variable is often called response, output, outcome. + +We have some data $D=\{x{\tiny i},y{\tiny i}\}$ and we assume a simple linear model of this dataset with Gaussian noise: + + +```csharp +// Prepare training Data +var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); +var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); +var n_samples = train_X.shape[0]; +``` +![regression dataset](_static/regression-dataset.png) + +Based on the given data points, we try to plot a line that models the points the best. The red line can be modelled based on the linear equation: $y = wx + b$. The motive of the linear regression algorithm is to find the best values for $w$ and $b$. Before moving on to the algorithm, le's have a look at two important concepts you must know to better understand linear regression. + + + +### Cost Function + +The cost function helps us to figure out the best possible values for $w$ and $b$ which would provide the best fit line for the data points. Since we want the best values for $w$ and $b$, we convert this search problem into a minimization problem where we would like to minimize the error between the predicted value and the actual value. + + + +![minimize-square-cost](_static/minimize-square-cost.png) + +We choose the above function to minimize. The difference between the predicted values and ground truth measures the error difference. We square the error difference and sum over all data points and divide that +value by the total number of data points. This provides the average squared error over all the data points. Therefore, this cost function is also known as the Mean Squared Error(MSE) function. Now, using this MSE +function we are going to change the values of $w$ and $b$ such that the MSE value settles at the minima. + + + +```csharp +// tf Graph Input +var X = tf.placeholder(tf.float32); +var Y = tf.placeholder(tf.float32); + +// Set model weights +var W = tf.Variable(rng.randn(), name: "weight"); +var b = tf.Variable(rng.randn(), name: "bias"); + +// Construct a linear model +var pred = tf.add(tf.multiply(X, W), b); + +// Mean squared error +var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples); +``` + + + +### Gradient Descent + +The another important concept needed to understand is gradient descent. Gradient descent is a method of updating $w$ and $b$ to minimize the cost function. The idea is that we start with some random values for $w$ and $b$ and then we change these values iteratively to reduce the cost. Gradient descent helps us on how to update the values or which direction we would go next. Gradient descent is also know as **steepest descent**. + + + + +![gradient-descent](_static/gradient-descent.png) + +To draw an analogy, imagine a pit in the shape of U and you are standing at the topmost point in the pit and your objective is to reach the bottom of the pit. There is a catch, you can only take a discrete number +of steps to reach the bottom. If you decide to take one step at a time you would eventually reach the bottom of the pit but this would take a longer time. If you choose to take longer steps each time, you would +reach sooner but, there is a chance that you could overshoot the bottom of the pit and not exactly at the bottom. In the gradient descent algorithm, the number of steps you take is the learning rate. This +decides on how fast the algorithm converges to the minima. + + + + +```csharp +// Gradient descent +// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default +var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); +``` + +When we visualize the graph in TensorBoard: + +![linear-regression](_static/linear-regression-tensor-board.png) + +The full example is [here](https://github.com/SciSharp/TensorFlow.NET-Examples/blob/master/src/TensorFlowNET.Examples/BasicModels/LinearRegression.cs). diff --git a/docs/source/LogisticRegression.md b/docs/source/LogisticRegression.md new file mode 100644 index 000000000..ddf75f846 --- /dev/null +++ b/docs/source/LogisticRegression.md @@ -0,0 +1,16 @@ +# Chapter. Logistic Regression + +### What is logistic regression? + +Logistic regression is a statistical analysis method used to predict a data value based on prior observations of a data set. A logistic regression model predicts a dependent data variable by analyzing the relationship between one or more existing independent variables. + + + +The dependent variable of logistics regression can be two-category or multi-category, but the two-category is more common and easier to explain. So the most common use in practice is the logistics of the two classifications. An example used by TensorFlow.NET is a hand-written digit recognition, which is a multi-category. + + + +Softmax regression allows us to handle ![1557035393445](_static\logistic-regression\1557035393445.png) where K is the number of classes. + + +The full example is [here](https://github.com/SciSharp/TensorFlow.NET-Examples/blob/master/src/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs). diff --git a/docs/source/MnistInRnn.md b/docs/source/MnistInRnn.md new file mode 100644 index 000000000..ce8a13909 --- /dev/null +++ b/docs/source/MnistInRnn.md @@ -0,0 +1,5 @@ +# Chapter. MNIST In RNN + +### Recurrent Neural Networks + +Recurrent Neural Networks (RNNs) are popular models that have shown great promise in sequential data classification task. The traditional neural network model cannot make the next prediction input based on the knowledge that has been learned before. \ No newline at end of file diff --git a/docs/source/NearestNeighbor.md b/docs/source/NearestNeighbor.md new file mode 100644 index 000000000..94e300df6 --- /dev/null +++ b/docs/source/NearestNeighbor.md @@ -0,0 +1,5 @@ +# Chapter. Nearest Neighbor + +The nearest neighbour algorithm was one of the first algorithms used to solve the travelling salesman problem. In it, the salesman starts at a random city and repeatedly visits the nearest city until all have been visited. It quickly yields a short tour, but usually not the optimal one. + +The full example is [here](https://github.com/SciSharp/TensorFlow.NET-Examples/blob/master/src/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs). \ No newline at end of file diff --git a/docs/source/NeuralNetwork.md b/docs/source/NeuralNetwork.md new file mode 100644 index 000000000..1d46111b7 --- /dev/null +++ b/docs/source/NeuralNetwork.md @@ -0,0 +1,244 @@ +# Chapter. Neural Network + +In this chapter, we'll learn how to build a graph of neural network model. The key advantage of neural network compared to Linear Classifier is that it can separate data which it not linearly separable. We'll implement this model to classify hand-written digits images from the MNIST dataset. + + + +The structure of the neural network we're going to build is as follows. The hand-written digits images of the MNIST data which has 10 classes (from 0 to 9). The network is with 2 hidden layers: the first layer with 200 hidden units (neurons) and the second one (known as classifier layer) with 10 neurons. + +![neural network architecture](../assets/nn.png) + +Get started with the implementation step by step: + +1. **Prepare data** + + MNIST is dataset of handwritten digits which contains 55,000 examples for training, 5,000 examples for validation and 10,000 example for testing. The digits have been size-normalized and centered in a fixed-size image (28 x 28 pixels) with values from 0 and 1.Each image has been flattened and converted to a 1-D array of 784 features. It's also kind of benchmark of datasets for deep learning. + + ![MNIST dataset](../assets/mnist.png) + + We define some variables makes it easier to modify them later. It's important to note that in a linear model, we have to flatten the input images to a vector. + + ```csharp + using System; + using NumSharp; + using Tensorflow; + using TensorFlowNET.Examples.Utility; + using static Tensorflow.Python; + ``` + + ```csharp + const int img_h = 28; + const int img_w = 28; + int img_size_flat = img_h * img_w; // 784, the total number of pixels + int n_classes = 10; // Number of classes, one class per digit + ``` + + We'll write the function which automatically loads the MNIST data and returns it in our desired shape and format. There is an MNIST data helper to make life easier. + + ```csharp + Datasets mnist; + public void PrepareData() + { + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); + } + ``` + + Other than a function for loading the images and corresponding labels, we still need two more functions: + + **randomize**: which randomizes the order of images and their labels. At the beginning of each epoch, we will re-randomize the order of data samples to make sure that the trained model is not sensitive to the order of data. + + ```csharp + private (NDArray, NDArray) randomize(NDArray x, NDArray y) + { + var perm = np.random.permutation(y.shape[0]); + + np.random.shuffle(perm); + return (mnist.train.images[perm], mnist.train.labels[perm]); + } + ``` + + **get_next_batch**: which only selects a few number of images determined by the batch_size variable (as per Stochastic Gradient Descent method). + + ```csharp + private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) + { + var x_batch = x[$"{start}:{end}"]; + var y_batch = y[$"{start}:{end}"]; + return (x_batch, y_batch); + } + ``` + +2. **Set Hyperparameters** + + There're about 55,000 images in training set, it takes a long time to calculate the gradient of the model using all there images. Therefore we use a small batch of images in each iteration of the optimizer by Stochastic Gradient Descent. + + * epoch: one forward pass and one backward pass of all the training examples. + * batch size: the number of training examples in one forward/backward pass. The higher the batch size, the more memory space you'll need. + * iteration: one forward pass and one backward pass of one batch of images the training examples. + + ```csharp + int epochs = 10; + int batch_size = 100; + float learning_rate = 0.001f; + int h1 = 200; // number of nodes in the 1st hidden layer + ``` + +3. **Building the neural network** + + Let's make some functions to help build computation graph. + + **variables**: We need to define two variables `W` and `b` to construct our linear model. We use `Tensorflow Variables` of proper size and initialization to define them. + + ```csharp + // weight_variable + var in_dim = x.shape[1]; + + var initer = tf.truncated_normal_initializer(stddev: 0.01f); + var W = tf.get_variable("W_" + name, + dtype: tf.float32, + shape: (in_dim, num_units), + initializer: initer); + + // bias_variable + var initial = tf.constant(0f, num_units); + var b = tf.get_variable("b_" + name, + dtype: tf.float32, + initializer: initial); + ``` + + **fully-connected layer**: Neural network consists of stacks of fully-connected (dense) layers. Having the weight (W) and bias (b) variables, a fully-connected layer is defined as `activation(W x X + b)`. The complete `fc_layer` function is as below: + + ```csharp + private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) + { + var in_dim = x.shape[1]; + + var initer = tf.truncated_normal_initializer(stddev: 0.01f); + var W = tf.get_variable("W_" + name, + dtype: tf.float32, + shape: (in_dim, num_units), + initializer: initer); + + var initial = tf.constant(0f, num_units); + var b = tf.get_variable("b_" + name, + dtype: tf.float32, + initializer: initial); + + var layer = tf.matmul(x, W) + b; + if (use_relu) + layer = tf.nn.relu(layer); + + return layer; + } + ``` + + **inputs**: Now we need to define the proper tensors to feed in the input to our model. Placeholder variable is the suitable choice for the input images and corresponding labels. This allow us to change the inputs (images and labels) to the TensorFlow graph. + + ```csharp + // Placeholders for inputs (x) and outputs(y) + x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); + y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); + ``` + + Placeholder `x` is defined for the images, the shape is set to `[None, img_size_flat]`, where `None` means that the tensor may hold an arbitrary number of images with each image being a vector of length `img_size_flat`. + + Placeholder `y` is the variable for the true labels associated with the images that were input in the placeholder variable `x`. It holds an arbitrary number of labels and each label is a vector of length `num_classes` which is 10. + + **network layers**: After creating the proper input, we have to pass it to our model. Since we have a neural network, we can stack multiple fully-connected layers using `fc_layer` method. Note that we will not use any activation function (use_relu = false) in the last layer. The reason is that we can use `tf.nn.softmax_cross_entropy_with_logits` to calculate the loss. + + ```csharp + // Create a fully-connected layer with h1 nodes as hidden layer + var fc1 = fc_layer(x, h1, "FC1", use_relu: true); + // Create a fully-connected layer with n_classes nodes as output layer + var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); + ``` + + **loss function**: After creating the network, we have to calculate the loss and optimize it, we have to calculate the `correct_prediction` and `accuracy`. + + ```csharp + // Define the loss function, optimizer, and accuracy + var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); + loss = tf.reduce_mean(logits, name: "loss"); + optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); + var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); + ``` + + **initialize variables**: We have to invoke a variable initializer operation to initialize all variables. + + ```csharp + var init = tf.global_variables_initializer(); + ``` + + The complete computation graph is looks like below: + + ![TensorBoard-nn](../assets/TensorBoard-nn.png) + +4. **Train** + + After creating the graph, we can train our model. To train the model, we have to create a session and run the graph in the session. + + ```csharp + // Number of training iterations in each epoch + var num_tr_iter = mnist.train.labels.len / batch_size; + with(tf.Session(), sess => + { + sess.run(init); + + float loss_val = 100.0f; + float accuracy_val = 0f; + + foreach (var epoch in range(epochs)) + { + print($"Training epoch: {epoch + 1}"); + // Randomly shuffle the training data at the beginning of each epoch + var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels); + + foreach (var iteration in range(num_tr_iter)) + { + var start = iteration * batch_size; + var end = (iteration + 1) * batch_size; + var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end); + + // Run optimization op (backprop) + sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + + if (iteration % display_freq == 0) + { + // Calculate and display the batch loss and accuracy + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + loss_val = result[0]; + accuracy_val = result[1]; + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + } + } + + // Run validation after every epoch + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels)); + loss_val = results1[0]; + accuracy_val = results1[1]; + print("---------------------------------------------------------"); + print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); + print("---------------------------------------------------------"); + } + }); + ``` + +5. **Test** + + After the training is done, we have to test our model to see how good it performs on a new dataset. + + ```csharp + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); + loss_test = result[0]; + accuracy_test = result[1]; + print("---------------------------------------------------------"); + print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); + print("---------------------------------------------------------"); + ``` + + ![result](../assets/nn-result.png) + + + + diff --git a/docs/source/Operation.md b/docs/source/Operation.md new file mode 100644 index 000000000..67a25a37f --- /dev/null +++ b/docs/source/Operation.md @@ -0,0 +1,3 @@ +# Chapter. Operation + +`Operation` represents a `Graph` node that performs computation on tensors. An operation is a `Node` in a `Graph` that takes zero or more `Tensor`s (produced by other Operations in the Graph) as input, and produces zero or more Tensors as output. \ No newline at end of file diff --git a/docs/source/Placeholder.md b/docs/source/Placeholder.md new file mode 100644 index 000000000..2cf345bd0 --- /dev/null +++ b/docs/source/Placeholder.md @@ -0,0 +1,20 @@ +# Chapter. Placeholder + +In this chapter we will talk about another common data type in TensorFlow: Placeholder. It is a simplified variable that can be passed to the required value by the session when the graph is run, that is, when you build the graph, you don't need to specify the value of that variable, but delay the session to the beginning. In TensorFlow terminology, we then feed data into the graph through these placeholders. The difference between placeholders and constants is that placeholders can specify coefficient values more flexibly without modifying the code that builds the graph. For example, mathematical constants are suitable for Constant, and some model smoothing values can be specified with Placeholder. + + + +```csharp +var x = tf.placeholder(tf.int32); +var y = x * 3; + +using (var sess = tf.Session()) +{ + var result = sess.run(y, feed_dict: new FeedItem[] + { + new FeedItem(x, 2) + }); + // (int)result should be 6; +} +``` + diff --git a/docs/source/Preface.md b/docs/source/Preface.md new file mode 100644 index 000000000..65db8c1d2 --- /dev/null +++ b/docs/source/Preface.md @@ -0,0 +1,15 @@ + + +# Preface + +Why do I start the TensorFlow.NET project? + +In a few days, it was Christmas in 2018. I watched my children grow up and be sensible every day, and I felt that time passed too fast. IT technology updates are faster than ever, and a variety of front-end technologies are emerging. Big data, Artificial Intelligence and Blockchain, Container technology and Microservices, Distributed Computing and Serverless technology are dazzling. The Amazon AI service interface claims that engineers who don't need any machine learning experience can use it, so that the idea of just calming down for two years and planning to switch to an AI architecture in the future is a splash of cold water. + + + +TensorFlow is an open source project for machine learning especially for deep learning. It's used for both research and production at Google company. It's designed according to dataflow programming pattern across a range of tasks. TensorFlow is not just a deep learning library. As long as you can represent your calculation process as a data flow diagram, you can use TensorFlow for distributed computing. TensorFlow uses a computational graph to build a computing network while operating on the graph. Users can write their own upper-level models in Python based on TensorFlow, or extend the underlying C++ custom action code to TensorFlow. + + + +In order to avoid confusion, the unique classes defined in TensorFlow are not translated in this book. For example, Tensor, Graph, Shape will retain the English name. diff --git a/docs/source/Queue.md b/docs/source/Queue.md new file mode 100644 index 000000000..7f137fb32 --- /dev/null +++ b/docs/source/Queue.md @@ -0,0 +1,157 @@ +# Chapter. Queue + +ThensorFlow is capable to handle multiple threads, and queues are powerful mechanism for asynchronous computation. If we have large datasets this can significantly speed up the training process of our models. This functionality is especially handy when reading, pre-processing and extracting in mini-batches our training data. The secret to being able to do professional and high performance training of our model is understanding TensorFlow queuing operations. TensorFlow has implemented 4 types of Queue: **FIFOQueue**, **PaddingFIFOQueue**, **PriorityQueue** and **RandomShuffleQueue**. + +![FIFOQueue](_static/FIFOQueue-example.jpg) + +Like everything in TensorFlow, a queue is a node in a computation graph. It's a stateful node, like a variable: other nodes can modify its content, In particular, nodes can enqueue new items into the queue, or dequeue existing items from the queue. + +To get started with queue, let's consider a simple example. We will create a "first in, first out" queue (FIFOQueue) and fill it with numbers. Then we'll construct a graph that takes an item off the queue, adds one to that item, and puts it back on the end of the queue. + +```csharp +[TestMethod] +public void FIFOQueue() +{ + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 2 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + using (var sess = tf.Session()) + { + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } +} +``` + +`Enqueue`, `EnqueueMany` and `Dequeue` are special nodes. They take a pointer to the queue instead of a normal value, allowing them to change it. I first create a FIFOQueue *queue* of size up to 3, I enqueue two values into the *queue*. Then I immediately attempt to *dequeue* a value from it and assign it to *y* where I simply add 1 to the dequeued variable. Next, we start up a *session* and run. After we've run this operation a few times the queue will be empty - if we try and run the operation again, the main thread of the program will hang or block - this is because it will be waiting for another operation to be run to put more values in the queue. + +#### FIFOQueue + +Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `FIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + +#### PaddingFIFOQueue + +A FIFOQueue that supports batching variable-sized tensors by padding. A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are described by the `shapes` argument. + +```chsarp +[TestMethod] +public void PaddingFIFOQueue() +{ + var numbers = tf.placeholder(tf.int32); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1)); + var enqueue = queue.enqueue(numbers); + var dequeue_many = queue.dequeue_many(n: 3); + + using(var sess = tf.Session()) + { + sess.run(enqueue, (numbers, new[] { 1 })); + sess.run(enqueue, (numbers, new[] { 2, 3 })); + sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); + + var result = sess.run(dequeue_many[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); + } +} +``` + + + +#### PriorityQueue + +A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. + +```csharp +[TestMethod] +public void PriorityQueue() +{ + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + using (var sess = tf.Session()) + { + init.run(); + + // output will 2, 3, 4 + var result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 2L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 3L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 4L); + } +} +``` + + + +#### RandomShuffleQueue + +A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + +```csharp +[TestMethod] +public void RandomShuffleQueue() +{ + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + using (var sess = tf.Session()) + { + init.run(); + + foreach(var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + // 1.2.3.4.5.6.7.8.9. + } +} +``` + + + +Queue methods must run on the same device as the queue. `FIFOQueue` and `RandomShuffleQueue` are important TensorFlow objects for computing tensor asynchronously in a graph. For example, a typical input architecture is to use a `RandomShuffleQueue` to prepare inputs for training a model: + +* Multiple threads prepare training examples and push them in the queue. +* A training thread executes a training op that dequeues mini-batches from the queue. + +This architecture simplifies the construction of input pipelines. + + + +From the above example, once the output gets to the point above you’ll actually have to terminate the program as it is blocked. Now, this isn’t very useful. What we really want to happen is for our little program to reload or enqueue more values whenever our queue is empty or is about to become empty. We could fix this by explicitly running our *enqueue_op* again in the code above to reload our queue with values. However, for large, more realistic programs, this will become unwieldy. Thankfully, TensorFlow has a solution. + +TensorFlow provides two classes to help multi-threading task: `tf.Coordinator` and `tf.QueueRunner`. There two classes are designed to be used together. The `Coordinator` class helps multiple threads stop together and report exceptions to a main thread. The `QueueRunner` class is used to create a number of threads cooperating to enqueue tensors in the same queue. diff --git a/docs/source/Session.md b/docs/source/Session.md new file mode 100644 index 000000000..d6f249048 --- /dev/null +++ b/docs/source/Session.md @@ -0,0 +1,29 @@ +# Chapter. Session + +TensorFlow **session** runs parts of the graph across a set of local and remote devices. A session allows to execute graphs or part of graphs. It allocates resources (on one or more machines) for that and holds the actual values of intermediate results and variables. + + + +### Running Computations in a Session + +Let's complete the example in last chapter. To run any of the operations, we need to create a session for that graph. The session will also allocate memory to store the current value of the variable. + + + +```csharp +with(tf.Graph(), graph => +{ + var variable = tf.Variable(31, name: "tree"); + var init = tf.global_variables_initializer(); + + var sess = tf.Session(graph); + sess.run(init); + + var result = sess.run(variable); // 31 + + var assign = variable.assign(12); + result = sess.run(assign); // 12 +}); +``` + +The value of our variables is only valid within one session. If we try to get the value in another session. TensorFlow will raise an error of `Attempting to use uninitialized value foo`. Of course, we can use the graph in more than one session, because session copies graph definition to new memory area. We just have to initialize the variables again. The values in the new session will be completely independent from the previous one. diff --git a/docs/source/Table of Contents.md b/docs/source/Table of Contents.md new file mode 100644 index 000000000..b28505cc8 --- /dev/null +++ b/docs/source/Table of Contents.md @@ -0,0 +1,44 @@ +# Table of Contents + +### Foreword...........................................................................................xxi + +### Preface..............................................................................................xxiii + +## Part I. Getting Started + +##### 1. You Know, for Machine Learning............................................................................................ 3 + +​ Installing Tensorflow.NET +​ Running Tensorflow.NET +​ Talking to Tensorflow.NET + +##### 2. Hello World + + + +## Part II. Tensorflow.NET in Depth + +##### 1. Control Dependency ...................................................................................................... + +##### 2. Graph .................................... + +##### 3. Session ............................ + + + +## Part III. Dealing with Human Language + +##### 1. Text Classification ............................................................................................ + +##### 2. Named Entity Recognition .............................................................................. + +##### 3. Sentiment Analyze ........................................................................................... + +##### 4. Sentence Dependency ........................................................................ + + + +## Part IV. Image Recognition + +##### 1. Inception Model ................................................................................................................. 100 + diff --git a/docs/source/Tensor.md b/docs/source/Tensor.md new file mode 100644 index 000000000..aefb884f7 --- /dev/null +++ b/docs/source/Tensor.md @@ -0,0 +1,52 @@ +# Chapter 1. Tensor + +### Represents one of the outputs of an Operation + + + +##### What is Tensor? + +Tensor holds a multi-dimensional array of elements of a single data type which is very similar with `NumPy`'s `ndarray`. When the dimension is zero, it can be called a scalar. When the dimension is 2, it can be called a matrix. When the dimension is greater than 2, it is usually called a tensor. If you are very familiar with `NumPy`, then understanding Tensor will be quite easy. + + + +##### How to create a Tensor? + +There are many ways to initialize a Tensor object in TF.NET. It can be initialized from a scalar, string, matrix or tensor. But the best way to create a Tensor is using high level APIs like `tf.constant`, `tf.zeros` and `tf.ones`. We'll talk about constant more detail in next chapter. + +```csharp +// Create a tensor holds a scalar value +var t1 = new Tensor(3); + +// Init from a string +var t2 = new Tensor("Hello! TensorFlow.NET"); + +// Tensor holds a ndarray +var nd = new NDArray(new int[]{3, 1, 1, 2}); +var t3 = new Tensor(nd); + +Console.WriteLine($"t1: {t1}, t2: {t2}, t3: {t3}"); +``` + + + +##### Data Structure of Tensor + +TF uses column major order. If we use NumSharp to generate a 2 x 3 matrix, if we access the data from 0 to 5 in order, we won't get a number of 1-6, but we get the order of 1, 4, 2, 5, 3, 6. a set of numbers. + +```csharp +// Generate a matrix:[[1, 2, 3], [4, 5, 6]] +var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); +// The index will be 0 2 4 1 3 5, it's column-major order. +``` + + + +![column-major order](_static/column-major-order.png) + +![row-major order](_static/row-major-order.png) + +##### Index/ Slice of Tensor + +Tensor element can be accessed by `index` and `slice` related operations. Through some high level APIs, we can easily access specific dimension's data. + diff --git a/docs/source/Train.md b/docs/source/Train.md new file mode 100644 index 000000000..85d441ba0 --- /dev/null +++ b/docs/source/Train.md @@ -0,0 +1,12 @@ +# Chapter. Trainer + +### Saver + +The `tf.train.saver` class provides methods to save and restore models. + + + +### Saver Builder + +##### Bulk Saver Builder + diff --git a/docs/source/Variable.md b/docs/source/Variable.md new file mode 100644 index 000000000..c4f6a6af6 --- /dev/null +++ b/docs/source/Variable.md @@ -0,0 +1,18 @@ +# Chapter. Variable + +The variables in TensorFlow are mainly used to represent variable parameter values in the machine learning model. Variables can be initialized by the `tf.Variable` function. During the graph computation the variables are modified by other operations. Variables exist in the session, as long as they are in the same session, other computing nodes on the network can access the same variable value. Variables use lazy loading and will only request memory space when they are used. + + + +```csharp +var x = tf.Variable(10, name: "x"); +using (var session = tf.Session()) +{ + session.run(x.initializer); + var result = session.run(x); + Console.Write(result); // should be 10 +} +``` + +The above code first creates a variable operation, initializes the variable, then runs the session, and finally gets the result. This code is very simple, but it shows the complete process how TensorFlow operates on variables. When creating a variable, you pass a `tensor` as the initial value to the function `Variable()`. TensorFlow provides a series of operators to initialize the tensor, the initial value is a constant or a random value. + diff --git a/docs/source/_static/FIFOQueue-example.jpg b/docs/source/_static/FIFOQueue-example.jpg new file mode 100644 index 000000000..ac2749346 Binary files /dev/null and b/docs/source/_static/FIFOQueue-example.jpg differ diff --git a/docs/The-Definitive-Guide/assets/column-major-order.png b/docs/source/_static/column-major-order.png similarity index 100% rename from docs/The-Definitive-Guide/assets/column-major-order.png rename to docs/source/_static/column-major-order.png diff --git a/docs/source/_static/constant/n-index-formula-offset.svg b/docs/source/_static/constant/n-index-formula-offset.svg new file mode 100644 index 000000000..6c5a3219c --- /dev/null +++ b/docs/source/_static/constant/n-index-formula-offset.svg @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/source/_static/constant/n-index-formula.svg b/docs/source/_static/constant/n-index-formula.svg new file mode 100644 index 000000000..5d05c06f0 --- /dev/null +++ b/docs/source/_static/constant/n-index-formula.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/source/_static/contiguous-block-of-memory-ndarray-example-1.png b/docs/source/_static/contiguous-block-of-memory-ndarray-example-1.png new file mode 100644 index 000000000..140e37716 Binary files /dev/null and b/docs/source/_static/contiguous-block-of-memory-ndarray-example-1.png differ diff --git a/docs/source/_static/contiguous-block-of-memory.png b/docs/source/_static/contiguous-block-of-memory.png new file mode 100644 index 000000000..44d3ab62f Binary files /dev/null and b/docs/source/_static/contiguous-block-of-memory.png differ diff --git a/docs/source/_static/front-cover.jpg b/docs/source/_static/front-cover.jpg new file mode 100644 index 000000000..3452f8001 Binary files /dev/null and b/docs/source/_static/front-cover.jpg differ diff --git a/docs/source/_static/gradient-descent.png b/docs/source/_static/gradient-descent.png new file mode 100644 index 000000000..3fcde528d Binary files /dev/null and b/docs/source/_static/gradient-descent.png differ diff --git a/docs/source/_static/linear-regression-tensor-board.png b/docs/source/_static/linear-regression-tensor-board.png new file mode 100644 index 000000000..ea9304a02 Binary files /dev/null and b/docs/source/_static/linear-regression-tensor-board.png differ diff --git a/docs/source/_static/logistic-regression/1557035393445.png b/docs/source/_static/logistic-regression/1557035393445.png new file mode 100644 index 000000000..a9ca67a8b Binary files /dev/null and b/docs/source/_static/logistic-regression/1557035393445.png differ diff --git a/docs/source/_static/minimize-square-cost.png b/docs/source/_static/minimize-square-cost.png new file mode 100644 index 000000000..229a4cf53 Binary files /dev/null and b/docs/source/_static/minimize-square-cost.png differ diff --git a/docs/source/_static/new-project-console.png b/docs/source/_static/new-project-console.png new file mode 100644 index 000000000..d4bfbc68c Binary files /dev/null and b/docs/source/_static/new-project-console.png differ diff --git a/docs/source/_static/new-project.png b/docs/source/_static/new-project.png new file mode 100644 index 000000000..789b5f1fd Binary files /dev/null and b/docs/source/_static/new-project.png differ diff --git a/docs/source/_static/regression-dataset.png b/docs/source/_static/regression-dataset.png new file mode 100644 index 000000000..0cd46f46e Binary files /dev/null and b/docs/source/_static/regression-dataset.png differ diff --git a/docs/The-Definitive-Guide/assets/row-major-order.png b/docs/source/_static/row-major-order.png similarity index 100% rename from docs/The-Definitive-Guide/assets/row-major-order.png rename to docs/source/_static/row-major-order.png diff --git a/docs/source/_static/sigmoid.png b/docs/source/_static/sigmoid.png new file mode 100644 index 000000000..4321a1a41 Binary files /dev/null and b/docs/source/_static/sigmoid.png differ diff --git a/docs/source/_static/tensor-constant-ndarray.png b/docs/source/_static/tensor-constant-ndarray.png new file mode 100644 index 000000000..3610ee0cd Binary files /dev/null and b/docs/source/_static/tensor-constant-ndarray.png differ diff --git a/docs/source/_static/tensor-naming.png b/docs/source/_static/tensor-naming.png new file mode 100644 index 000000000..7b1d408b9 Binary files /dev/null and b/docs/source/_static/tensor-naming.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 000000000..68b6f5aee --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'TensorFlow.NET' +copyright = '2019, Haiping Chen' +author = 'Haiping Chen' + +# The short X.Y version +version = '0.6.0' +# The full version, including alpha/beta/rc tags +release = '0.6.0' + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.mathjax', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +from recommonmark.parser import CommonMarkParser +source_parsers = {'.md': CommonMarkParser} +source_suffix = ['.rst', '.md'] + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'TensorFlowNETdoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'TensorFlowNET.tex', 'TensorFlow.NET Documentation', + 'Haiping Chen', 'manual'), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'tensorflownet', 'TensorFlow.NET Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'TensorFlowNET', 'TensorFlow.NET Documentation', + author, 'TensorFlowNET', 'One line description of project.', + 'Miscellaneous'), +] + + +# -- Options for Epub output ------------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + + +# -- Extension configuration ------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 000000000..61f0d752e --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,34 @@ +.. TensorFlow.NET documentation master file, created by + sphinx-quickstart on Sat Jan 5 09:26:55 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to TensorFlow.NET's documentation! +========================================== + + +.. toctree:: + :maxdepth: 3 + :caption: The Definitive Guide to TensorFlow.NET + + FrontCover + Foreword + Preface + HelloWorld + Tensor + Constant + Variable + Placeholder + Graph + Session + Operation + Queue + Gradient + Train + EagerMode + LinearRegression + LogisticRegression + NearestNeighbor + ImageRecognition + NeuralNetwork + ConvolutionNeuralNetwork \ No newline at end of file diff --git a/graph/InceptionV3.meta b/graph/InceptionV3.meta new file mode 100644 index 000000000..fe220cce1 Binary files /dev/null and b/graph/InceptionV3.meta differ diff --git a/graph/README.md b/graph/README.md new file mode 100644 index 000000000..491e2374a --- /dev/null +++ b/graph/README.md @@ -0,0 +1 @@ +These are models built with the original tensorflow library. They can be imported in TensorFlow.NET and trained. See the examples for how to do so. \ No newline at end of file diff --git a/graph/att_rnn_untrained.meta b/graph/att_rnn_untrained.meta new file mode 100644 index 000000000..438e37aeb Binary files /dev/null and b/graph/att_rnn_untrained.meta differ diff --git a/graph/char_cnn_untrained.meta b/graph/char_cnn_untrained.meta new file mode 100644 index 000000000..1e99f99b6 Binary files /dev/null and b/graph/char_cnn_untrained.meta differ diff --git a/graph/cond_test.meta b/graph/cond_test.meta new file mode 100644 index 000000000..2110d5770 Binary files /dev/null and b/graph/cond_test.meta differ diff --git a/graph/kmeans.meta b/graph/kmeans.meta new file mode 100644 index 000000000..0ad4f03f2 Binary files /dev/null and b/graph/kmeans.meta differ diff --git a/graph/lstm_crf_ner.meta b/graph/lstm_crf_ner.meta new file mode 100644 index 000000000..19a267e29 Binary files /dev/null and b/graph/lstm_crf_ner.meta differ diff --git a/graph/rcnn_untrained.meta b/graph/rcnn_untrained.meta new file mode 100644 index 000000000..1cd86abba Binary files /dev/null and b/graph/rcnn_untrained.meta differ diff --git a/graph/vd_cnn.meta b/graph/vd_cnn.meta new file mode 100644 index 000000000..b857fc6c5 Binary files /dev/null and b/graph/vd_cnn.meta differ diff --git a/graph/word2vec.meta b/graph/word2vec.meta new file mode 100644 index 000000000..df120b7f8 Binary files /dev/null and b/graph/word2vec.meta differ diff --git a/graph/word_cnn.meta b/graph/word_cnn.meta new file mode 100644 index 000000000..141947b19 Binary files /dev/null and b/graph/word_cnn.meta differ diff --git a/graph/word_cnn_untrained.meta b/graph/word_cnn_untrained.meta new file mode 100644 index 000000000..a29a33a0a Binary files /dev/null and b/graph/word_cnn_untrained.meta differ diff --git a/graph/word_rnn_untrained.meta b/graph/word_rnn_untrained.meta new file mode 100644 index 000000000..a5c749a9b Binary files /dev/null and b/graph/word_rnn_untrained.meta differ diff --git a/graph/xor.meta b/graph/xor.meta new file mode 100644 index 000000000..f466e49af Binary files /dev/null and b/graph/xor.meta differ diff --git a/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 000000000..4d0fa1f0e --- /dev/null +++ b/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,63 @@ + + + + netstandard2.0 + win-x64;linux-x64 + SciSharp.Tensorflow-Cpu.Redist + + SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Meta-package for GPU Tensoflow library runtime distribution. + Libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + + + + + + + + + + + + ../../packages;$(RestoreSources);https://api.nuget.org/v3/index.json + + + + + + + + + + + + + runtime.json + true + PreserveNewest + + + + diff --git a/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json b/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json new file mode 100644 index 000000000..a7a39cb52 --- /dev/null +++ b/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json @@ -0,0 +1,14 @@ +{ + "runtimes": { + "linux-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist": "1.0.0" + } + }, + "win-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.win-x64.SciSharp.Tensorflow-Cpu.Redist": "1.0.0" + } + } + } +} diff --git a/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 000000000..61ea992ed --- /dev/null +++ b/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,81 @@ + + + + Library + netstandard2.0 + + win-x64;linux-x64 + SciSharp.Tensorflow-Gpu.Redist + + SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Meta-package for GPU Tensoflow library runtime distribution. + Libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + + + + + + + + + + + + ../../packages;$(RestoreSources);https://api.nuget.org/v3/index.json + + + + + + + + + + + + + runtime.json + true + PreserveNewest + + + + diff --git a/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json b/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json new file mode 100644 index 000000000..392dc3ccf --- /dev/null +++ b/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json @@ -0,0 +1,14 @@ +{ + "runtimes": { + "linux-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist": "1.0.0" + } + }, + "win-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.win-x64.SciSharp.Tensorflow-Gpu.Redist": "1.0.0" + } + } + } +} diff --git a/redist/TensorFlow.NET.Redist.sln b/redist/TensorFlow.NET.Redist.sln new file mode 100644 index 000000000..a21dc9dc9 --- /dev/null +++ b/redist/TensorFlow.NET.Redist.sln @@ -0,0 +1,60 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29102.190 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{1E65784D-C976-4DFF-991A-DD5C57FFC8E2}" + ProjectSection(SolutionItems) = preProject + scripts\Copy-NativeTensorFlowLibs.ps1 = scripts\Copy-NativeTensorFlowLibs.ps1 + EndProjectSection +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist", "runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist\runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj", "{9834D2B4-01BF-4D18-8DCF-F498AC481FE7}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist", "runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist\runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj", "{9D853997-3143-4F87-B995-7D7024CF4E1A}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist", "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist\runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj", "{878C1EE4-B945-41BF-98DE-C4747C28022A}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist", "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist\runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj", "{744A3D51-CEF6-4685-B4C3-718FA61143A0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SciSharp.TensorFlow-Cpu.Redist", "SciSharp.TensorFlow-Cpu.Redist\SciSharp.TensorFlow-Cpu.Redist.csproj", "{0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SciSharp.TensorFlow-Gpu.Redist", "SciSharp.TensorFlow-Gpu.Redist\SciSharp.TensorFlow-Gpu.Redist.csproj", "{1910BE36-82E3-4465-B3B1-788BFD252DB7}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Release|Any CPU.Build.0 = Release|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Release|Any CPU.Build.0 = Release|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Release|Any CPU.Build.0 = Release|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Release|Any CPU.Build.0 = Release|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Release|Any CPU.Build.0 = Release|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {CD7D5F34-42AE-4CCB-BDFA-1619B3A84708} + EndGlobalSection +EndGlobal diff --git a/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 000000000..ea6d4186d --- /dev/null +++ b/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,39 @@ + + + + netstandard2.0 + linux-x64 + SciSharp.Tensorflow-Cpu.Redist + + runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the Linux CPU Tensoflow library. + The libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + diff --git a/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 000000000..d680f38a6 --- /dev/null +++ b/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,40 @@ + + + + Library + netstandard2.0 + linux-x64 + SciSharp.Tensorflow-Gpu.Redist + + runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the Linux GPU Tensoflow library. + Dll can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + \ No newline at end of file diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 000000000..19e7854cb --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,39 @@ + + + + netstandard2.0 + win-x64 + SciSharp.Tensorflow-Cpu.Redist + + runtime.win-x64.SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the windows GPU Tensoflow library. + The libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore new file mode 100644 index 000000000..fca132d9b --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore @@ -0,0 +1 @@ +tensorflow.dll diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 000000000..915e0e2a5 --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,40 @@ + + + + Library + netstandard2.0 + win-x64 + SciSharp.Tensorflow-Gpu.Redist + + runtime.win-x64.SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the windows GPU Tensoflow library. + Dll can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/CommonPackage.props b/src/SciSharp.TensorFlow.Redist/CommonPackage.props new file mode 100644 index 000000000..08fbb153a --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/CommonPackage.props @@ -0,0 +1,24 @@ + + + + + + PreserveNewest + false + %(Filename)%(Extension) + + + PreserveNewest + false + %(Filename)%(Extension) + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md new file mode 100644 index 000000000..4002aa21d --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/README.md @@ -0,0 +1,40 @@ +## SciSharp.TensorFlow.Redist ## + + +`SciSharp.TensorFlow.Redist` is a migration from [Microsoft.ML.TensorFlow.Redist](https://github.com/dotnet/machinelearning/tree/release/1.2/src/Redist/Microsoft.ML.TensorFlow.Redist). [ML.NET](https://github.com/dotnet/machinelearning) team will not maintain the package since [ML.NET](https://www.nuget.org/packages/Microsoft.ML) v1.3.0 going forward. + +* CPU version for all platforms (Windows, Linux, OSX) +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist +``` + +* GPU version for Windows +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU +``` + +* GPU version for Linux +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU +``` + +https://www.nuget.org/packages/SciSharp.TensorFlow.Redist + +Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). + + + +#### Download pre-build package + +[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip) + + + +#### Pack and Deploy #### + +On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. + +1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. +2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` + + diff --git a/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec new file mode 100644 index 000000000..1524a0f86 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec @@ -0,0 +1,27 @@ + + + + $packageId$ + $version$ + The TensorFlow Authors + The TensorFlow Authors + true + LICENSE.txt + https://aka.ms/deprecateLicenseUrl + https://www.tensorflow.org/ + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + $packageId$ contains the TensorFlow C library CPU version $version$ redistributed as a NuGet package. + https://github.com/tensorflow/tensorflow/releases/tag/v$version$ + Copyright 2019 The TensorFlow Authors. All rights reserved. + TensorFlow + + + + + + + + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/Redist-Linux-GPU.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-Linux-GPU.nuspec new file mode 100644 index 000000000..3a08b37d6 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/Redist-Linux-GPU.nuspec @@ -0,0 +1,27 @@ + + + + $packageId$ + $version$ + The TensorFlow Authors + The TensorFlow Authors + true + LICENSE.txt + https://aka.ms/deprecateLicenseUrl + https://www.tensorflow.org/ + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + $packageId$ contains the TensorFlow C library GPU for Linux version $version$ redistributed as a NuGet package. + https://github.com/tensorflow/tensorflow/releases/tag/v$version$ + Copyright 2019 The TensorFlow Authors. All rights reserved. + TensorFlow + + + + + + + + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec new file mode 100644 index 000000000..769838368 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec @@ -0,0 +1,27 @@ + + + + $packageId$ + $version$ + The TensorFlow Authors + The TensorFlow Authors + true + LICENSE.txt + https://aka.ms/deprecateLicenseUrl + https://www.tensorflow.org/ + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + $packageId$ contains the TensorFlow C library GPU for Windows version $version$ redistributed as a NuGet package. + https://github.com/tensorflow/tensorflow/releases/tag/v$version$ + Copyright 2019 The TensorFlow Authors. All rights reserved. + TensorFlow + + + + + + + + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj new file mode 100644 index 000000000..63e11ed4e --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj @@ -0,0 +1,174 @@ + + + + $(MSBuildThisFileDirectory) + $(ProjDir)bin\ + $(ProjDir)obj\ + + x64 + netstandard2.0 + 1.14.0 + 1 + + $(BinDir)packages\ + $(MSBuildProjectName) + $(TensorFlowVersion) + + true + false + + Redist-Linux-GPU.nuspec + packageId=$(PackageId);version=$(PackageVersion) + $(ProjDir) + + CopyFilesFromArchive + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) + + + + + false + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + + + + + + + + + + + + @(FilesWithHashes->'%(FileHash)') + $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", "")) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj new file mode 100644 index 000000000..e2b101fac --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj @@ -0,0 +1,172 @@ + + + + $(MSBuildThisFileDirectory) + $(ProjDir)bin\ + $(ProjDir)obj\ + + x64 + netstandard2.0 + 1.14.0 + 1 + + $(BinDir)packages\ + $(MSBuildProjectName) + $(TensorFlowVersion) + + true + false + + Redist-Windows-GPU.nuspec + packageId=$(PackageId);version=$(PackageVersion) + $(ProjDir) + + CopyFilesFromArchive + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) + + + + + false + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + + + + + + + + + + + + @(FilesWithHashes->'%(FileHash)') + $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", "")) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj new file mode 100644 index 000000000..85ca28984 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj @@ -0,0 +1,187 @@ + + + + $(MSBuildThisFileDirectory) + $(ProjDir)bin\ + $(ProjDir)obj\ + + x64 + netstandard2.0 + 1.15.0 + 1 + + $(BinDir)packages\ + $(MSBuildProjectName) + $(TensorFlowVersion) + + true + false + + Redist-CPU.nuspec + packageId=$(PackageId);version=$(PackageVersion) + $(ProjDir) + + CopyFilesFromArchive + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) + + + + + false + + + + + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + + + + + + + + + + + + @(FilesWithHashes->'%(FileHash)') + $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", "")) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs new file mode 100644 index 000000000..a91b86827 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -0,0 +1,107 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; +using static Tensorflow.CppShapeInferenceResult.Types; + +namespace Tensorflow +{ + /// + /// C API for TensorFlow. + /// Port from tensorflow\c\c_api.h + /// + /// The API leans towards simplicity and uniformity instead of convenience + /// since most usage will be by language specific wrappers. + /// + /// The params type mapping between c_api and .NET + /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) + /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) + /// struct => struct (TF_Output output) => (TF_Output output) + /// struct* => struct[] (TF_Output* output) => (TF_Output[] output) + /// struct* => struct* for ref + /// const char* => string + /// int32_t => int + /// int64_t* => long[] + /// size_t* => ulong[] + /// size_t* => ref ulong + /// void* => IntPtr + /// string => IntPtr c_api.StringPiece(IntPtr) + /// unsigned char => byte + /// + public partial class c_api + { + public const string TensorFlowLibName = "tensorflow"; + + public static string StringPiece(IntPtr handle) + { + return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); + } + + public unsafe static byte[] ByteStringPiece(Buffer? handle) + { + if (handle is null) + { + return new byte[0]; + } + var data = handle.ToArray(); + return data; + } + + public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) + { + if (handle == IntPtr.Zero) + { + return new byte[0]; + } + + byte* str_data = (byte*)handle.ToPointer(); + List bytes = new List(); + byte current = 255; + while (current != ((byte)'\0')) + { + current = *(str_data++); + bytes.Add(current); + } + var data = bytes.ToArray(); + return data; + } + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DeallocatorV2(IntPtr data, long size, IntPtr args); + + public struct DeallocatorArgs + { + internal static unsafe c_api.DeallocatorArgs* EmptyPtr; + internal static unsafe IntPtr Empty; + + static unsafe DeallocatorArgs() + { + Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*)Marshal.AllocHGlobal(Marshal.SizeOf())); + *EmptyPtr = new DeallocatorArgs() { gc_handle = IntPtr.Zero, deallocator_called = false }; + } + + public bool deallocator_called; + public IntPtr gc_handle; + } + + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TF_Version(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs new file mode 100644 index 000000000..bee4897ee --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class c_api + { + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern SafeBufferHandle TF_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + [DllImport(TensorFlowLibName)] + public static extern void TF_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/APIs/c_api_lite.cs b/src/TensorFlowNET.Core/APIs/c_api_lite.cs new file mode 100644 index 000000000..5a437d261 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/c_api_lite.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Lite; + +namespace Tensorflow +{ + public class c_api_lite + { + public const string TensorFlowLibName = "tensorflowlite_c"; + + public static string StringPiece(IntPtr handle) + { + return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); + } + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteVersion(); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteModelHandle TfLiteModelCreateFromFile(string model_path); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteModelDelete(IntPtr model); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteInterpreterOptionsHandle TfLiteInterpreterOptionsCreate(); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterOptionsDelete(IntPtr options); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterOptionsSetNumThreads(SafeTfLiteInterpreterOptionsHandle options, int num_threads); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteInterpreterHandle TfLiteInterpreterCreate(SafeTfLiteModelHandle model, SafeTfLiteInterpreterOptionsHandle optional_options); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterDelete(IntPtr interpreter); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterAllocateTensors(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteInterpreterGetInputTensorCount(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteInterpreterGetOutputTensorCount(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteInterpreterHandle interpreter, + int input_index, int[] input_dims, int input_dims_size); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteDataType TfLiteTensorType(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorNumDims(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorByteSize(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteTensorData(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteTensorName(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterInvoke(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteInterpreterGetOutputTensor(SafeTfLiteInterpreterHandle interpreter, int output_index); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteTensorCopyToBuffer(TfLiteTensor output_tensor, IntPtr output_data, int output_data_size); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs new file mode 100644 index 000000000..b529cd319 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -0,0 +1,350 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using static Tensorflow.Binding; +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// A convenient alias for None, useful for indexing arrays. + /// + public Slice newaxis = Slice.NewAxis; + /// + /// A convenient alias for ... + /// + public Slice ellipsis = Slice.Ellipsis; + + /// + /// BatchToSpace for N-D tensors of type T. + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) + => gen_array_ops.batch_to_space_nd(ops.convert_to_tensor(input), ops.convert_to_tensor(block_shape), + ops.convert_to_tensor(crops), name: name); + + /// + /// Apply boolean mask to tensor. + /// + /// + /// + /// N-D tensor. + /// K-D boolean tensor, K <= N and K must be known statically. + /// + /// A 0-D int Tensor representing the axis in tensor to mask from. + /// (N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask. + public Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + => array_ops.boolean_mask(tensor, mask, name: name, axis: axis); + + /// + /// Broadcast an array for a compatible shape. + /// + /// + /// + /// + /// + public Tensor broadcast_to(Tensor input, Shape shape, string name = null) + => gen_array_ops.broadcast_to(input, shape, name: name); + + public Tensor check_numerics(Tensor tensor, string message, string name = null) + => gen_array_ops.check_numerics(tensor, message, name: name); + + /// + /// Concatenates tensors along one dimension. + /// + /// A list of `Tensor` objects or a single `Tensor`. + /// + /// + /// A `Tensor` resulting from concatenation of the input tensors. + public Tensor concat(IEnumerable values, int axis, string name = "concat") + { + if (values.Count() == 1) + { + return tf_with(ops.name_scope(name), scope => + { + var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); + Debug.Assert(tensor.shape.ndim == 0); + return identity(values.First(), name: scope); + }); + } + return array_ops.concat(values.ToArray(), axis, name: name); + } + + /// + /// Inserts a dimension of 1 into a tensor's shape. + /// + /// + /// + /// + /// + /// A `Tensor` with the same data as `input`, but its shape has an additional + /// dimension of size 1 added. + /// + public Tensor expand_dims(Tensor input, int axis = -1, string name = null) + => array_ops.expand_dims(input, axis, name); + + /// + /// Creates a tensor filled with a scalar value. + /// + /// + /// + /// + /// + public Tensor fill(Tensor dims, T value, string name = null) + => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); + + public Tensor fill(Shape dims, T value, string name = null) + => array_ops.fill(dims, value, name: name); + + /// + /// Return a tensor with the same shape and contents as input. + /// + /// + /// + /// + public Tensor identity(Tensor input, string name = null) + => array_ops.identity(input, name: name); + + /// + /// Gather slices from params axis axis according to indices. + /// + /// + /// + /// + /// + /// + public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) + => array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis)); + + /// + /// Gather slices from `params` into a Tensor with shape specified by `indices`. + /// + /// + /// + /// + /// + public Tensor gather_nd(Tensor @params, Tensor indices, string name = null) + => gen_array_ops.gather_nd(@params, indices, name: name); + + /// + /// Return the elements, either from `x` or `y`, depending on the `condition`. + /// + /// + public Tensor where(Tensor condition, Tx x, Ty y, string name = null) + => array_ops.where(condition, x, y, name); + + /// + /// Transposes `a`. Permutes the dimensions according to `perm`. + /// + /// + /// + /// + /// + /// + public Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false) + => array_ops.transpose(a, perm, name, conjugate); + + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)). + /// + /// + public Tensor reverse(Tensor tensor, Axis axis, string name = null) + { + if (axis.IsScalar) + { + axis = new Axis(axis.axis); + } + return array_ops.reverse(tensor, axis, name: name); + } + + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// Returns a 0-D `int32` `Tensor` representing the rank of `input`. + public Tensor rank(Tensor input, string name = null) + => array_ops.rank(input, name: name); + + /// + /// Extracts a slice from a tensor. + /// + /// A `Tensor`. + /// An `int32` or `int64` `Tensor`. + /// An `int32` or `int64` `Tensor`. + /// A name for the operation (optional). + /// A `Tensor` the same type as `input`. + public Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + => array_ops.slice(input, begin.Select(x => ops.convert_to_tensor(x)).ToArray(), + size.Select(x => ops.convert_to_tensor(x)).ToArray(), name: name); + + public Tensor squeeze(Tensor input, int axis, string name = null, int squeeze_dims = -1) + => array_ops.squeeze(input, new[] { axis }, name); + + public Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) + => array_ops.squeeze(input, axis, name); + + /// + /// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. + /// + /// + /// + /// + /// + public Tensor stack(object values, int axis = 0, string name = "stack") + => array_ops.stack(values, axis, name: name); + + /// + /// Creates a tensor with all elements set to 1. + /// + /// + /// + /// A name for the operation (optional). + /// + /// if true, attempt to statically determine the shape of 'tensor' and + /// encode it as a constant. + /// + /// A `Tensor` with all elements set to 1. + public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); + + public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize); + + public Tensor one_hot(Tensor indices, int depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) => array_ops.one_hot(indices, ops.convert_to_tensor(depth), dtype: dtype, axis: axis, name: name); + + /// + /// Pads a tensor + /// + /// + /// + /// + /// + /// + /// + public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + => array_ops.pad(tensor, paddings, mode: mode, name: name, constant_values: constant_values); + + /// + /// A placeholder op that passes through `input` when its output is not fed. + /// + /// + /// A `Tensor`. The default value to produce when output is not fed. + /// + /// A `tf.Shape` or list of `int`s. The (possibly partial) shape of + /// the tensor. + /// + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `input`. + public Tensor placeholder_with_default(T input, int[] shape, string name = null) + => gen_array_ops.placeholder_with_default(ops.convert_to_tensor(input), shape, name: name); + + /// + /// Returns the shape of a tensor. + /// + /// + /// + /// + /// + public Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => array_ops.shape_internal(input, name, optimize: true, out_type: out_type); + + /// + /// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. + /// + /// + /// + /// + /// A stacked `Tensor` with the same type as `values`. + public Tensor stack(Tensor[] values, int axis = 0, string name = "stack") + => array_ops.stack(values, axis: axis, name: name); + + /// + /// Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. + /// + /// + /// + /// + /// + /// + public Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") + => array_ops.unstack(value, num: num, axis: axis, name: name); + + /// + /// Creates a tensor with all elements set to zero. + /// + /// + /// + /// + /// + /// A `Tensor` with all elements set to zero. + public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); + + public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize); + + /// + /// Stops gradient computation. + /// + /// + /// + /// + public Tensor stop_gradient(Tensor x, string name = null) + => gen_array_ops.stop_gradient(x, name: name); + + public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false, + bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, + bool infer_shape = true) + => tf.executing_eagerly() ? + new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call) : + new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call); + + public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, + bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, + bool infer_shape = true) + => tf.executing_eagerly() ? + new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call) : + new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.audio.cs b/src/TensorFlowNET.Core/APIs/tf.audio.cs new file mode 100644 index 000000000..573b11ec3 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.audio.cs @@ -0,0 +1,37 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.IO; + +namespace Tensorflow +{ + public partial class tensorflow + { + public AudioAPI audio { get; } = new AudioAPI(); + + public class AudioAPI + { + audio_ops audio_ops = new audio_ops(); + + public Tensors decode_wav(Tensor contents, int desired_channels = -1, int desired_samples = -1, string name = null) + => audio_ops.decode_wav(contents, + desired_channels: desired_channels, + desired_samples: desired_samples, + name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.autograph.cs b/src/TensorFlowNET.Core/APIs/tf.autograph.cs new file mode 100644 index 000000000..55acac621 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.autograph.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Graphs; + +namespace Tensorflow +{ + public partial class tensorflow + { + public AutoGraph autograph = new AutoGraph(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.bitwise.cs b/src/TensorFlowNET.Core/APIs/tf.bitwise.cs new file mode 100644 index 000000000..b05182447 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.bitwise.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + public bitwise_ops bitwise = new bitwise_ops(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs new file mode 100644 index 000000000..8a30badd9 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -0,0 +1,71 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System.Text; + +namespace Tensorflow +{ + public partial class tensorflow + { + public CompatApi compat { get; } = new CompatApi(); + + public class CompatApi + { + public CompatV1Api v1 { get; } = new CompatV1Api(); + + internal string as_text(string bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return bytes_or_text; + } + internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return encoding.GetString(bytes_or_text); + } + + internal string as_str(string bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + + public ByteString as_bytes(ByteString bytes, Encoding encoding = null) + { + return bytes; + } + public ByteString as_bytes(byte[] bytes, Encoding encoding = null) + { + return ByteString.CopyFrom(bytes); + } + public ByteString as_bytes(string text, Encoding encoding = null) + { + if(encoding is null) + { + encoding = Encoding.UTF8; + } + return ByteString.CopyFrom(encoding.GetBytes(text)); + } + } + + public bool executing_eagerly() + => Context.executing_eagerly(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs b/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs new file mode 100644 index 000000000..982e7ccce --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs @@ -0,0 +1,60 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class CompatV1Api + { + public void disable_eager_execution() + => tf.Context.graph_mode(); + + public IVariableV1 get_variable(string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + object initializer = null, // IInitializer or Tensor + bool? trainable = null, + List collections = null, + bool? use_resource = null, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + var scope = Tensorflow.variable_scope.get_variable_scope(); + var store = Tensorflow.variable_scope._get_default_variable_store(); + return scope.get_variable(store, + name, + shape: shape, + dtype: dtype, + use_resource: use_resource, + validate_shape: validate_shape, + initializer: initializer, + trainable: trainable, + collections: collections); + } + + public Operation global_variables_initializer() + { + var g = variables.global_variables(); + return variables.variables_initializer(g.ToArray()); + } + + public Session Session() + => new Session().as_default(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.config.cs b/src/TensorFlowNET.Core/APIs/tf.config.cs new file mode 100644 index 000000000..3c30ffb48 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.config.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Contexts; +using Tensorflow.Framework; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// Public API for tf.debugging namespace + /// https://www.tensorflow.org/api_docs/python/tf/debugging + /// More debugging instructions + /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/ + /// + public ConfigImpl config => new ConfigImpl(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs new file mode 100644 index 000000000..cd5a71e50 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -0,0 +1,73 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor cond(Tensor pred, + Tensor true_value, + Tensor false_false) + => control_flow_ops.cond(pred, () => true_value, () => false_false); + + public Tensor cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + => control_flow_ops.cond(pred, true_fn, false_fn, name: name); + + /// + /// Create an op that groups multiple operations. + /// + /// + /// + /// + /// An Operation that executes all its inputs. + public Operation group(T[] inputs, string name = null) where T : ITensorOrOperation + => control_flow_ops.group(inputs, name: name); + + public Tensor while_loop(Func cond, + Func body, + Tensor loop_vars, + int parallel_iterations = 10) + { + Func cond1 = x + => cond(x[0]); + + Func body1 = x + => new[] { body(x[0]) }; + + var results = control_flow_ops.while_loop(cond1, + body1, + new[] { loop_vars }); + return results[0]; + } + + public Tensor[] while_loop(Func cond, + Func body, + Tensors loop_vars, + int parallel_iterations = 10, + string name = null) + => control_flow_ops.while_loop(cond, body, loop_vars, + parallel_iterations: parallel_iterations, + name: name); + + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) + => ops.control_dependencies(control_inputs); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.data.cs b/src/TensorFlowNET.Core/APIs/tf.data.cs new file mode 100644 index 000000000..6c41a8393 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.data.cs @@ -0,0 +1,31 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public DataOps data { get; } = new DataOps(); + + public class DataOps + { + public int AUTOTUNE = -1; + public int INFINITE_CARDINALITY = -1; + public int UNKNOWN_CARDINALITY = -2; + public DatasetManager Dataset { get; } = new DatasetManager(); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.data_flow.cs b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs new file mode 100644 index 000000000..e4c0a83cc --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// Interleave the values from the data tensors into a single tensor. + /// + /// + /// + /// + /// + public Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) + => gen_data_flow_ops.dynamic_stitch(indices, data, name: name); + + /// + /// Partitions `data` into `num_partitions` tensors using indices from `partitions`. + /// + /// + /// + /// The number of partitions to output. + /// + /// + public Tensor[] dynamic_partition(Tensor data, Tensor partitions, int num_partitions, + string name = null) + => gen_data_flow_ops.dynamic_partition(data, partitions, num_partitions, name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.debugging.cs b/src/TensorFlowNET.Core/APIs/tf.debugging.cs new file mode 100644 index 000000000..b3b3529e4 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.debugging.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Debugging; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// Public API for tf.debugging namespace + /// https://www.tensorflow.org/api_docs/python/tf/debugging + /// More debugging instructions + /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/ + /// + public DebugImpl debugging => new DebugImpl(); + + public void print(Tensor input) + => tf.logging.print_v2(input); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.distributions.cs b/src/TensorFlowNET.Core/APIs/tf.distributions.cs new file mode 100644 index 000000000..c9ccad917 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.distributions.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public distributions_internal distributions { get; } = new distributions_internal(); + + public class distributions_internal + { + public Normal Normal(Tensor loc, + Tensor scale, + bool validate_args = false, + bool allow_nan_stats = true, + string name = "Normal") => new Normal(loc, scale, validate_args = false, allow_nan_stats = true, "Normal"); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs new file mode 100644 index 000000000..d722cb143 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -0,0 +1,98 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Gradients; + +namespace Tensorflow +{ + public partial class tensorflow + { + GradientTape _tapeSet; + + /// + /// Record operations for automatic differentiation. + /// + /// + /// + /// Tape set + public GradientTape GradientTape(bool persistent = false, + bool watch_accessed_variables = true) + { + var tape = _tapeSet.PushTape(persistent: persistent, + watch_accessed_variables: watch_accessed_variables); + tape.StartRecord(); + return _tapeSet; + } + + public Stack GetTapeSet() + => _tapeSet.GetTapeSet(); + + public Tensor[] gradients(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_util._GradientsHelper(ys, + xs, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } + + public Tensor[] gradients(Tensor ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_util._GradientsHelper(new Tensor[] { ys }, + xs, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } + + public Tensor[] gradients(Tensor ys, + Tensor xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null, + Tensor[] stop_gradients = null) + { + return gradients_util._GradientsHelper(new Tensor[] { ys }, + new Tensor[] { xs }, + grad_ys, + name, + colocate_gradients_with_ops, + gate_gradients, + stop_gradients: stop_gradients); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs new file mode 100644 index 000000000..c1b033aee --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.ops; + +namespace Tensorflow +{ + public partial class tensorflow + { + public graph_util_impl graph_util { get; } = new graph_util_impl(); + public GraphTransformer graph_transforms { get; } = new GraphTransformer(); + public GraphKeys GraphKeys { get; } = new GraphKeys(); + + public void reset_default_graph() + => ops.reset_default_graph(); + + public Graph get_default_graph() + => ops.get_default_graph(); + + public Graph peak_default_graph() + => ops.peak_default_graph(); + + /// + /// Creates a new graph. + /// + ///Has no interaction with graph defaulting. Equivalent to new Graph(); + public Graph Graph() + => new Graph(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs new file mode 100644 index 000000000..41ef52967 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -0,0 +1,376 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using OneOf.Types; +using System; +using System.Buffers.Text; +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public image_internal image = new image_internal(); + + public class image_internal + { + public Tensor random_flip_up_down(Tensor image, int seed = 0) + => image_ops_impl.random_flip_up_down(image, seed); + + public Tensor random_flip_left_right(Tensor image, int seed = 0) + => image_ops_impl.random_flip_left_right(image, seed); + + public Tensor flip_left_right(Tensor image) + => image_ops_impl.flip_left_right(image); + + public Tensor flip_up_down(Tensor image) + => image_ops_impl.flip_up_down(image); + + public Tensor rot90(Tensor image, int k = 1, string name = null) + => image_ops_impl.rot90(image, k, name); + + public Tensor transpose(Tensor image, string name = null) + => image_ops_impl.transpose(image, name); + + public Tensor central_crop(Tensor image, float central_fraction) + => image_ops_impl.central_crop(image, central_fraction); + + public Tensor pad_to_bounding_box(Tensor image, int offset_height, int offset_width, int target_height, int target_width) + => image_ops_impl.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width); + + public Tensor crop_to_bounding_box(Tensor image, int offset_height, int offset_width, int target_height, int target_width) + => image_ops_impl.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width); + + public Tensor resize_image_with_crop_or_pad(Tensor image, object target_height, object target_width) + => image_ops_impl.resize_image_with_crop_or_pad(image, target_height, target_width); + + public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false, + string name = null) + => image_ops_impl.resize_images(images, size, method, preserve_aspect_ratio, antialias, name); + + public Tensor resize_images_v2(Tensor images, Shape size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false, + string name = null) + => image_ops_impl.resize_images_v2(images, size, method, preserve_aspect_ratio, antialias, name); + + public Tensor resize_images_v2(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false, + string name = null) + => image_ops_impl.resize_images_v2(images, size, method, preserve_aspect_ratio, antialias, name); + + public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias) + => image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias); + + public Tensor per_image_standardization(Tensor image) + => image_ops_impl.per_image_standardization(image); + + public Tensor random_brightness(Tensor image, float max_delta, int seed = 0) + => image_ops_impl.random_brightness(image, max_delta, seed); + + public Tensor random_contrast(Tensor image, float lower, float upper, int seed = 0) + => image_ops_impl.random_contrast(image, lower, upper, seed); + + public Tensor adjust_brightness(Tensor image, Tensor delta) + => image_ops_impl.adjust_brightness(image, delta); + + public Tensor adjust_contrast(Tensor images, Tensor contrast_factor) + => image_ops_impl.adjust_contrast(images, contrast_factor); + + public Tensor adjust_gamma(Tensor image, int gamma = 1, int gain = 1) + => image_ops_impl.adjust_gamma(image, gamma, gain); + + public Tensor rgb_to_grayscale(Tensor images, string name = null) + => image_ops_impl.rgb_to_grayscale(images, name); + + public Tensor grayscale_to_rgb(Tensor images, string name = null) + => image_ops_impl.grayscale_to_rgb(images, name); + + public Tensor random_hue(Tensor image, float max_delta, int seed = 0) + => image_ops_impl.random_hue(image, max_delta, seed); + + public Tensor adjust_hue(Tensor image, Tensor delta, string name = null) + => image_ops_impl.adjust_hue(image, delta, name); + + public Tensor random_jpeg_quality(Tensor image, float min_jpeg_quality, float max_jpeg_quality, int seed = 0) + => image_ops_impl.random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed); + + public Tensor adjust_jpeg_quality(Tensor image, Tensor jpeg_quality, string name = null) + => image_ops_impl.adjust_jpeg_quality(image, jpeg_quality, name); + + public Tensor random_saturation(Tensor image, float lower, float upper, int seed = 0) + => image_ops_impl.random_saturation(image, lower, upper, seed); + + public Tensor adjust_saturation(Tensor image, Tensor saturation_factor, string name = null) + => image_ops_impl.adjust_saturation(image, saturation_factor, name); + + public Tensor total_variation(Tensor images, string name = null) + => image_ops_impl.total_variation(images, name); + + public (Tensor, Tensor, Tensor) sample_distorted_bounding_box(Tensor image_size, Tensor bounding_boxes, + int seed = 0, + Tensor min_object_covered = null, + float[] aspect_ratio_range = null, + float[] area_range = null, + int max_attempts = 100, + bool use_image_if_no_bounding_boxes = false, + string name = null) + => image_ops_impl.sample_distorted_bounding_box_v2(image_size, bounding_boxes, seed, min_object_covered, aspect_ratio_range, + area_range, max_attempts, use_image_if_no_bounding_boxes, name); + + public Tensor non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f, + float score_threshold = -1f / 0f, /*float soft_nms_sigma = 0.0f,*/ string name = null) + => image_ops_impl.non_max_suppression(boxes, scores, max_output_size, iou_threshold, score_threshold, name); + + public Tensor non_max_suppression_with_overlaps(Tensor overlaps, Tensor scores, Tensor max_output_size, + float overlap_threshold = 0.5f, float score_threshold = -1 / 0f, string name = null) + => image_ops_impl.non_max_suppression_with_overlaps(overlaps, scores, max_output_size, overlap_threshold, score_threshold, name); + + public Tensor rgb_to_yiq(Tensor images) + => image_ops_impl.rgb_to_yiq(images); + + public Tensor yiq_to_rgb(Tensor images) + => image_ops_impl.yiq_to_rgb(images); + + public Tensor rgb_to_yuv(Tensor images) + => image_ops_impl.rgb_to_yuv(images); + + public Tensor yuv_to_rgb(Tensor images) + => image_ops_impl.yuv_to_rgb(images); + + public Tensor psnr(Tensor a, Tensor b, Tensor max_val, string name = null) + => image_ops_impl.psnr(a, b, max_val, name); + + public Tensor ssim(Tensor img1, Tensor img2, float max_val = 1f, float filter_size = 11f, float filter_sigma = 1.5f, + float k1 = 0.01f, float k2 = 0.03f) + => image_ops_impl.ssim(img1, img2, max_val, filter_size, filter_sigma, k1, k2); + + public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] power_factors = null, float filter_size = 11f, + float filter_sigma = 1.5f, float k1 = 0.01f, float k2 = 0.03f) + => image_ops_impl.ssim_multiscale(img1, img2, max_val, power_factors, filter_size, filter_sigma, k1, k2); + + public (Tensor, Tensor) image_gradients(Tensor image) + => image_ops_impl.image_gradients(image); + + public Tensor sobel_edges(Tensor image) + => image_ops_impl.sobel_edges(image); + + /// + /// Adjust contrast of RGB or grayscale images. + /// + /// Images to adjust. At least 3-D. + /// + /// A float multiplier for adjusting contrast. + /// The contrast-adjusted image or images. + public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null) + => gen_image_ops.adjust_contrastv2(images, contrast_factor, name); + + /// + /// Adjust hue of RGB images. + /// + /// RGB image or images. The size of the last dimension must be 3. + /// float. How much to add to the hue channel. + /// A name for this operation (optional). + /// Adjusted image(s), same shape and DType as `image`. + /// if `delta` is not in the interval of `[-1, 1]`. + public Tensor adjust_hue(Tensor images, float delta, string name = null) + { + if (tf.Context.executing_eagerly()) + { + if (delta < -1f || delta > 1f) + throw new ValueError("delta must be in the interval [-1, 1]"); + } + return gen_image_ops.adjust_hue(images, delta, name: name); + } + + /// + /// Adjust saturation of RGB images. + /// + /// RGB image or images. The size of the last dimension must be 3. + /// float. Factor to multiply the saturation by. + /// A name for this operation (optional). + /// Adjusted image(s), same shape and DType as `image`. + public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null) + => gen_image_ops.adjust_saturation(image, saturation_factor, name); + + /// + /// Greedily selects a subset of bounding boxes in descending order of score. + /// + /// + /// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` + /// is 1 then same boxes are used for all classes otherwise, if `q` is equal + /// to number of classes, class-specific boxes are used. + /// + /// + /// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` + /// representing a single score corresponding to each box(each row of boxes). + /// + /// + /// A scalar integer `Tensor` representing the + /// maximum number of boxes to be selected by non-max suppression per class + /// + /// + /// A int32 scalar representing maximum number of boxes retained + /// over all classes.Note that setting this value to a large number may + /// result in OOM error depending on the system workload. + /// + /// + /// A float representing the threshold for deciding whether boxes + /// overlap too much with respect to IOU. + /// + /// + /// A float representing the threshold for deciding when to + /// remove boxes based on score. + /// + /// + /// If false, the output nmsed boxes, scores and classes are + /// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`, + /// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false. + /// + /// + /// If true, the coordinates of output nmsed boxes will be clipped + /// to[0, 1]. If false, output the box coordinates as it is. Defaults to true. + /// + /// + /// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes. + /// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes. + /// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes. + /// 'valid_detections': A [batch_size] int32 tensor indicating the number of + /// valid detections per batch item. Only the top valid_detections[i] entries + /// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the + /// entries are zero paddings. + /// + public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression( + Tensor boxes, + Tensor scores, + int max_output_size_per_class, + int max_total_size, + float iou_threshold, + float score_threshold, + bool pad_per_class = false, + bool clip_boxes = true) + { + var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold"); + var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold"); + var max_total_size_t = ops.convert_to_tensor(max_total_size); + var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class); + return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t, + iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes); + } + + /// + /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. + /// Returns a tensor with crops from the input image at positions defined at the bounding box locations in boxes.The cropped boxes are all resized(with bilinear or nearest neighbor interpolation) to a fixed size = [crop_height, crop_width].The result is a 4 - D tensor[num_boxes, crop_height, crop_width, depth].The resizing is corner aligned. In particular, if boxes = [[0, 0, 1, 1]], the method will give identical results to using tf.image.resize_bilinear() or tf.image.resize_nearest_neighbor() (depends on the method argument) with align_corners = True. + /// + /// A Tensor. Must be one of the following types: uint8, uint16, int8, int16, int32, int64, half, float32, float64. A 4-D tensor of shape [batch, image_height, image_width, depth]. Both image_height and image_width need to be positive. + /// A Tensor of type float32. A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled crop is an up-down flipped version of the original image. The width dimension is treated similarly. Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to extrapolate the input image values. + /// A Tensor of type int32. A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). The value of box_ind[i] specifies the image that the i-th box refers to. + /// A Tensor of type int32. A 1-D tensor of 2 elements, size = [crop_height, crop_width]. All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive. + /// An optional string from: "bilinear", "nearest". Defaults to "bilinear". A string specifying the sampling method for resizing. It can be either "bilinear" or "nearest" and default to "bilinear". Currently two sampling methods are supported: Bilinear and Nearest Neighbor. + /// An optional float. Defaults to 0. Value used for extrapolation, when applicable. + /// A name for the operation (optional). + /// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]. + public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => + gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); + + public Tensor decode_jpeg(Tensor contents, + int channels = 0, + int ratio = 1, + bool fancy_upscaling = true, + bool try_recover_truncated = false, + int acceptable_fraction = 1, + string dct_method = "", + string name = null) + => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, + fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, + acceptable_fraction: acceptable_fraction, dct_method: dct_method); + + public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, + bool uniform_noise = true, string name = null) + => image_ops_impl.extract_glimpse(input, size, offsets, centered, normalized, uniform_noise, name); + + public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, + Tensor max_total_size, float iou_threshold = 0.5f, float score_threshold = -1f / 0f, bool pad_per_class = false, bool clip_boxes = true, + string name = null) + => image_ops_impl.combined_non_max_suppression(boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold, + pad_per_class, clip_boxes, name); + + public (Tensor, Tensor) non_max_suppression_padded(Tensor boxes, Tensor scores, Tensor max_output_size, + float iou_threshold = 0.5f, + float score_threshold = -1f / 0f, + bool pad_to_max_output_size = false, + string name = null, + bool sorted_input = false, + bool canonicalized_coordinates = false, + int tile_size = 512) + => image_ops_impl.non_max_suppression_padded(boxes, scores, max_output_size, iou_threshold, score_threshold, pad_to_max_output_size, + name, sorted_input, canonicalized_coordinates, tile_size); + + public Tensor resize(Tensor image, Shape size, string method = ResizeMethod.BILINEAR) + => image_ops_impl.resize_images_v2(image, size, method: method); + + public Tensor resize(Tensor image, Tensor size, string method = ResizeMethod.BILINEAR) + => image_ops_impl.resize_images_v2(image, size, method: method); + + public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) + => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, half_pixel_centers: half_pixel_centers, name: name); + + public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, + bool preserve_aspect_ratio = false, string name = null) + => image_ops_impl.resize_images(images, size, method: method, + preserve_aspect_ratio: preserve_aspect_ratio, name: name); + + public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) + => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); + + public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, + name: name, expand_animations: expand_animations); + + public Tensor encode_png(Tensor contents, string name = null) + => image_ops_impl.encode_png(contents, name: name); + + public Tensor encode_jpeg(Tensor contents, string name = null) + => image_ops_impl.encode_jpeg(contents, name: name); + + + /// + /// Convenience function to check if the 'contents' encodes a JPEG image. + /// + /// + /// + /// + public Tensor is_jpeg(Tensor contents, string name = null) + => image_ops_impl.is_jpeg(contents, name: name); + + /// + /// Resize `images` to `size` using nearest neighbor interpolation. + /// + /// + /// + /// + /// + /// + /// + public Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + string name = null, bool half_pixel_centers = false) + => image_ops_impl.resize_nearest_neighbor(images, size, align_corners: align_corners, + name: name, half_pixel_centers: half_pixel_centers); + + public Tensor draw_bounding_boxes(Tensor images, Tensor boxes, Tensor colors = null, string name = null) + => image_ops_impl.draw_bounding_boxes(images, boxes, colors, name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs new file mode 100644 index 000000000..8635f6620 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -0,0 +1,104 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations.Initializers; + +namespace Tensorflow +{ + public partial class tensorflow + { + public InitializersImpl initializers { get; } = new InitializersImpl(); + + public IInitializer constant_initializer(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + => new Constant(value, dtype: dtype, verify_shape: verify_shape); + public IInitializer zeros_initializer => new Zeros(); + public IInitializer ones_initializer => new Ones(); + public IInitializer glorot_uniform_initializer => new GlorotUniform(); + public IInitializer random_uniform_initializer => new RandomUniform(); + public IInitializer orthogonal_initializer => new Orthogonal(); + + public variable_scope variable_scope(string name, + string default_name = null, + Tensor[] values = null, + bool? reuse = null, + bool auxiliary_name_scope = true) => new variable_scope(name, + default_name, + values, + reuse: reuse, + auxiliary_name_scope: auxiliary_name_scope); + + public variable_scope variable_scope(VariableScope scope, + string default_name = null, + Tensor[] values = null, + bool? reuse = null, + bool auxiliary_name_scope = true) => new variable_scope(scope, + default_name, + values, + reuse: reuse, + auxiliary_name_scope: auxiliary_name_scope); + + public IInitializer truncated_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.DtInvalid) => new TruncatedNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); + + public IInitializer random_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); + + /// + /// Initializer capable of adapting its scale to the shape of weights tensors. + /// + /// + /// + /// + /// + /// + /// + public IInitializer variance_scaling_initializer(float factor = 1.0f, + string mode = "fan_in", + string distribution = "truncated_normal", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling( + scale: factor, + mode: mode, + distribution: distribution, + seed: seed, + dtype: dtype); + + public class InitializersImpl + { + public IInitializer random_normal_initializer(float mean = 0.0f, + float stddev = 0.05f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new RandomNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); + + public IInitializer zeros_initializer(Shape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new Zeros(shape: shape, + dtype: dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs new file mode 100644 index 000000000..ea1e44b28 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -0,0 +1,66 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.IO; +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + public IoApi io { get; } = new IoApi(); + + public class IoApi + { + io_ops ops; + public GFile gfile; + public IoApi() + { + ops = new io_ops(); + gfile = new GFile(); + } + + public Tensor read_file(string filename, string name = null) + => ops.read_file(filename, name); + + public Tensor read_file(Tensor filename, string name = null) + => ops.read_file(filename, name); + + public Operation save_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, Tensor[] tensors, string name = null) + => ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); + + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); + + public Operation write_file(string filename, Tensor conentes, string name = null) + => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name); + + public Operation write_file(Tensor filename, Tensor conentes, string name = null) + => gen_ops.write_file(filename, conentes, name); + } + + public GFile gfile = new GFile(); + + public ITensorOrOperation[] import_graph_def(GraphDef graph_def, + Dictionary input_map = null, + string[] return_elements = null, + string name = null, + OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs new file mode 100644 index 000000000..32f64ec35 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -0,0 +1,111 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public LinalgApi linalg { get; } = new LinalgApi(); + + public class LinalgApi + { + linalg_ops ops = new linalg_ops(); + + public Tensor einsum(string equation, Tensors inputs, string name = null) + => math_ops.einsum(equation, inputs, name: name); + + public Tensor eye(int num_rows, + int num_columns = -1, + Shape batch_shape = null, + TF_DataType dtype = TF_DataType.TF_DOUBLE, + string name = null) + => ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name); + + public Tensor diag(Tensor diagonal, string name = null) + => gen_array_ops.diag(diagonal, name: name); + + public Tensor matmul(Tensor a, Tensor b) + => math_ops.matmul(a, b); + + public Tensor norm(Tensor a, string ord = "euclidean", Axis axis = null, string name = null) + => ops.norm(a, ord: ord, axis: axis, name: name); + + public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) + => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); + + public Tensor inv(Tensor input, bool adjoint = false, string name = null) + => ops.matrix_inverse(input, adjoint: adjoint, name: name); + + public Tensor global_norm(Tensor[] t_list, string name = null) + => clip_ops.global_norm(t_list, name: name); + + public Tensor l2_normalize(Tensor x, + int axis = 0, + float epsilon = 1e-12f, + string name = null) + => nn_impl.l2_normalize(x, axis: axis, epsilon: constant_op.constant(epsilon), name: name); + + public Tensor lstsq(Tensor matrix, Tensor rhs, + NDArray l2_regularizer = null, bool fast = true, string name = null) + => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); + + public Tensors qr(Tensor input, bool full_matrices = true, string name = null) + => ops.qr(input, full_matrices: full_matrices, name: name); + + public Tensor tensor_diag_part(Tensor input, string name = null) + => gen_array_ops.diag_part(input, name: name); + + public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null) + => math_ops.tensordot(x, y, axes, name: name); + } + + public Tensor diag(Tensor diagonal, string name = null) + => gen_array_ops.diag(diagonal, name: name); + + public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) + => math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b); + + /// + /// Multiply slices of the two matrices "x" and "y". + /// + /// + /// The `BatchMatMul` operation is embedded into the + /// `MatMul` operation on the DLL side. However the expected + /// attributes are not the same, hence we need to expose this + /// method to have the right args list on the `_apply_op_helper` + /// function. + /// + /// For each rank > 2 the first rank - 2 dimensions are considered + /// as fixed, and have to be consistent across the two matrices. A + /// common matrix multiplication is then applied over the residual + /// 2 dimensions. + /// + /// e.g. + /// x is (3, 6, 12); y is (3, 12, 6) + /// batch_matmul(x, y) ==> (3, 6, 6) + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) + => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.logging.cs b/src/TensorFlowNET.Core/APIs/tf.logging.cs new file mode 100644 index 000000000..0e10c1610 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.logging.cs @@ -0,0 +1,23 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public logging_ops logging => new logging_ops(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.loss.cs b/src/TensorFlowNET.Core/APIs/tf.loss.cs new file mode 100644 index 000000000..48ed01500 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.loss.cs @@ -0,0 +1,23 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public LossesImpl losses => new LossesImpl(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs new file mode 100644 index 000000000..da54a9dd7 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -0,0 +1,628 @@ +/***************************************************************************** + Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + public MathApi math { get; } = new MathApi(); + public class MathApi + { + public Tensor argmax(Tensor input, Axis axis = null, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) + => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); + + public Tensor count_nonzero(Tensor input, Axis? axis = null, bool? keepdims = null, TF_DataType dtype = TF_DataType.TF_INT64, string name = null) + => math_ops.count_nonzero_v2(input, axis: axis, keepdims: keepdims ?? false, dtype: dtype); + public Tensor log(Tensor x, string name = null) + => gen_math_ops.log(x, name); + + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public Tensor erf(Tensor x, string name = null) + => math_ops.erf(x, name); + + public Tensor multiply(Tensor x, Tensor y, string name = null) + => math_ops.multiply(x, y, name: name); + public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) + => math_ops.div_no_nan(a, b); + + /// + /// Computes the Euclidean norm of elements across dimensions of a tensor. + /// + /// The tensor to reduce. Should have numeric type. + /// The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))` + /// If true, retains reduced dimensions with length 1. + /// A name for the operation (optional). + /// The reduced tensor, of the same dtype as the input_tensor. + public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name); + + public Tensor square(Tensor x, string name = null) + => math_ops.square(x, name: name); + + public Tensor sum(Tensor x, Axis? axis = null, string name = null) + => math_ops.reduce_sum(x, axis: axis, name: name); + + public Tensor softplus(Tensor features, string name = null) + => nn_ops.softplus(features, name: name); + + public Tensor tanh(Tensor x, string name = null) + => math_ops.tanh(x, name: name); + + /// + /// Finds values and indices of the `k` largest entries for the last dimension. + /// + /// + /// + /// + /// + /// + public Tensors top_k(Tensor input, int k, bool sorted = true, string name = null) + => nn_ops.top_kv2(input, k, sorted: sorted, name: name); + + public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") + => nn_ops.in_top_k(predictions, targets, k, name); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + Shape axis = null, + bool binary_output = false) + => math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, + dtype: dtype, name: name, axis: axis, binary_output: binary_output); + + public Tensor real(Tensor x, string name = null) + => gen_ops.real(x, x.dtype.real_dtype(), name); + public Tensor imag(Tensor x, string name = null) + => gen_ops.imag(x, x.dtype.real_dtype(), name); + + public Tensor conj(Tensor x, string name = null) + => gen_ops.conj(x, name); + public Tensor angle(Tensor x, string name = null) + => gen_ops.angle(x, x.dtype.real_dtype(), name); + } + + public Tensor abs(Tensor x, string name = null) + => math_ops.abs(x, name); + + /// + /// Computes acos of x element-wise. + /// + /// + /// + /// + public Tensor acos(Tensor x, string name = null) + => gen_math_ops.acos(x, name); + + /// + /// Computes asin of x element-wise. + /// + /// + /// + /// + public Tensor asin(Tensor x, string name = null) + => gen_math_ops.asin(x, name); + + public Tensor add(Tensor a, Tensor b, string name = null) + => gen_math_ops.add(a, b, name: name); + + public Tensor add(Tx a, Ty b, string name = null) + => gen_math_ops.add(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name); + + /// + /// Adds all input tensors element-wise. + /// + /// + /// + /// A `Tensor` of same shape and type as the elements of `inputs`. + public Tensor add_n(Tensor[] inputs, string name = null) + => math_ops.add_n(inputs, name: name); + + /// + /// Computes atan of x element-wise. + /// + /// + /// + /// + public Tensor atan(Tensor x, string name = null) + => gen_math_ops.atan(x, name); + + public Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_max(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name); + + public Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_min(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name); + + public Tensor is_finite(Tensor input, string name = null) + => gen_math_ops.is_finite(input, name); + + public Tensor is_nan(Tensor input, string name = null) + => gen_math_ops.is_nan(input, name); + + /// + /// Returns element-wise smallest integer not less than x. + /// + /// + /// + /// + public Tensor ceil(Tensor x, string name = null) + => gen_math_ops.ceil(x, name); + + /// + /// Computes sin of x element-wise. + /// + /// + /// + /// + public Tensor sin(Tensor x, string name = null) + => gen_math_ops.sin(x, name); + + /// + /// Computes hyperbolic sine of x element-wise. + /// + /// + /// + /// + public Tensor sinh(Tensor x, string name = null) + => gen_math_ops.sinh(x, name); + + /// + /// Computes cos of x element-wise. + /// + /// + /// + /// + public Tensor cos(Tensor x, string name = null) + => gen_math_ops.cos(x, name); + + public Tensor cos(float x, string name = null) + => gen_math_ops.cos(ops.convert_to_tensor(x), name); + + /// + /// Computes hyperbolic cosine of x element-wise. + /// + /// + /// + /// + public Tensor cosh(Tensor x, string name = null) + => gen_math_ops.cosh(x, name); + + public Tensor tan(Tensor x, string name = null) + => gen_math_ops.tan(x, name); + + public Tensor tanh(Tensor x, string name = null) + => gen_math_ops.tanh(x, name); + + /// + /// Returns element-wise largest integer not greater than x. + /// + /// + /// + /// + public Tensor floor(Tensor x, string name = null) + => gen_math_ops.floor(x, name); + + /// + /// Returns the truth value of (x > y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor greater(Tx x, Ty y, string name = null) + => gen_math_ops.greater(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + /// + /// Returns the truth value of (x >= y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor greater_equal(Tx x, Ty y, string name = null) + => gen_math_ops.greater_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + /// + /// Returns the truth value of (x < y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor less(Tx x, Ty y, string name = null) + => gen_math_ops.less(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + /// + /// Computes the log of the absolute value of `Gamma(x)` element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. + public Tensor lgamma(Tensor x, string name = null) + => gen_math_ops.lgamma(x, name: name); + + /// + /// Returns the truth value of (x <= y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor less_equal(Tx x, Ty y, string name = null) + => gen_math_ops.less_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + /// + /// Computes natural logarithm of (1 + x) element-wise. + /// + /// + /// + /// + public Tensor log1p(Tensor x, string name = null) + => gen_math_ops.log1p(x, name); + + public Tensor logical_and(T x, T y, string name = null) + => gen_math_ops.logical_and(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + public Tensor logical_not(Tensor x, string name = null) + => gen_math_ops.logical_not(x, name); + + public Tensor logical_or(Tensor x, Tensor y, string name = null) + => gen_math_ops.logical_or(x, y, name); + + public Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor") + { + return gen_math_ops.logical_and(gen_math_ops.logical_or(x, y), + gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)), name); + } + + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// + /// + /// + /// + public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) + => gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max); + + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// A Tensor. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The minimum value to clip by. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The maximum value to clip by. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'. + /// + /// + /// A clipped Tensor with the same shape as input 't'. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor t, this operation returns a tensor of the same type and + /// shape as t with its values clipped to clip_value_min and clip_value_max. + /// Any values less than clip_value_min are set to clip_value_min. Any values + /// greater than clip_value_max are set to clip_value_max. + /// + public Tensor clip_by_value(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = "ClipByValue") + => clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); + + public Tensor sub(Tx a, Ty b, string name = null) + => gen_math_ops.sub(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name); + + public Tensor divide(Tensor a, Tensor b) + => a / b; + + public Tensor sqrt(Tensor a, string name = null) + => math_ops.sqrt(a, name); + + public Tensor sign(Tensor a, string name = null) + => gen_math_ops.sign(a, name); + + public Tensor subtract(Tensor x, T[] y, string name = null) where T : struct + => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); + + /// + /// return x - y + /// + /// + /// + /// + /// + public Tensor subtract(Tensor x, Tensor y, string name = null) + => gen_math_ops.sub(x, y, name); + + public Tensor log(Tensor x, string name = null) + => gen_math_ops.log(x, name); + + public Tensor equal(Tensor x, Tensor y, string name = null) + => gen_math_ops.equal(x, y, name: name); + + /// + /// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. + /// + /// + /// + /// + /// + public Tensor atan2(Tensor y, Tensor x, string name = null) + => gen_math_ops.atan2(y, x, name); + + /// + /// Computes the maximum of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor max(Tx input, Ty axis, bool keep_dims = false, string name = null) + => gen_math_ops.max(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + + /// + /// Computes the minimum of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor min(Tx input, Ty axis, bool keep_dims = false, string name = null) + => gen_math_ops.min(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + + /// + /// Returns the max of x and y (i.e. x > y ? x : y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor maximum(T1 x, T2 y, string name = null) + => gen_math_ops.maximum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + /// + /// Returns the min of x and y (i.e. x < y ? x : y) element-wise. + /// + /// + /// + /// + /// + /// + /// + public Tensor minimum(T1 x, T2 y, string name = null) + => gen_math_ops.minimum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + public Tensor multiply(Tensor x, Tensor y, string name = null) + => gen_math_ops.mul(x, y, name: name); + + /// + /// return x * y + /// + /// + /// + /// + /// + /// + /// + public Tensor multiply(Tx x, Ty y, string name = null) + => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + /// + /// return scalar product + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor dot_prod(Tx x, Ty y, NDArray axes, string name = null) + => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); + public Tensor negative(Tensor x, string name = null) + => gen_math_ops.neg(x, name); + + /// + /// Returns the truth value of (x != y) element-wise. + /// + /// + /// + /// + /// A `Tensor` of type bool with the same size as that of x or y. + public Tensor not_equal(Tx x, Ty y, string name = null) + => math_ops.not_equal(x, y, name: name); + + /// + /// Divides x / y elementwise (using Python 2 division operator semantics). + /// + /// + /// + /// + /// + public Tensor div(Tensor x, Tensor y, string name = null) + => math_ops.div(x, y, name: name); + + public Tensor divide(Tensor x, T[] y, string name = null) where T : struct + => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); + + public Tensor pow(T1 x, T2 y, string name = "pow") + => math_ops.pow(x, y, name: name); + + /// + /// Divides `x / y` elementwise, rounding toward the most negative integer. + /// + /// + /// + /// + /// `x / y` rounded down. + public Tensor floordiv(Tensor x, Tensor y, string name = null) + => math_ops.floordiv(x, y, name: name); + + /// + /// Divides x / y elementwise (using Python 3 division operator semantics). + /// + /// + /// + /// + /// `x / y` evaluated in floating point. + public static Tensor truediv(Tensor x, Tensor y, string name = null) + => math_ops.truediv(x, y, name: name); + + public Tensor range(object start, object limit = null, object delta = null, TF_DataType? dtype = null, string name = "range") + => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name); + + public Tensor real(Tensor input, string name = null) + => math_ops.real(input, name); + + /// + /// Computes the "logical or" of elements across dimensions of a tensor. + /// + /// The boolean tensor to reduce. + /// The dimensions to reduce. + /// If true, retains reduced dimensions with length 1. + /// + /// The reduced tensor. + public Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name); + + /// + /// Computes the "logical and" of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// The reduced tensor. + public Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name); + + /// + /// Computes the product of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name); + + /// + /// Computes the sum of elements across dimensions of a tensor. + /// + /// + /// + /// + public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null, + bool keepdims = false, string name = null) + { + if (keepdims) + return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name); + else + return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices)); + } + + /// + /// Computes the maximum of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_max(input_tensor, axis, keepdims, name); + + public Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_min(input_tensor, axis, keepdims, name); + + public Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_std(input_tensor, axis, keepdims, name); + + public Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_variance(input_tensor, axis, keepdims, name); + + public Tensor sigmoid(T x, string name = null) + => math_ops.sigmoid(x, name: name); + + public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) + => gen_math_ops.sum(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + + public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) + => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); + + public Tensor round(Tensor x, string name = null) + => gen_math_ops.round(x, name: name); + + public Tensor cast(Tensor x, TF_DataType dtype, string name = null) + => math_ops.cast(x, dtype, name); + + public Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null) + => math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name); + + public Tensor square(Tensor x, string name = null) + => gen_math_ops.square(x, name: name); + public Tensor squared_difference(Tensor x, Tensor y, string name = null) + => gen_math_ops.squared_difference(x: x, y: y, name: name); + public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, + string name = null) => gen_ops.complex(real, imag, dtype, name); + public Tensor exp(Tensor x, + string name = null) => gen_math_ops.exp(x, name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs new file mode 100644 index 000000000..112c48628 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -0,0 +1,248 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Xml.Linq; +using Tensorflow.Operations; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public nn_internal nn { get; } = new nn_internal(); + + public class nn_internal + { + public Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, + string data_format = "NHWC", int[] dilations = null, string name = null) + { + return gen_nn_ops.conv2d(input, filter, strides, padding, use_cudnn_on_gpu, + data_format: data_format, dilations: dilations, name: name); + } + + public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null) + => gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name); + + /// + /// Computes dropout. + /// + /// A floating point tensor. + /// (deprecated) A deprecated alias for `(1-rate)`. + /// + /// Used to create random seeds. + /// + /// A scalar `Tensor` with the same type as `x`. + /// A Tensor of the same shape of `x`. + public Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_shape = null, int? seed = null, string name = null, + float? rate = null) + { + Tensor keep = null; + if (keep_prob != null) + keep = 1.0f - keep_prob; + var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep; + return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); + } + + /// + /// Creates a recurrent neural network specified by RNNCell `cell`. + /// + /// An instance of RNNCell. + /// The RNN inputs. + /// + /// + /// + /// A pair (outputs, state) + public (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs, + Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, + int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) + => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, + parallel_iterations: parallel_iterations, swap_memory: swap_memory, + time_major: time_major); + + public Tensor elu(Tensor features, string name = null) + => gen_nn_ops.elu(features, name: name); + + public (Tensor, Tensor) moments(Tensor x, + Axis axes, + string name = null, + bool keep_dims = false) => nn_impl.moments(x, + axes, + name: name, + keep_dims: keep_dims); + + public Tensor embedding_lookup(IVariableV1 @params, + Tensor ids, + string partition_strategy = "mod", + string name = null) => embedding_ops._embedding_lookup_and_transform(@params, + ids, + partition_strategy: partition_strategy, + name: name); + + public Tensor embedding_lookup(Tensor @params, + Tensor ids, + string partition_strategy = "mod", + string name = null) => embedding_ops._embedding_lookup_and_transform(new Tensor[] { @params }, + ids, + partition_strategy: partition_strategy, + name: name); + + public IActivation relu() => new relu(); + + + public IActivation swish() => new swish(); + public IActivation tanh() => new tanh(); + + public IActivation softmax() => new softmax(); + public Tensor tanh(Tensor x, string name = null) + => gen_math_ops.tanh(x, name); + + public Tensor relu(Tensor features, string name = null) + => gen_nn_ops.relu(features, name); + + public Tensor relu6(Tensor features, string name = null) + => gen_nn_ops.relu6(features, name); + + public Tensor[] fused_batch_norm(Tensor x, + Tensor scale, + Tensor offset, + Tensor mean = null, + Tensor variance = null, + float epsilon = 0.001f, + string data_format = "NHWC", + bool is_training = true, + string name = null, + float exponential_avg_factor = 1.0f) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, + epsilon: epsilon, + data_format: data_format, + is_training: is_training, + name: name, + exponential_avg_factor: exponential_avg_factor); + + /// + /// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\). + /// + /// A floating point tensor. + /// A mean `Tensor`. + /// A variance `Tensor`. + /// An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor. + /// A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor. + /// A small float number to avoid dividing by 0. + /// A name for this operation. + /// the normalized, scaled, offset tensor. + public Tensor batch_normalization(Tensor x, + Tensor mean, + Tensor variance, + Tensor offset, + Tensor scale, + float variance_epsilon, + string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name); + + + public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) + => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); + + public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") + => nn_ops.in_top_k(predictions, targets, k, name); + + public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) + => gen_nn_ops.top_kv2(input, k: ops.convert_to_tensor(k), sorted: sorted, name: name); + + public Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null, string name = null) + { + return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => + { + name = scope; + return gen_nn_ops.bias_add(value, ops.convert_to_tensor(bias), data_format: data_format, name: name); + }); + } + + public Tensor l2_loss(Tensor t, string name = null) + => nn_ops.l2_loss(t, name: name); + + /// + /// Local Response Normalization. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1, + int alpha = 1, float beta = 0.5f, string name = null) + => gen_nn_ops.lrn(input, depth_radius: depth_radius, bias: bias, + alpha: alpha, beta: beta, name: name); + + public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + => nn_ops.leaky_relu(features, alpha: alpha, name: name); + + public rnn_cell_impl rnn_cell => new rnn_cell_impl(); + + public Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) + => nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); + + public Tensor softmax(Tensor logits, int axis = -1, string name = null) + => gen_nn_ops.softmax(logits, name); + + + /// + /// Computes sparse softmax cross entropy between `logits` and `labels`. + /// + /// + /// + /// + /// + public Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null, + Tensor logits = null, string name = null) + => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name); + + /// + /// Computes softmax cross entropy between `logits` and `labels`. + /// + /// + /// + /// + /// + /// + public Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, int dim = -1, string name = null) + { + tf_with(ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", new { logits, labels }), scope => + { + name = scope; + labels = array_ops.stop_gradient(labels, name: "labels_stop_gradient"); + }); + + return softmax_cross_entropy_with_logits_v2(labels, logits, axis: dim, name: name); + } + + public Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) + => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); + + /// + /// Computes sigmoid of `x` element-wise. + /// Specifically, `y = 1 / (1 + exp(-x))`. + /// + /// + /// + /// A name for the operation (optional). + /// A Tensor with the same type as `x`. + public Tensor sigmoid(T x, string name = null) + => math_ops.sigmoid(x, name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.numpy.cs b/src/TensorFlowNET.Core/APIs/tf.numpy.cs new file mode 100644 index 000000000..392ba915f --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.numpy.cs @@ -0,0 +1,29 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// NumPy API on TensorFlow + /// https://www.tensorflow.org/api_docs/python/tf/experimental/numpy + /// + public NumPyImpl numpy => new NumPyImpl(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs new file mode 100644 index 000000000..ebf35e3f9 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -0,0 +1,97 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + public partial class tensorflow + { + public void add_to_collection(string name, T value) + => get_default_graph().add_to_collection(name, value); + + public void add_to_collections(List names, T value) + => get_default_graph().add_to_collections(names, value); + + public (Tensors, Tensor) clip_by_global_norm(Tensor[] t_list, float clip_norm, Tensor use_norm = null, string name = null) + => clip_ops.clip_by_global_norm(t_list, clip_norm, use_norm: use_norm, name: name); + + public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) + => state_ops.assign(@ref, value, validate_shape, use_locking, name); + + public void device(string device_name) + => get_default_graph().device(device_name); + + public List get_collection(string key, string scope = "") + => get_default_graph().get_collection(key, scope: scope); + + /// + /// A context manager that lifts ops out of control-flow scopes and function-building graphs. + /// When eager execution is enabled, code inside an init_scope block runs with + /// eager execution enabled even when tracing a `tf.function`. + /// + public ops.NameScope init_scope() + => ops.init_scope(); + + /// + /// Returns a context manager that creates hierarchical names for operations. + /// + /// The name argument that is passed to the op function. + /// The default name to use if the name argument is None. + /// The list of Tensor arguments that are passed to the op function. + /// The scope name. + public ops.NameScope name_scope(string name, string default_name = "", object values = null) + => new ops.NameScope(name, default_name, values); + + /// + /// Does nothing. Only useful as a placeholder for control edges. + /// + /// + /// + public Operation no_op(string name = null) + => gen_control_flow_ops.no_op(name: name); + + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = -1, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + => Operation.map_fn(fn, + elems, + dtype, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + infer_shape: infer_shape, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs new file mode 100644 index 000000000..a4757890e --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs @@ -0,0 +1,126 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Queues; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// A FIFOQueue that supports batching variable-sized tensors by padding. + /// + /// + /// + /// + /// + /// + /// + /// + public PaddingFIFOQueue PaddingFIFOQueue(int capacity, + TF_DataType[] dtypes, + Shape[] shapes, + string[] names = null, + string shared_name = null, + string name = "padding_fifo_queue") + => new PaddingFIFOQueue(capacity, + dtypes, + shapes, + names, + shared_name: shared_name, + name: name); + + public PaddingFIFOQueue PaddingFIFOQueue(int capacity, + TF_DataType dtype, + Shape shape, + string shared_name = null, + string name = "padding_fifo_queue") + => new PaddingFIFOQueue(capacity, + new[] { dtype }, + new[] { shape }, + shared_name: shared_name, + name: name); + + /// + /// A queue implementation that dequeues elements in first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + /// + public FIFOQueue FIFOQueue(int capacity, + TF_DataType[] dtypes, + Shape[] shapes = null, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + dtypes, + shapes, + names, + shared_name: shared_name, + name: name); + + public FIFOQueue FIFOQueue(int capacity, + TF_DataType dtype, + Shape shape = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + new[] { dtype }, + new[] { shape ?? Shape.Null }, + shared_name: shared_name, + name: name); + + /// + /// Creates a queue that dequeues elements in a first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + public PriorityQueue PriorityQueue(int capacity, + TF_DataType dtype, + Shape shape = null, + string shared_name = null, + string name = "priority_queue") + => new PriorityQueue(capacity, + new[] { dtype }, + new[] { shape ?? Shape.Null }, + shared_name: shared_name, + name: name); + + public RandomShuffleQueue RandomShuffleQueue(int capacity, + int min_after_dequeue, + TF_DataType dtype, + Shape shape = null, + int? seed = null, + string shared_name = null, + string name = "random_shuffle_queue") + => new RandomShuffleQueue(capacity, + min_after_dequeue: min_after_dequeue, + new[] { dtype }, + new[] { shape ?? Shape.Null }, + seed: seed, + shared_name: shared_name, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs new file mode 100644 index 000000000..4f4962840 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -0,0 +1,128 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Random random => new Random(); + + public class Random + { + /// + /// Outputs random values from a normal distribution. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); + + public Tensor stateless_normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + string name = null) => stateless_random_ops.stateless_random_normal(shape, mean, stddev, dtype, name: name); + + /// + /// Outputs random values from a truncated normal distribution. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor truncated_normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); + + public Tensor categorical( + Tensor logits, + int num_samples, + int? seed = null, + string name = null, + TF_DataType output_dtype = TF_DataType.DtInvalid) => random_ops.multinomial(logits, num_samples, seed: seed, name: name, output_dtype: output_dtype); + + public Tensor uniform(Shape shape, + float minval = 0, + float maxval = 1, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + if (dtype.is_integer()) + return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, seed, name); + else + return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); + } + } + + public Tensor random_uniform(Shape shape, + float minval = 0, + float maxval = 1, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + => random.uniform(shape, minval: minval, maxval: maxval, dtype: dtype, seed: seed, name: name); + + public Tensor truncated_normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); + + /// + /// Randomly shuffles a tensor along its first dimension. + /// + /// + /// + /// + /// + /// A tensor of same shape and type as value, shuffled along its + /// first dimension. + /// + public Tensor random_shuffle(Tensor value, int? seed = null, string name = null) + => random_ops.random_shuffle(value, seed: seed, name: name); + + public void set_random_seed(int seed) + { + if (executing_eagerly()) + Context.set_global_seed(seed); + else + ops.get_default_graph().seed = seed; + } + + public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, + string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) + => random_ops.multinomial(logits, num_samples, seed: seed, + name: name, output_dtype: output_dtype); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs new file mode 100644 index 000000000..41f0ec45d --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor reduce_logsumexp(Tensor input_tensor, + Axis? axis = null, + bool keepdims = false, + string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name); + + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs new file mode 100644 index 000000000..102a81323 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor reshape(Tensor tensor, + Shape shape, + string name = null) + => gen_array_ops.reshape(tensor, shape, name); + + public Tensor reshape(Tensor tensor, + Tensor shape, + string name = null) + => gen_array_ops.reshape(tensor, shape, name); + + public Tensor reshape(Tensor tensor, + object[] shape, + string name = null) + => array_ops.reshape(tensor, shape, name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.saved_model.cs b/src/TensorFlowNET.Core/APIs/tf.saved_model.cs new file mode 100644 index 000000000..ef6251ca8 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.saved_model.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow +{ + public partial class tensorflow + { + public SavedModelAPI saved_model { get; } = new SavedModelAPI(); + } + + public class SavedModelAPI + { + public Trackable load(string export_dir, LoadOptions? options = null) + { + return Loader.load(export_dir, options); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.scan.cs b/src/TensorFlowNET.Core/APIs/tf.scan.cs new file mode 100644 index 000000000..5642eaaf1 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.scan.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor scan( + Func fn, + Tensor elems, + Tensor initializer = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + bool reverse = false, + string name = null) => functional_ops.scan(fn, elems, initializer, parallel_iterations, back_prop, + swap_memory, infer_shape, reverse, name); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/tf.signal.cs b/src/TensorFlowNET.Core/APIs/tf.signal.cs new file mode 100644 index 000000000..2471124c5 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.signal.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2023 Konstantin Balashov All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + public SignalApi signal { get; } = new SignalApi(); + public class SignalApi + { + public Tensor fft(Tensor input, string name = null) + => gen_ops.f_f_t(input, name: name); + public Tensor ifft(Tensor input, string name = null) + => gen_ops.i_f_f_t(input, name: name); + public Tensor fft2d(Tensor input, string name = null) + => gen_ops.f_f_t2d(input, name: name); + public Tensor ifft2d(Tensor input, string name = null) + => gen_ops.i_f_f_t2d(input, name: name); + public Tensor fft3d(Tensor input, string name = null) + => gen_ops.f_f_t3d(input, name: name); + public Tensor ifft3d(Tensor input, string name = null) + => gen_ops.i_f_f_t3d(input, name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.sparse.cs b/src/TensorFlowNET.Core/APIs/tf.sparse.cs new file mode 100644 index 000000000..f124f6105 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.sparse.cs @@ -0,0 +1,62 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Framework; + +namespace Tensorflow +{ + public partial class tensorflow + { + public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape) + => new SparseTensor(indices, values, dense_shape); + + public Tensor sparse_tensor_to_dense(SparseTensor sp_input, + Array default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sp_input.indices, + sp_input.dense_shape, + sp_input.values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`. + public Tensor sparse_to_dense(Tensor sparse_indices, + Shape output_shape, + T sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sparse_indices, + output_shape, + sparse_values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.state.cs b/src/TensorFlowNET.Core/APIs/tf.state.cs new file mode 100644 index 000000000..d86f88b17 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.state.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public ITensorOrOperation assign_add(IVariableV1 @ref, T value, + bool use_locking = false, string name = null) + => state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs new file mode 100644 index 000000000..ecaf775d0 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -0,0 +1,95 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public StringsApi strings { get; } = new StringsApi(); + + public class StringsApi + { + string_ops ops = new string_ops(); + + /// + /// Converts all uppercase characters into their respective lowercase replacements. + /// + /// + /// + /// + /// + public Tensor lower(Tensor input, string encoding = "", string name = null) + => ops.lower(input: input, encoding: encoding, name: name); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor regex_replace(Tensor input, string pattern, string rewrite, + bool replace_global = true, string name = null) + => ops.regex_replace(input, pattern, rewrite, + replace_global: replace_global, name: name); + + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// + public Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + => ops.substr(input, pos, len, @uint: @uint, name: name); + + public Tensor substr(string input, int pos, int len, + string name = null, string @uint = "BYTE") + => ops.substr(input, pos, len, @uint: @uint, name: name); + + /// + /// String lengths of `input`. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => ops.string_length(input, name: name, unit: unit); + + public Tensor format(string template, Tensor[] inputs, string placeholder = "{}", int summarize = 3, string name = null) + => ops.string_format(inputs, template: template, placeholder: placeholder, summarize: summarize, name: name); + + public RaggedTensor split(Tensor input, char sep = ' ', int maxsplit = -1, string name = null) + => ops.string_split_v2(input, sep: sep.ToString(), maxsplit : maxsplit, name : name); + + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, + string errors = "replace", int replacement_char = 0xFFFD, + bool replace_control_characters = false, string name = null) + => ops.unicode_decode_with_offsets(input, input_encoding, errors, + replacement_char: replacement_char, + replace_control_characters: replace_control_characters, + name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.summary.cs b/src/TensorFlowNET.Core/APIs/tf.summary.cs new file mode 100644 index 000000000..4d0492b60 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.summary.cs @@ -0,0 +1,26 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Summaries.Summary summary = new Summaries.Summary(); + + public Tensor scalar(string name, Tensor tensor) + => summary.scalar(name, tensor); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs new file mode 100644 index 000000000..b03168ab3 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -0,0 +1,97 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + => ops.convert_to_tensor(value, dtype, name, preferred_dtype: preferred_dtype); + + public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) => gen_array_ops.strided_slice(input: input, + begin: begin, + end: end, + strides: strides, + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask, + name: name); + + public Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides = null, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) => array_ops.strided_slice(input, + begin: ops.convert_to_tensor(begin), + end: ops.convert_to_tensor(end), + strides: ops.convert_to_tensor(strides), + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask, + name: name); + + /// + /// Splits a tensor into sub tensors. + /// + /// The Tensor to split. + /// Either an integer indicating the number of splits along split_dim or a 1-D integer + /// Tensor or Python list containing the sizes of each output tensor along split_dim. + /// If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value. + /// An integer or scalar int32 Tensor. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0. + /// A name for the operation (optional) + /// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; + /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value. + public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) + => array_ops.split( + value: value, + num_or_size_splits: num_split, + axis: axis, + name: name); + + public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) + => array_ops.split( + value: value, + num_or_size_splits: num_split, + axis: axis, + name: name); + + //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) + // => array_ops.split( + // value: value, + // num_or_size_splits: num_split, + // axis: axis, + // name: name); + + public Tensor ensure_shape(Tensor x, Shape shape, string name = null) + { + return gen_ops.ensure_shape(x, shape, name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs new file mode 100644 index 000000000..a3b497e8a --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -0,0 +1,34 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor tile(Tensor input, Tensor multiples, string name = null) + => gen_array_ops.tile(input, multiples, name); + + public Tensor tile(Tensor input, object[] multiples, string name = null) + => array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name); + + public Tensor tile(Tensor input, Shape multiples, string name = null) + { + var multiples_tensor = constant_op.constant(multiples); + return gen_array_ops.tile(input, multiples_tensor, name); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs new file mode 100644 index 000000000..cf02ed599 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -0,0 +1,110 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow +{ + public partial class tensorflow + { + public train_internal train { get; } = new train_internal(); + + public class train_internal + { + public IVariableV1 create_global_step(Graph graph) + => TrainingUtil.create_global_step(graph); + + public IVariableV1 get_global_step(Graph graph) + => TrainingUtil.get_global_step(graph); + + public Optimizer GradientDescentOptimizer(float learning_rate) + => new GradientDescentOptimizer(learning_rate); + + public Optimizer GradientDescentOptimizer(Tensor learning_rate) + => new GradientDescentOptimizer(learning_rate); + + public Optimizer AdamOptimizer(float learning_rate, float epsilon = 1e-8f, string name = "Adam") + => new AdamOptimizer(learning_rate, epsilon: epsilon, name: name); + + public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name, dtype: dtype); + + public Optimizer AdamOptimizer(IVariableV1 learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate.AsTensor(), name: name); + + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name); + + public ExponentialMovingAverage ExponentialMovingAverage(float decay) + => new ExponentialMovingAverage(decay); + + public Saver Saver(IVariableV1[] var_list = null, int max_to_keep = 5) + => new Saver(var_list: var_list, max_to_keep: max_to_keep); + + public string write_graph(Graph graph, string logdir, string name, bool as_text = true) + => graph_io.write_graph(graph, logdir, name, as_text); + + public Graph load_graph(string freeze_graph_pb) + => saver.load_graph(freeze_graph_pb); + + public string freeze_graph(string checkpoint_dir, string output_pb_name, string[] output_node_names) + => saver.freeze_graph(checkpoint_dir, output_pb_name, output_node_names); + + public Saver import_meta_graph(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, + clear_devices, + import_scope).Item1; + + public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", + bool as_text = false, + bool clear_devices = false, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false) => meta_graph.export_scoped_meta_graph(filename: filename, + as_text: as_text, + clear_devices: clear_devices, + clear_extraneous_savers: clear_extraneous_savers, + strip_default_attrs: strip_default_attrs); + + public string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename); + + public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); + + /*public Tensor polynomial_decay(float learning_rate, + RefVariable global_step, + float decay_steps, + float end_learning_rate = 0.0001f, + float power = 1.0f, + bool cycle = false, + string name = null) + { + var decayed = new PolynomialDecay(learning_rate, + decay_steps, + end_learning_rate: end_learning_rate, + power: power, + cycle: cycle, + name: name); + + var decayed_lr = decayed.__call__(global_step); + + return decayed_lr; + }*/ + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs new file mode 100644 index 000000000..9ce864bd8 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -0,0 +1,53 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class tensorflow + { + public IVariableV1[] global_variables(string scope = null) + { + return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List) + .ToArray(); + } + + /// + /// Returns an Op that initializes a list of variables. + /// + /// List of `Variable` objects to initialize. + /// Optional name for the returned operation. + /// An Op that run the initializers of all the specified variables. + public Operation variables_initializer(IVariableV1[] var_list, string name = "init") + => variables.variables_initializer(var_list, name: name); + + public Operation global_variables_initializer() + => tf.compat.v1.global_variables_initializer(); + + /// + /// Returns all variables created with `trainable=True`. + /// + /// + /// + public IVariableV1[] trainable_variables(string scope = null) + => (variables.trainable_variables() as List).ToArray(); + + public VariableScope get_variable_scope() + => Tensorflow.variable_scope.get_variable_scope(); + } +} diff --git a/src/TensorFlowNET.Core/Assembly/Properties.cs b/src/TensorFlowNET.Core/Assembly/Properties.cs new file mode 100644 index 000000000..290a72df0 --- /dev/null +++ b/src/TensorFlowNET.Core/Assembly/Properties.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; +#if DEBUG +[assembly: InternalsVisibleTo("Tensorflow.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +#endif diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs new file mode 100644 index 000000000..ba6f653a1 --- /dev/null +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -0,0 +1,122 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Fills in `value` with the value of the attribute `attr_name`. `value` must + /// point to an array of length at least `max_length` (ideally set to + /// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, + /// attr_name)). + /// + /// TF_Operation* + /// const char* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, SafeStatusHandle status); + + /// + /// Fills in `value` with the value of the attribute `attr_name`. `value` must + /// point to an array of length at least `max_length` (ideally set to + /// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, + /// attr_name)). + /// + /// TF_Operation* + /// const char* + /// void* + /// size_t + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrString(IntPtr oper, string attr_name, IntPtr value, uint max_length, SafeStatusHandle status); + + /// + /// Sets `output_attr_value` to the binary-serialized AttrValue proto + /// representation of the value of the `attr_name` attr of `oper`. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, SafeBufferHandle output_attr_value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrType(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrInt(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrFloat(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrBool(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value, int num_dims, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); + + /// + /// Set `num_dims` to -1 to represent "unknown rank". + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); + + /// + /// Call some TF_SetAttr*() function for every attr that is not + /// inferred from an input and doesn't have a default value you wish to + /// keep. + /// + /// `value` must point to a string of length `length` bytes. + /// + /// TF_OperationDescription* + /// const char* + /// const void* + /// size_t + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); + + /// + /// + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, SafeTensorHandle value, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); + } +} diff --git a/src/TensorFlowNET.Core/Binding.FuncTools.cs b/src/TensorFlowNET.Core/Binding.FuncTools.cs new file mode 100644 index 000000000..42a7b4ef9 --- /dev/null +++ b/src/TensorFlowNET.Core/Binding.FuncTools.cs @@ -0,0 +1,25 @@ +using System; + +namespace Tensorflow +{ + public static partial class Binding + { + public static class functools + { + public static PartialFunc partial(Func func, Tin arg) + => new PartialFunc + { + args = arg, + invoke = func + }; + } + + public class PartialFunc + { + public Tin args { get; set; } + public object[] keywords { get; set; } + + public Func invoke { get; set; } + } + } +} diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs new file mode 100644 index 000000000..99ed5c1f3 --- /dev/null +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -0,0 +1,537 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using Tensorflow.Operations; + +namespace Tensorflow +{ + /// + /// Binding utilities to mimic python functions. + /// + public static partial class Binding + { + public static T2 get(this Dictionary dict, T1 key) + => key == null ? + default : + (dict.ContainsKey(key) ? dict[key] : default); + + public static void Update(this IList list, T element) + { + var index = list.IndexOf(element); + if (index < 0) + list.Add(element); + else + { + list[index] = element; + } + } + + public static void difference_update(this IList list, IList list2) + { + foreach(var el in list2) + { + if (list.Contains(el)) + list.Remove(el); + } + } + + public static void add(this IList list, T element) + => list.Add(element); + + public static void add(this IList list, IEnumerable elements) + { + foreach (var ele in elements) + list.Add(ele); + } + + public static void append(this IList list, T element) + => list.Insert(list.Count, element); + + public static void append(this IList list, IList elements) + { + for (int i = 0; i < elements.Count(); i++) + list.Insert(list.Count, elements[i]); + } + + public static T[] concat(this IList list1, IList list2) + { + var list = new List(); + list.AddRange(list1); + list.AddRange(list2); + return list.ToArray(); + } + + public static void extend(this List list, IEnumerable elements) + => list.AddRange(elements); + + private static string _tostring(object obj) + { + switch (obj) + { + case NDArray nd: + return nd.ToString(); + /*case Array arr: + if (arr.Rank != 1 || arr.GetType().GetElementType()?.IsArray == true) + arr = Arrays.Flatten(arr); + var objs = toObjectArray(arr); + return $"[{string.Join(", ", objs.Select(_tostring))}]";*/ + default: + return obj?.ToString() ?? "null"; + } + } + + private static TextWriter _writer = Console.Out; + + public static TextWriter tf_output_redirect { + set + { + if(_writer != null) + { + _writer.Flush(); + if (_writer is StringWriter sw) + sw.GetStringBuilder().Clear(); + } + + _writer = value; + } + get => _writer ?? Console.Out; + } + + public static void print(object obj) + { + tf_output_redirect.WriteLine(_tostring(obj)); + } + + public static void print(string format, params object[] objects) + { + if (!format.Contains("{}")) + { + tf_output_redirect.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString()))); + return; + } + + foreach (var obj in objects) + { + + } + + tf_output_redirect.WriteLine(format); + } + + public static int len(object a) + { + switch (a) + { + case Tensor tensor: + return (int)tensor.shape[0]; + case Tensors arr: + return arr.Length; + case Array arr: + return arr.Length; + case IList arr: + return arr.Count; + case ICollection arr: + return arr.Count; + case IEnumerable enumerable: + return enumerable.OfType().Count(); + case Axis axis: + return axis.size; + case Shape arr: + return arr.ndim; + } + throw new NotImplementedException("len() not implemented for type: " + a.GetType()); + } + + public static int min(int a, int b) + => Math.Min(a, b); + + public static float min(float a, float b) + => Math.Min(a, b); + + public static int max(int a, int b) + => Math.Max(a, b); + + public static T[] list(IEnumerable list) + => list.ToArray(); + + public static IEnumerable range(int end) + { + return Enumerable.Range(0, end); + } + + public static IEnumerable range(int start, int end) + { + return Enumerable.Range(start, end - start); + } + + public static IEnumerable reversed(IList values) + { + var len = values.Count; + for (int i = len - 1; i >= 0; i--) + yield return values[i]; + } + + [DebuggerStepThrough] + public static void tf_with(T py, Action action) where T : ITensorFlowObject + { + py.__enter__(); + action(py); + py.__exit__(); + } + + [DebuggerStepThrough] + public static TOut tf_with(TIn py, Func action) where TIn : ITensorFlowObject + { + py.__enter__(); + var result = action(py); + py.__exit__(); + return result; + } + + public static float time() + { + return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; + } + + public static IEnumerable<(T1, T2)> zip((T1, T1) t1, (T2, T2) t2) + { + for (int i = 0; i < 2; i++) + { + if (i == 0) + yield return (t1.Item1, t2.Item1); + else + yield return (t1.Item2, t2.Item2); + } + } + + public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2, Axis axis = null) + where T : unmanaged + { + if (axis == null) + { + var a = t1.ToArray(); + var b = t2.ToArray(); + for (int i = 0; i < a.Length; i++) + yield return (a[i], b[i]); + } + else + throw new NotImplementedException(""); + } + + public static IEnumerable<(T1, T2)> zip(IList t1, IList t2) + { + for (int i = 0; i < t1.Count; i++) + yield return (t1[i], t2[i]); + } + + public static IEnumerable<(T1, T2, T3)> zip(IList t1, IList t2, IList t3) + { + for (int i = 0; i < t1.Count; i++) + yield return (t1[i], t2[i], t3[i]); + } + + public static IEnumerable<(T1, T2)> zip(NDArray t1, NDArray t2) + where T1 : unmanaged + where T2 : unmanaged + { + //var a = t1.AsIterator(); + //var b = t2.AsIterator(); + //while (a.HasNext() && b.HasNext()) + //yield return (a.MoveNext(), b.MoveNext()); + throw new NotImplementedException(""); + } + + public static IEnumerable<(T1, T2)> zip(IEnumerable e1, IEnumerable e2) + { + return e1.Zip(e2, (t1, t2) => (t1, t2)); + } + + public static IEnumerable<(TKey, TValue)> enumerate(Dictionary values) + { + foreach (var item in values) + yield return (item.Key, item.Value); + } + + public static IEnumerable<(TKey, TValue)> enumerate(KeyValuePair[] values) + { + var len = values.Length; + for (var i = 0; i < len; i++) + { + var item = values[i]; + yield return (item.Key, item.Value); + } + } + + public static IEnumerable<(int, T)> enumerate(IList values) + { + var len = values.Count; + for (int i = 0; i < len; i++) + yield return (i, values[i]); + } + + public static IEnumerable<(int, T)> enumerate(IEnumerable values, int start = 0, int step = 1) + { + int i = 0; + foreach (var val in values) + { + if (i++ < start) + continue; + + yield return (i - 1, val); + } + } + + [DebuggerStepThrough] + public static Dictionary ConvertToDict(object dyn) + { + var dictionary = new Dictionary(); + foreach (PropertyDescriptor propertyDescriptor in TypeDescriptor.GetProperties(dyn)) + { + object obj = propertyDescriptor.GetValue(dyn); + string name = propertyDescriptor.Name; + dictionary.Add(name, obj); + } + return dictionary; + } + + + public static bool all(IEnumerable enumerable) + { + foreach (var e1 in enumerable) + { + if (!Convert.ToBoolean(e1)) + return false; + } + return true; + } + + public static bool any(IEnumerable enumerable) + { + foreach (var e1 in enumerable) + { + if (Convert.ToBoolean(e1)) + return true; + } + return false; + } + + public static double sum(IEnumerable enumerable) + { + var typedef = new Type[] { typeof(double), typeof(int), typeof(float) }; + var sum = 0.0d; + foreach (var e1 in enumerable) + { + if (!typedef.Contains(e1.GetType())) + throw new Exception("Numeric array expected"); + sum += (double)e1; + } + return sum; + } + + public static float sum(IEnumerable enumerable) + => enumerable.Sum(); + + public static int sum(IEnumerable enumerable) + => enumerable.Sum(); + + public static double sum(Dictionary values) + { + return sum(values.Keys); + } + + public static IEnumerable slice(double start, double end, double step = 1) + { + for (double i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(float start, float end, float step = 1) + { + for (float i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(int start, int end, int step = 1) + { + for (int i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(int range) + { + for (int i = 0; i < range; i++) + yield return i; + } + + public static bool hasattr(object obj, string key) + { + var __type__ = (obj).GetType(); + + var __member__ = __type__.GetMembers(); + var __memberobject__ = __type__.GetMember(key); + return (__memberobject__.Length > 0) ? true : false; + } + + public static IEnumerable TupleToEnumerable(object tuple) + { + Type t = tuple.GetType(); + if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) + { + var flds = t.GetFields(); + for (int i = 0; i < flds.Length; i++) + { + yield return flds[i].GetValue(tuple); + } + } + else + { + throw new System.Exception("Expected Tuple."); + } + } + + public static bool isinstance(object Item1, Type Item2) + { + return Item1.GetType() == Item2; + } + + public static bool isinstance(object Item1, object tuple) + { + foreach (var t in TupleToEnumerable(tuple)) + if (isinstance(Item1, (Type)t)) + return true; + return false; + } + + public static bool issubset(this IEnumerable subset, IEnumerable src) + { + bool issubset = true; + foreach (var element in subset) + { + if (!src.Contains(element)) + { + issubset = false; + continue; + } + } + + return true; + } + + public static void extendleft(this Queue queue, IEnumerable elements) + { + foreach (var element in elements.Reverse()) + queue.Enqueue(element); + } + + public static bool empty(this Queue queue) + => queue.Count == 0; + + public static TValue SetDefault(this Dictionary dic, TKey key, TValue defaultValue) + { + if (dic.ContainsKey(key)) + return dic[key]; + + dic[key] = defaultValue; + return defaultValue; + } + + public static TValue Get(this Dictionary dic, TKey key, TValue defaultValue) + { + if (dic.ContainsKey(key)) + return dic[key]; + + return defaultValue; + } + + public static Shape GetShape(this object data) + { + if (data is NDArray nd) + return nd.shape; + else if (data is Tensor tensor) + return tensor.shape; + else if (data is Axis axis) + return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis.Length); + else if (data is Shape shape) + return new Shape(shape.rank); + else if (!data.GetType().IsArray) + return Shape.Scalar; + + switch (data) + { + case Array array: + var dims = range(array.Rank).Select(x => (long)array.GetLength(x)).ToArray(); + return new Shape(dims); + default: + throw new NotImplementedException(""); + } + } + public static NDArray GetFlattenArray(NDArray x) + { + switch (x.GetDataType()) + { + case TF_DataType.TF_FLOAT: + x = x.ToArray(); + break; + case TF_DataType.TF_DOUBLE: + x = x.ToArray(); + break; + case TF_DataType.TF_INT16: + case TF_DataType.TF_INT32: + x = x.ToArray(); + break; + case TF_DataType.TF_INT64: + x = x.ToArray(); + break; + default: + break; + } + return x; + } + public static TF_DataType GetDataType(this object data) + { + var type = data.GetType(); + switch (data) + { + case Shape: + return TF_DataType.TF_INT64; + case Axis: + return TF_DataType.TF_INT32; + case NDArray nd: + return nd.dtype; + case Tensor tensor: + return tensor.dtype; + case Tensors tensors: + return tensors.dtype; + case IEnumerable tensors: + return tensors.Where(x => x is not null).First().dtype; + case RefVariable variable: + return variable.dtype; + case ResourceVariable variable: + return variable.dtype; + default: + return type.as_tf_dtype(); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs new file mode 100644 index 000000000..004f35a3a --- /dev/null +++ b/src/TensorFlowNET.Core/Binding.cs @@ -0,0 +1,21 @@ +using System.Diagnostics; + +namespace Tensorflow +{ + public static partial class Binding + { + public static tensorflow tf { get; } = new tensorflow(); + + /// + /// Alias to null, similar to python's None. + /// For Shape, please use Unknown + /// + public static readonly object None = null; + + /// + /// Used for Shape None + /// + /// + public static readonly int Unknown = -1; + } +} diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index bf2799cbd..330e30caa 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -1,29 +1,124 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.IO; +using System.Runtime.CompilerServices; +using Tensorflow.Util; +using static Tensorflow.c_api; namespace Tensorflow { - public class Buffer + /// + /// Represents a TF_Buffer that can be passed to Tensorflow. + /// + public sealed class Buffer { - private IntPtr _handle; - public IntPtr Handle => _handle; - //public TF_Buffer buffer => Marshal.PtrToStructure(_handle); + SafeBufferHandle _handle; + + /// + /// + /// + private unsafe ref readonly TF_Buffer DangerousBuffer + => ref Unsafe.AsRef(_handle.DangerousGetHandle().ToPointer()); + + /// + /// The memory block representing this buffer. + /// + /// + /// The deallocator is set to null. + /// + /// + /// + public unsafe MemoryStream DangerousMemoryBlock + { + get + { + ref readonly TF_Buffer buffer = ref DangerousBuffer; + return new MemoryStream(ToArray()); + } + } - public unsafe Buffer() + /// + /// The bytes length of this buffer. + /// + public ulong Length { - _handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); + get + { + using (_handle.Lease()) + { + return DangerousBuffer.length; + } + } } - public byte[] GetBuffer() + public Buffer() + => _handle = TF_NewBuffer(); + + public Buffer(SafeBufferHandle handle) + => _handle = handle; + + public Buffer(byte[] data) + => _handle = _toBuffer(data); + + private static SafeBufferHandle _toBuffer(byte[] data) { - var buffer = Marshal.PtrToStructure(_handle); + if (data == null) + throw new ArgumentNullException(nameof(data)); - var data = Marshal.AllocHGlobal(buffer.length); - //var bytes = c_api.TF_GetBuffer(buffer.data); + unsafe + { + fixed (byte* src = data) + return TF_NewBufferFromString(new IntPtr(src), (ulong)data.LongLength); + } + } - return null; + /// + /// Copies this buffer's contents onto a array. + /// + public unsafe byte[] ToArray() + { + using (_handle.Lease()) + { + ref readonly TF_Buffer buffer = ref DangerousBuffer; + + if (buffer.length == 0) + return new byte[0]; + + var data = new byte[DangerousBuffer.length]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy(buffer.data.ToPointer(), dst, buffer.length, buffer.length); + + return data; + } + } + + public void Release() + { + _handle.Dispose(); + _handle = null; + } + + public override string ToString() + => $"0x{_handle.DangerousGetHandle():x16}"; + + public static implicit operator SafeBufferHandle(Buffer buffer) + { + return buffer._handle; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Buffers/SafeBufferHandle.cs b/src/TensorFlowNET.Core/Buffers/SafeBufferHandle.cs new file mode 100644 index 000000000..82678d549 --- /dev/null +++ b/src/TensorFlowNET.Core/Buffers/SafeBufferHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeBufferHandle : SafeTensorflowHandle + { + private SafeBufferHandle() + { + } + + public SafeBufferHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteBuffer(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs index 90fc98db2..c10f7b5f1 100644 --- a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs @@ -1,7 +1,21 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { @@ -9,7 +23,34 @@ namespace Tensorflow public struct TF_Buffer { public IntPtr data; - public int length; + public ulong length; public IntPtr data_deallocator; + + public unsafe Span AsSpan() where T: unmanaged + { + if(length > int.MaxValue) + { + throw new ValueError($"The length {length} is too large to use in the span."); + } + return new Span(data.ToPointer(), (int)length); + } + + public unsafe byte[] ToByteArray() + { + byte[] res = new byte[length]; + if(length > int.MaxValue) + { + byte* root = (byte*)data; + for(ulong i = 0; i < length; i++) + { + res[i] = *(root++); + } + } + else + { + new Span(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); + } + return res; + } } } diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index 0cc081c25..2e2422306 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -1,13 +1,47 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { [DllImport(TensorFlowLibName)] - public static extern string TF_GetBuffer(IntPtr buffer); + public static extern void TF_DeleteBuffer(IntPtr buffer); + + /// + /// Useful for passing *out* a protobuf. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern SafeBufferHandle TF_NewBuffer(); + + [DllImport(TensorFlowLibName)] + public static extern TF_Buffer TF_GetBuffer(SafeBufferHandle buffer); + + /// + /// Makes a copy of the input and sets an appropriate deallocator. Useful for + /// passing in read-only, input protobufs. + /// + /// const void* + /// size_t + /// + [DllImport(TensorFlowLibName)] + public static extern SafeBufferHandle TF_NewBufferFromString(IntPtr proto, ulong proto_len); } } diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs new file mode 100644 index 000000000..071b41875 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -0,0 +1,171 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint; + +public static class CheckPointUtils +{ + private static string _ESCAPE_CHAR = "."; + public static (IList, IDictionary>, IDictionary, + IDictionary>, + IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); + return (trackable_objects, node_paths, node_ids, slot_variables, object_names); + } + + public static + IDictionary> + serialize_slot_variables(IEnumerable trackable_objects, + IDictionary node_ids, IDictionary object_names) + { + var non_slot_objects = trackable_objects.ToList(); + Dictionary> + slot_variables = new(); + foreach (var trackable in non_slot_objects) + { + if (trackable is not Optimizer) + { + continue; + } + + var optim = (Optimizer)trackable; + var slot_names = optim.get_slot_names(); + foreach (var slot_name in slot_names) + { + for (int original_variable_node_id = 0; + original_variable_node_id < non_slot_objects.Count; + original_variable_node_id++) + { + var original_variable = non_slot_objects[original_variable_node_id]; + IVariableV1 slot_variable; + if (original_variable is not IVariableV1) + { + slot_variable = null; + } + slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); + if(slot_variable is null) continue; + + // There're some problems about the inherits of `Variable` and `Trackable`. + throw new NotImplementedException(); + } + } + } + + return slot_variables; + } + + public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) + { + if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) + { + return trackable; + } + else + { + return possible_res; + } + } + + public static string get_full_name(Trackable variable) + { + // TODO: This state is not correct, the whole framework need to be updated in the future. + if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) + { + return ""; + } + // skip the check of attribute `_save_slice_info` . + + // TODO: Need to be revised!!! + Debug.Assert(variable is BaseResourceVariable); + return ((BaseResourceVariable)variable).Name; + } + + public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) + { + HashSet checkpointed_trackables = new(); + Dictionary> parents = new(); + for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + { + var object_proto = object_graph_proto.Nodes[i]; + // skip the process of registered saver. + if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || + object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) + { + checkpointed_trackables.Add(i); + } + + foreach (var child_proto in object_proto.Children) + { + var child = child_proto.NodeId; + if (!parents.ContainsKey(child)) + { + parents[child] = new HashSet(); + } + + parents[child].Add(i); + } + } + + Queue to_visit = new(checkpointed_trackables.AsEnumerable()); + while (to_visit.Count > 0) + { + var trackable = to_visit.Dequeue(); + if (!parents.ContainsKey(trackable)) continue; + var current_parents = parents[trackable]; + foreach (var parent in current_parents) + { + checkpointed_trackables.Add(parent); + if (parents.ContainsKey(parent)) + { + to_visit.Enqueue(parent); + } + } + parents.Remove(trackable); + } + + // TODO: Complete it after supporting checkpoint. + // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + // { + // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); + // } + } + + /// + /// Traverse the object graph and list all accessible objects. + /// + /// + public static IList list_objects(ObjectGraphView graph_view) + { + return objects_ids_and_slot_variables_and_paths(graph_view).Item1; + } + + internal static IEnumerable _objects_with_attributes(IEnumerable full_list) + { + return full_list.Where(x => + { + var saveables = x.gather_saveables_for_checkpoint(); + return saveables is not null && saveables.Count > 0; + }); + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs new file mode 100644 index 000000000..75b392af8 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Checkpoint; + +public record class CheckpointOptions( + string? experimental_io_device = null, + bool experimental_enable_async_checkpoint = false); diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs new file mode 100644 index 000000000..a1dba371c --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -0,0 +1,69 @@ +namespace Tensorflow.Checkpoint; + +public class CheckpointReader +{ + private SafeCheckpointReaderHandle _handle; + public Dictionary VariableToDataTypeMap { get; set; } + public Dictionary VariableToShapeMap { get; set; } + + public CheckpointReader(string filename) + { + Status status = new Status(); + VariableToDataTypeMap = new Dictionary(); + VariableToShapeMap = new Dictionary(); + _handle = c_api.TF_NewCheckpointReader(filename, status); + status.Check(true); + ReadAllShapeAndType(); + } + + public int HasTensor(string name) + => c_api.TF_CheckpointReaderHasTensor(_handle, name); + + /// + /// Get the variable name. + /// + /// + /// + public string GetVariable(int index) + => c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); + + public int Size() + => c_api.TF_CheckpointReaderSize(_handle); + + public TF_DataType GetVariableDataType(string name) + => c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); + + public Shape GetVariableShape(string name) + { + int num_dims = GetVariableNumDims(name); + long[] dims = new long[num_dims]; + Status status = new Status(); + c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status); + status.Check(true); + return new Shape(dims); + } + + public int GetVariableNumDims(string name) + => c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); + + public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) + { + Status status = new Status(); + var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status); + status.Check(true); + return new Tensor(tensor); + } + + private void ReadAllShapeAndType() + { + int size = Size(); + for(int i = 0; i < size; i++) + { + var name = GetVariable(i); + var shape = GetVariableShape(name); + var dtype = GetVariableDataType(name); + VariableToDataTypeMap[name] = dtype; + VariableToShapeMap[name] = shape; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs new file mode 100644 index 000000000..f435dd88b --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Serilog.Debugging; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +public class ObjectGraphView: TrackableView, ICloneable +{ + protected IEnumerable? _attached_dependencies; + // TODO: attached_dependencies + public ObjectGraphView(Trackable root, IEnumerable? attached_dependencies = null): base(root) + { + _attached_dependencies = attached_dependencies; + } + + public object Clone() + { + // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ + return new ObjectGraphView(Root, _attached_dependencies); + } + + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) + { + List res = base.children(obj, save_type, serialization_cache) + .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + // Check the reference, not value. + if (obj == Root && _attached_dependencies is not null) + { + res.AddRange(_attached_dependencies); + } + + return res; + } + + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) + { + return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); + } + + public IEnumerable? AttachedDependencies + { + get => _attached_dependencies; + } + + public virtual (IList, IDictionary>) breadth_first_traversal() + { + return base._descendants_with_paths(); + } + + // TODO: complete the implementation + public void serialize_object_graph(object? saveables_cache = null) + { + throw new NotImplementedException(); + } + + // TODO: complete the implementation + public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) + { + throw new NotImplementedException(); + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs b/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs new file mode 100644 index 000000000..674e83512 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SafeCheckpointReaderHandle.cs @@ -0,0 +1,21 @@ +using Tensorflow.Util; + +namespace Tensorflow.Checkpoint; + +public sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle +{ + private SafeCheckpointReaderHandle() : base () + { + } + + public SafeCheckpointReaderHandle(IntPtr handle) : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteCheckpointReader(handle); + SetHandle(IntPtr.Zero); + return true; + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs new file mode 100644 index 000000000..7a5da7e3a --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -0,0 +1,261 @@ +using OneOf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Common.Extensions; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint +{ + internal record class TrackableData( + // A trackable in the root Trackable object graph. + Trackable trackable, + // The index at which the Trackable appears in TrackableObjectGraph.nodes. + int node_id, + // The BFS-generated path from the root object / used to generate readable checkpoint keys. + string object_name, + // A list of ObjectReference for each child connected to this Trackable. + pbc::RepeatedField children_proto, + // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). + pbc::RepeatedField slot_variable_proto, + // The object to save to checkpoint. Usually this is the same as `trackable`, + // but can differ when the the caller wants to specify a different object to + // save. For example, when saving checkpoints asynchronously, variables are + // copied to the CPU. `object_to_save` is set as the copied variable. + Trackable object_to_save + ); + public static class SaveUtil + { + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) + { + var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); + var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); + + var object_graph_proto = fill_object_graph_proto(trackable_data); + + var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); + var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); + + Dictionary feed_additions; + if(cache is null) + { + feed_additions = null; + serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, + cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); + } + else + { + feed_additions = null; + // TODO: deal with cache. + throw new NotFiniteNumberException(); + } + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + + return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); + } + + private static (IList, IDictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach(var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + Dictionary node_ids = new(); + for(int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + List trackable_data = new(); + foreach(var trackable in trackable_objects) + { + pbc::RepeatedField children_proto = new(); + foreach(var child in graph_view.list_children(trackable)) + { + children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { + NodeId = node_ids[child.Refer], + LocalName = child.Name + }); + } + slot_variables.TryGetValue(trackable, out var slot_variable); + trackable_data.Add(new TrackableData( + trackable: trackable, + node_id: node_ids[trackable], + object_name: object_names[trackable], + children_proto: children_proto, + slot_variable_proto: slot_variable??new pbc.RepeatedField(), + object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) + )); + } + return (trackable_data, node_ids); + } + + private static TrackableObjectGraph fill_object_graph_proto(IList trackable_data) + { + TrackableObjectGraph object_graph_proto = new(); + for(int i = 0; i < trackable_data.Count; i++) + { + var td = trackable_data[i]; + Debug.Assert(td.node_id == i); + TrackableObjectGraph.Types.TrackableObject trackable_object = new(); + trackable_object.SlotVariables.AddRange(td.slot_variable_proto); + trackable_object.Children.AddRange(td.children_proto); + object_graph_proto.Nodes.Add(trackable_object); + } + return object_graph_proto; + } + + /// + /// Creates dictionary of tensors to checkpoint, and updates the proto. + /// + /// + /// + /// + /// + /// + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) + { + Dictionary>>> serialized_tensors = new(); + foreach(var td in tensor_trackables) + { + // TODO: deal with cache. + var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; + Trackable trackable = null; + IDictionary>> tensor_dict; + if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) + { + (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); + } + else + { + tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + trackable = td.object_to_save; + } + if(trackable is not null) + { + serialized_tensors[trackable] = tensor_dict; + } + else + { + serialized_tensors[Trackable.None] = tensor_dict; + } + } + return serialized_tensors; + } + + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + var trackable = trackable_data.object_to_save; + + // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. + IDictionary>> ret_tensor_dict; + if (call_with_mapped_captures) + { + throw new NotImplementedException(); + } + else + { + ret_tensor_dict = trackable.serialize_to_tensors(); + } + + Dictionary>> tensor_dict = new(); + foreach(var pair in ret_tensor_dict) + { + var local_name = TrackableUtils.escape_local_name(pair.Key); + var maybe_tensor = pair.Value; + var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); + + tensor_dict[checkpoint_key] = maybe_tensor; + + foreach(var key in maybe_tensor.Keys) + { + if (maybe_tensor[key].IsTypeOrDeriveFrom()) + { + maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; + } + } + + if(object_graph_proto is not null) + { + object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { + Name = local_name, + CheckpointKey = checkpoint_key, + FullName = CheckPointUtils.get_full_name(trackable) + }); + } + } + return tensor_dict; + } + + /// + /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. + /// + /// + /// + /// + /// + /// + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + Dictionary object_names = new(); + object_names[trackable_data.trackable] = trackable_data.object_name; + Dictionary object_map = new(); + object_map[trackable_data.trackable] = trackable_data.object_to_save; + + var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); + var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, + call_with_mapped_captures, saveables_cache: null); + var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); + return (trackable, trackable.serialize_to_tensors()); + } + + private static IDictionary> get_and_write_registered_savers(IDictionary> registered_trackables, TrackableObjectGraph object_graph_proto) + { + Dictionary> registered_savers = new(); + foreach(var pair in registered_trackables) + { + foreach(var td in pair.Value) + { + if (registered_savers.ContainsKey(pair.Key)) + { + registered_savers[pair.Key] = new Dictionary(); + } + else + { + registered_savers[pair.Key][td.object_name] = td.object_to_save; + } + + var object_proto = object_graph_proto.Nodes[td.node_id]; + // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. + } + } + return registered_savers; + } + + private static (IList, IList, IDictionary>) split_trackables(IEnumerable trackable_data) + { + List tensor_trackables = new(); + List py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. + Dictionary> registered_trackables = new(); + + foreach(var td in trackable_data) + { + // TODO: deal with registration. + tensor_trackables.Add(td); + } + return (tensor_trackables, py_state_trackables, registered_trackables); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs new file mode 100644 index 000000000..9280179c0 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -0,0 +1,225 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using Google.Protobuf; +using OneOf; + +namespace Tensorflow.Checkpoint; + +public static class SaveUtilV1 +{ + public static (IDictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + IDictionary? object_map = null) + { + // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, + // till now only internal registrations are allowed. So, we won't return a saver in this function. + // The implementation of this function should be updated if tensorflow update it. + Dictionary> checkpoint_factory_map = new(); + foreach (var pair in object_names) + { + var trackable = pair.Key; + var object_name = pair.Value; + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + + // skip the registration process. + + List current_list = new(); + foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) + { + // treat name as key_suffix. + var name = name_and_factory.Key; + var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); + + current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); + } + + checkpoint_factory_map[trackable] = current_list; + } + + return (checkpoint_factory_map, null); + } + + public static (IList, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, + IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, + object? saveables_cache = null) + { + if (to_graph is not null) + { + var g = to_graph.as_default(); + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => + { + // TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. + return constant_op.constant(graph_proto.ToByteArray()); + }); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + g.Exit(); + return (named_saveable_objects, registered_savers); + } + else + { + using (new ops.NullContextManager()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant(graph_proto.ToString()); + }); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + } + + public static (IList, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); + var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( + trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, + saveables_cache); + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); + } + + private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, + IDictionary node_ids, + IDictionary> + slot_variables) + { + TrackableObjectGraph object_graph_proto = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + var trackable = trackable_objects[i]; + Debug.Assert(node_ids[trackable] == i); + var object_proto = new TrackableObjectGraph.Types.TrackableObject(); + if (slot_variables.TryGetValue(trackable, out var slots)) + { + object_proto.SlotVariables.AddRange(slots); + } + object_graph_proto.Nodes.Add(object_proto); + foreach (var child in graph_view.list_children(trackable)) + { + object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { NodeId = node_ids[child.Refer], LocalName = child.Name }); + } + } + + return object_graph_proto; + } + + private static (IList, object?, IDictionary>?) add_attributes_to_object_graph( + IList trackable_objects, + TrackableObjectGraph object_graph_proto, IDictionary node_ids, + IDictionary object_names, IDictionary object_map, + bool call_with_mapped_captures, object? saveables_cache = null) + { + int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + Debug.Assert(node_ids[trackable_objects[i]] == i); + } + + var (checkpoint_factory_map, unmmaped_registered_savers) = + get_checkpoint_factories_and_keys(object_names, object_map); + + // skip the process of registered savers + + var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, + object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); + return (named_saveable_objects, feed_additions, null); + } + + public static (IList, object?) generate_saveable_objects( + IDictionary> checkpoint_factory_map, + TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + List named_saveable_objects = new(); + foreach (var pair in checkpoint_factory_map) + { + var trackable = pair.Key; + var factory_data_list = pair.Value; + bool fill_object_proto = object_graph_proto is not null && node_ids is not null; + TrackableObjectGraph.Types.TrackableObject object_proto = null!; + if (fill_object_proto) + { + object_proto = object_graph_proto.Nodes[node_ids[trackable]]; + } + + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + // skip cache + + foreach (var factory_data in factory_data_list) + { + var name = factory_data.name; + var key = factory_data.checkpoint_key; + var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); + + // TODO: tensorflow python has a process with callable `saveable_factory`. + List saveables = new(); + if (maybe_saveable.TryPickT1(out var s, out var variable)) + { + saveables.Add(s); + } + else + { + saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); + } + + foreach (var saveable in saveables) + { + if (!saveable.name.Contains(key)) + { + throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + + $"'{saveable.name}' for attribute '{name}'. Expected a name" + + $" containing '{key}'."); + } + } + + // skip the process of PythonState + + named_saveable_objects.AddRange(saveables); + + if(!fill_object_proto) continue; + + // skip the process of `TrackableSaveable` because of lack of APIs. + + object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); + } + } + + return (named_saveable_objects, null); + } +} + +public record class CheckpointFactoryData +( + Func> factory, + string name, + string checkpoint_key +); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs new file mode 100644 index 000000000..fa441d799 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint +{ + internal static class SaveableCompat + { + public static string? get_saveable_name(Trackable cls_or_obj) + { + // TODO: implement it with Attribute. + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs new file mode 100644 index 000000000..dab6d5d97 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -0,0 +1,82 @@ +using System; +using Tensorflow.Train; +using System.Collections.Generic; +using System.IO; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Checkpoint; + +public class TrackableView +{ + protected WeakReference _root_ref; + public TrackableView(Trackable obj) + { + _root_ref = new WeakReference(obj); + } + + public TrackableView(WeakReference obj) + { + _root_ref = obj; + } + + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + obj._maybe_initialize_trackable(); + Dictionary children = new(); + // Note: in python the return type of `Trackable._trackable_children` is not fixed. + // Therefore it uses `convert_to_trackable` to have an extra process. + foreach (var pair in obj._trackable_children(save_type, cache)) + { + children[pair.Key] = pair.Value; + } + return children; + } + + public Trackable Root + { + get + { + if (_root_ref.TryGetTarget(out Trackable res)) + { + return res; + } + else + { + throw new InvalidDataException( + "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); + } + } + } + + /// + /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. + /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths + /// + protected (IList, IDictionary>) _descendants_with_paths() + { + List bfs_sorted = new(); + Queue to_visit = new(); + to_visit.Enqueue(Root); + Dictionary> node_paths = new(); + node_paths[this.Root] = new List(); + while (!to_visit.empty()) + { + var current_trackable = to_visit.Dequeue(); + bfs_sorted.Add(current_trackable); + var children_dict = this.children(current_trackable); + foreach (var name in children_dict.Keys) + { + var dependency = children_dict[name]; + if (!node_paths.ContainsKey(dependency)) + { + var list = new List(node_paths[current_trackable]); + list.Add(new TrackableReference(name, dependency)); + node_paths[dependency] = list; + to_visit.Enqueue(dependency); + } + } + } + + return (bfs_sorted, node_paths); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs new file mode 100644 index 000000000..f956e3337 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs @@ -0,0 +1,27 @@ +using System.Runtime.InteropServices; +using Tensorflow.Checkpoint; + +namespace Tensorflow +{ + public unsafe partial class c_api + { + [DllImport(TensorFlowLibName)] + internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern void TF_DeleteCheckpointReader(IntPtr reader); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); + [DllImport(TensorFlowLibName)] + internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs new file mode 100644 index 000000000..30d45e82c --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -0,0 +1,582 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; +using static Tensorflow.Binding; +using Tensorflow.Operations; +using Newtonsoft.Json; +using Tensorflow.Training; +using OneOf; + +namespace Tensorflow.Checkpoint; + +/// +/// Saves and restores a `Trackable` object and its dependencies. +/// +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private Tensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + private Tensor? _file_prefix_placeholder = null; + private Dictionary? _object_map = null; + private object? _cache = null; + public Tensor? FilePrefixPlaceHolder + { + get + { + return _file_prefix_placeholder; + } + set + { + _file_prefix_placeholder = value; + } + } + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); + + // TODO: cache. + + if(object_graph_tensor is null) + { + tf_with(ops.device("/cpu:0"), _ => + { + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + }); + } + else + { + feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); + } + Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + if (!serialized_tensors.ContainsKey(Trackable.None)) + { + serialized_tensors[Trackable.None] = new Dictionary>>(); + } + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary>(); + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); + return (serialized_tensors, feed_additions, registered_savers, graph_proto); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] { save_op })) + { + _cached_save_operation = array_ops.identity(file_prefix); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] {save_op} )) + { + _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + string file_prefix_to_save; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + file_prefix_to_save = ""; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); + object_graph_tensor = null; + file_prefix_to_save = file_prefix; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } + + public LoadStatus restore(string? save_path, CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + if(save_path is null) + { + return new InitializationOnlyStatus(_graph_view, ops.uid()); + } + + CheckpointReader reader = new CheckpointReader(save_path); + bool graph_building = tf.Context.executing_eagerly(); + Dictionary dtype_map = null; + if (!graph_building) + { + dtype_map = reader.VariableToDataTypeMap; + } + Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); + + Dictionary file_prefix_feed_dict; + Tensor file_prefix_tensor = null; + if (graph_building) + { + if(_file_prefix_placeholder is null) + { + _file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant("model"); + }); + } + file_prefix_tensor = _file_prefix_placeholder; + file_prefix_feed_dict = new(); + file_prefix_feed_dict[_file_prefix_placeholder] = save_path; + } + else + { + file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant(save_path); + }); + file_prefix_feed_dict = null; + } + TrackableObjectGraph object_graph_proto = new(); + if(object_graph_string.ndim > 0) + { + object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); + } + else + { + object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); + } + CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( + object_graph_proto: object_graph_proto, + save_path: save_path, + save_path_tensor: file_prefix_tensor, + reader: reader, + restore_op_cache: null, + graph_view: _graph_view, + options: options, + saveables_cache: null + ); + + new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); + + if(_graph_view.AttachedDependencies is not null) + { + foreach(var refer in _graph_view.AttachedDependencies) + { + if(refer.Name == "root") + { + continue; + } + int? proto_id = null; + // Find proto ID of attached dependency (if it is in the proto). + foreach (var proto_refer in object_graph_proto.Nodes[0].Children) + { + if(proto_refer.LocalName == refer.Name) + { + proto_id = proto_refer.NodeId; + break; + } + } + + if (proto_id is null) + { + continue; + } + + // Object has already been restored. This can happen when there's an + // indirect connection from the attached object to the root. + if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) + { + continue; + } + + new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); + } + } + + return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); + } +} + +public class CheckpointRestoreCoordinator +{ + private CheckpointOptions _options; + private TrackableObjectGraph _object_graph_proto; + private int _restore_uid; + private HashSet _matched_proto_ids; + private Tensor _save_path_tensor; + private string _save_path_string; + private CheckpointReader _reader; + private Dictionary _dtype_map; + private Dictionary _shape_map; + private ObjectGraphView _graph_view; + private Dictionary> _slot_restorations; + private bool _expect_partial_attr; + private List _restore_ops; + private List _all_trackables; + private Dictionary _object_by_proto_id; + private Dictionary _restore_ops_by_name; + private Dictionary> _deferred_slot_restorations; + private Dictionary> _unused_attributes; + + public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, + CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) + { + // TODO(Rinne): cache. + _options = options; + _object_graph_proto = object_graph_proto; + _restore_uid = ops.uid(); + _save_path_tensor = save_path_tensor; + _save_path_string = save_path; + _reader = reader; + if(_reader is null) + { + _reader = new CheckpointReader(save_path); + } + _dtype_map = _reader.VariableToDataTypeMap; + _shape_map = _reader.VariableToShapeMap; + _graph_view = graph_view; + _restore_ops = new List(); + _restore_ops_by_name = new Dictionary(); + _all_trackables = new List(); + _matched_proto_ids = new HashSet(); + _object_by_proto_id = new Dictionary(); + _slot_restorations = new Dictionary>(); + _deferred_slot_restorations = new Dictionary>(); + + _expect_partial_attr = false; + for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) + { + var node = _object_graph_proto.Nodes[i]; + foreach(var slot_reference in node.SlotVariables) + { + _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List()) + .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); + } + } + + // skip the deleter and cache. + } + + public bool ExpectPartial + { + get + { + return _expect_partial_attr; + } + set + { + _expect_partial_attr = value; + } + } + + /// + /// Corresponding to `all_python_objects` of tensorflow python + /// + public List AllTrackables => _all_trackables; + public HashSet MatchedProtoIds => _matched_proto_ids; + // TODO(Rinne): change to weak ref. + public Dictionary ObjectByProtoId => _object_by_proto_id; + public int RestoreUid => _restore_uid; + public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; + public Dictionary> SlotRestorations => _slot_restorations; + public Dictionary> DeferredSlotRestorations => _deferred_slot_restorations; + public Dictionary RestoreOpsByName => _restore_ops_by_name; + public Dictionary> UnusedAttributes => _unused_attributes; + + public void new_restore_ops(IEnumerable new_ops) + { + _restore_ops.AddRange(new_ops); + // skip the callback. + } + + public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null) + { + List restore_ops = new(); + foreach(var position in positions) + { + var key = position.ObjectProto.Attributes[0].CheckpointKey; + throw new NotImplementedException(); + } + + Dictionary variable_dict = new(); + foreach(var item in tensor_saveables) + { + if(item.Value.TryPickT0(out var variable, out var _)) + { + variable_dict[item.Key] = variable; + } + else + { + throw new TypeError(); + } + } + + if (tensor_saveables is not null && tensor_saveables.Count > 0) + { + var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); + var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); + if (!tf.Context.executing_eagerly()) + { + foreach(var item in new_restore_ops) + { + restore_ops.Add(item.Value); + Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); + _restore_ops_by_name[item.Key] = item.Value; + } + } + } + return restore_ops; + } +} + +public abstract class LoadStatus +{ + public abstract LoadStatus assert_consumed(); + public abstract LoadStatus assert_existing_objects_matched(); + public abstract LoadStatus assert_nontrivial_match(); + public abstract LoadStatus run_restore_ops(Session? session = null); + public abstract void initialize_or_restore(Session? session = null); + public virtual LoadStatus expect_partial() + { + return this; + } +} + +public class InitializationOnlyStatus: LoadStatus +{ + private int _restore_uid; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) + { + _restore_uid = restore_uid; + _object_graph_view = object_graph_view; + _root = object_graph_view.Root; + } + public override LoadStatus assert_consumed() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus assert_existing_objects_matched() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus assert_nontrivial_match() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override LoadStatus run_restore_ops(Session? session = null) + { + throw new AssertionError("No checkpoint specified, so no restore ops are available " + + "(save_path=None to Saver.restore)."); + } + public override void initialize_or_restore(Session? session = null) + { + if (tf.Context.executing_eagerly()) + { + return; + } + if(session is null) + { + session = new Session(); + } + var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } +} + +internal class CheckpointLoadStatus: LoadStatus +{ + private CheckpointRestoreCoordinator _checkpoint; + private Dictionary _feed_dict; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary feed_dict, ObjectGraphView graph_view):base() + { + _checkpoint = checkpoint; + _feed_dict = feed_dict; + _object_graph_view = graph_view; + _root = graph_view.Root; + } + + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + + public override LoadStatus assert_consumed() + { + throw new NotImplementedException(); + } + + public override LoadStatus assert_existing_objects_matched() + { + for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) + { + var node = _checkpoint.ObjectGraphProto.Nodes[i]; + if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && + trackable.UpdateUid < _checkpoint.RestoreUid) + { + throw new AssertionError($"Object {node} not assigned a value from checkpoint."); + } + } + foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) + { + if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) + { + continue; + } + _checkpoint.AllTrackables.Add(trackable_object); + } + var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) + .Except(_checkpoint.ObjectByProtoId.Values); + if (unused_trackables.Any()) + { + var num_unused_trackables = unused_trackables.Count(); + var num_variables_to_show = Math.Min(10, num_unused_trackables); + throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + + $"not bound to checkpointed values, likely due to changes in the " + + $"Python program. Showing {num_variables_to_show} of " + + $"{num_unused_trackables} unmatched objects: " + + $"{{list(unused_python_objects)[:num_variables_to_show]}}"); + } + return this; + } + + public override LoadStatus assert_nontrivial_match() + { + throw new NotImplementedException(); + } + + public override LoadStatus expect_partial() + { + throw new NotImplementedException(); + } + + public override void initialize_or_restore(Session? session = null) + { + throw new NotImplementedException(); + } + + public override LoadStatus run_restore_ops(Session? session = null) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs new file mode 100644 index 000000000..211d7d6f0 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -0,0 +1,464 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.OptimizerOptions.Types; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Training; +using Tensorflow.Graphs; +using System.Xml.Linq; +using System.Diagnostics; +using RestoreFunc = System.Func; +using OneOf; + +namespace Tensorflow.Checkpoint +{ + internal class SingleDeviceSaver + { + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict; + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => OneOf.FromT0(y.Value)) + as IDictionary>); + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => OneOf.FromT1(y.Value)) + as IDictionary>); + } + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensors = new(); + List slice_specs = new(); + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + if(maybe_tensor.TryPickT1(out var spec, out var tensor)) + { + var tensor_value = spec.tensor; + if (tensor_value is not null) + { + tensor_names.Add(spec.name); + tensors.Add(tensor_value); + slice_specs.Add(spec.slice_spec); + } + } + else + { + tensor_names.Add(checkpoint_key); + tensors.Add(tensor); + slice_specs.Add(slice_spec); + } + } + } + // TODO: specify the device. + return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); + } + + public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); + + public IDictionary> restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensor_dtypes = new(); + List slice_specs = new(); + + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.TryPickT1(out var spec, out var tensor)) + { + tensor_dtypes.Add(spec.dtype); + slice_specs.Add(spec.slice_spec); + tensor_names.Add(spec.name); + } + else + { + tensor_dtypes.Add(tensor.dtype); + slice_specs.Add(slice_spec); + tensor_names.Add(checkpoint_key); + } + } + } + + string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; + + Tensor[] restored_tensors = null; + tf_with(ops.device(restore_device), _ => + { + restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + }); + + Dictionary> restored_tensor_dict = new(); + int idx = 0; + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice_spec in tensor_slices.Keys) + { + var restored_tensor = restored_tensors[idx++]; + if (!restored_tensor_dict.ContainsKey(checkpoint_key)) + { + restored_tensor_dict[checkpoint_key] = new Dictionary(); + } + restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; + } + } + return restored_tensor_dict; + } + + public IDictionary> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); + } + /// + /// Saves checkpoints directly from multiple devices. + /// Note that this is a low-level utility which stores Tensors in the keys + /// specified by `SaveableObject`s.Higher-level utilities for object-based + /// checkpointing are built on top of it. + /// + public class MultiDeviceSaver + { + private Dictionary _single_device_savers; + private IDictionary _registered_savers; + private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; + /// + /// + /// + /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. + /// + /// + public MultiDeviceSaver(IDictionary>>> serialized_tensors, + IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) + { + _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); + _restore_fn_to_keys = new Dictionary>(); + Dictionary>>> tensors_by_device= new(); + + foreach(var pair in serialized_tensors) + { + var obj = pair.Key; + var tensor_dict = pair.Value; + RestoreFunc restore_fn; + if(obj == Trackable.None) + { + restore_fn = new RestoreFunc(x => null); + } + else + { + restore_fn = new RestoreFunc(x => + { + if(x is IDictionary>>) + { + return obj._restore_from_tensors(x as IDictionary>>); + } + throw new TypeError($"Expected `IDictionary>>` as input, got{x.GetType()}."); + }); + } + + foreach(var item in tensor_dict) + { + var checkpoint_key = item.Key; + var spec_to_tensor = item.Value; + + foreach(var spec in spec_to_tensor) + { + var slice_spec = spec.Key; + var tensor = spec.Value; + if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) + { + throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + + $"slice spec. This is invalid because one will overwrite the " + + $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); + } + _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; + _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); + + string host_device; + if (tensor.IsT0) + { + host_device = tensor.AsT0.Device; + } + else + { + host_device = tensor.AsT1.device; + } + host_device = saveable_object_util.set_cpu0(host_device); + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>>()); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + internal_dict[checkpoint_key] = new Dictionary>(); + } + internal_dict[checkpoint_key][slice_spec] = tensor; + } + } + } + + _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); + + _registered_savers = new Dictionary(); + if(registered_savers is not null && registered_savers.Count > 0) + { + // TODO: complete the implementation. + throw new NotImplementedException(); + } + } + + public Operation save(Tensor file_prefix, CheckpointOptions? options= null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + Tensor tmp_checkpoint_prefix = null; + tf_with(ops.device("CPU"), _ => + { + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + constant_op.constant(".part"), constant_op.constant("_temp/part")); + tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + }); + + Operation save_fn() + { + List saved_prefixes= new(); + foreach(var saver in _registered_savers) + { + // TODO: implementi it later. + throw new NotImplementedException(); + } + + int num_shards = _single_device_savers.Count; + List sharded_saves = new(); + var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); + string? last_device = null; + int shard = 0; + foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) + { + var device = pair.Key; + var saver = pair.Value; + last_device = device; + // skip the extra process of device name because of lack of API. + Tensor shard_prefix = null; + tf_with(ops.device(device), _ => + { + shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + }); + saved_prefixes.Add(shard_prefix); + tf_with(ops.device(device), _ => + { + sharded_saves.Add(saver.save(shard_prefix, options)); + }); + } + using (var controller = ops.control_dependencies(sharded_saves.ToArray())) + { + string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; + return tf_with(ops.device(merge_device), _ => + { + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + }); + } + } + + if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return save_fn(); + } + } + + public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); + + public IDictionary restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + IDictionary restore_func() + { + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary restore_ops = new(); + + foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) + { + var device = single_saver.Key; + var saver = single_saver.Value; + tf_with(ops.device(device), _ => + { + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach (var pair in restored_tensor_dict) + { + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach (var item in slice_and_tensor) + { + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) + { + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = OneOf>.FromT1(dict); + } + else + { + internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; + } + } + else + { + internal_dict[checkpoint_key] = OneOf>.FromT0(tensor); + } + restore_fn_input_count[restore_fn]--; + + if (restore_fn_input_count[restore_fn] == 0) + { + Dictionary>> restored_tensors = new(); + foreach (var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if (ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } + } + } + } + }); + } + + foreach(var item in _registered_savers) + { + throw new NotImplementedException(); + } + return restore_ops; + } + + // TODO: complete the implementation. Currently skip it because of lack of API. + bool has_custom_device_saver = false; + + if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return restore_func(); + } + } + + /// + /// Serializes to a SaverDef referencing the current graph. + /// + public SaverDef to_proto() + { + var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); + var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); + var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); + var save_tensor = traced_save_func(filename_tensor); + var restore_op = traced_restore_func(filename_tensor).op; + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + Version = SaverDef.Types.CheckpointFormatVersion.V2 + }; + } + + private Tensor _traced_save(Tensor file_prefix) + { + var save_op = save(file_prefix); + return tf_with(ops.device("cpu:0"), _ => + { + return tf_with(ops.control_dependencies(new object[] { save_op }), __ => + { + return array_ops.identity(file_prefix); + }); + }); + } + + private Tensor _traced_restore(Tensor file_prefix) + { + var restore_op = restore(file_prefix); + return tf_with(ops.device("cpu:0"), _ => + { + return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => + { + return array_ops.identity(file_prefix); + }); + }); + } + + public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) + { + Dictionary>>> serialized_tensors = new(); + foreach (var saveable in saveables) + { + var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); + serialized_tensors[trackable] = trackable.serialize_to_tensors(); + } + return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); + } + + private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) + { + return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); + } + private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) + { + return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs new file mode 100644 index 000000000..0e1a300e9 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -0,0 +1,333 @@ +using OneOf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Security; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +public class CheckpointPosition +{ + private CheckpointRestoreCoordinator _checkpoint; + private int _proto_id; + private bool _skip_restore; + public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) + { + _checkpoint = checkpoint; + _proto_id = proto_id; + _skip_restore = false; + } + + public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; + + public void restore(Trackable trackable) + { + using (ops.init_scope()) + { + if (bind_project(trackable)) + { + var restore_ops = _restore_descendants(); + if(restore_ops is not null && restore_ops.Count > 0) + { + _checkpoint.new_restore_ops(restore_ops); + } + } + } + } + + /// + /// Set a checkpoint<->object correspondence. + /// + /// + /// + public bool bind_project(Trackable trackable) + { + _checkpoint.AllTrackables.Add(trackable); + _checkpoint.MatchedProtoIds.Add(_proto_id); + if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) + { + // skip the `logging.warning`. + return false; + } + else + { + _checkpoint.ObjectByProtoId[_proto_id] = trackable; + return true; + } + } + + public (List, Dictionary>, List, object?) gather_ops_or_named_saveables() + { + // skip the registered_saver + + if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) + { + return (new List(), new Dictionary>(), + new List(), null); + } + + var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); + + List existing_restore_ops; + List positions = new(); + Dictionary> named_saveables; + if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); + } + else if(saveable_factories.Count > 0) + { + (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); + } + else + { + throw new NotImplementedException(); + } + return (existing_restore_ops, named_saveables, positions, null); + } + + public CheckpointPosition create_child_position(int node_id) + { + return new CheckpointPosition(_checkpoint, node_id); + } + + public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, + int slot_variable_id, string slot_name) + { + //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); + + // TODO(Rinne): implement it. + return (null, null); + } + + /// + /// Creates a saveable using the _serialize_to_tensor method. + /// + /// + private (List, Dictionary>) _create_serialize_to_tensor_saveable( + IDictionary>> saveable_factories) + { + string suffix = SaveableCompat.get_saveable_name(this.Trackable); + suffix = suffix ?? ""; + var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; + + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The restore under graph mode has not been implemented. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + + var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); + // skip the cache. + Dictionary> dict = new(); + dict[saveable_name] = saveable; + return (new List(), dict); + } + + private (List, Dictionary>) _create_saveables_by_attribute_name( + IDictionary>> saveable_factories) + { + // TODO(Rinne): implement it. + if(ObjectProto.Attributes is null) + { + return (new List(), new Dictionary>()); + } + + List existing_restore_ops = new(); + HashSet created_compat_names = new(); + Dictionary> named_saveables = new(); + foreach (var serialized_tensor in ObjectProto.Attributes) + { + Operation existing_op; + if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) + { + existing_op = null; + } + else + { + existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; + } + + if(existing_op is not null) + { + existing_restore_ops.Add(existing_op); + continue; + } + + if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) + { + continue; + } + + // TODO(Rinne): deal with cache. + + var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); + if(saveable is null) + { + _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List()).Add(serialized_tensor.Name); + continue; + } + named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; + } + return (existing_restore_ops, named_saveables); + } + + private OneOf? _get_saveable_from_factory(IDictionary>> saveable_factories, + TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet created_compat_names) + { + var expected_factory_name = serialized_tensor.Name; + var factory_input_name = serialized_tensor.CheckpointKey; + + if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) + { + foreach(var item in saveable_factories) + { + var factory_name = item.Key; + var factory = item.Value; + if (expected_factory_name.StartsWith(factory_name)) + { + if(matched_factory is not null) + { + throw new ValueError($"Forward compatibility load error: Unable to load " + + "checkpoint saved in future version of TensorFlow. " + + "Please update your version of TensorFlow to the " + + "version in which the checkpoint was saved."); + } + } + matched_factory = factory; + factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; + created_compat_names.Add(factory_name); + } + } + return matched_factory(factory_input_name); + } + + private string _extract_saveable_name(string checkpoint_key) + { + var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; + return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); + } + + /// + /// Restore the bound Trackable and dependencies (may be deferred). + /// + private List _restore_descendants() + { + Queue<(CheckpointPosition, Trackable)> visit_queue = new(); + visit_queue.Enqueue((this, this.Trackable)); + List restore_ops = new(); + Dictionary> tensor_saveables = new(); + List positions = new(); + + CheckpointPosition current_position = null; + while (visit_queue.Count > 0) + { + current_position = visit_queue.Dequeue().Item1; + var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); + restore_ops.AddRange(new_restore_ops); + foreach(var item in new_tensor_saveables) + { + tensor_saveables.Add(item.Key, item.Value); + } + positions.AddRange(new_positions); + _queue_children_for_restoration(current_position, visit_queue); + _queue_slot_variables(current_position, visit_queue); + } + restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); + return restore_ops; + } + + private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + foreach(var child in checkpoint_position.ObjectProto.Children) + { + var child_position = checkpoint_position.create_child_position(child.NodeId); + var local_object = trackable._lookup_dependency(child.LocalName); + var child_proto = child_position.ObjectProto; + if(local_object is null) + { + if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) + { + trackable.DeferredDependencies.SetDefault(child.LocalName, new List()).Add(child_position); + } + } + else + { + if (child_position.bind_project(local_object)) + { + visit_queue.Enqueue((child_position, local_object)); + } + } + } + } + + private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + var checkpoint = checkpoint_position.Checkpoint; + if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) + { + checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var deferred_slot_restoration in positions) + { + var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( + trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, + deferred_slot_restoration.SlotName + ); + if(slot_variable_position is not null) + { + visit_queue.Enqueue((slot_variable_position, slot_variable)); + } + } + } + if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) + { + checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var slot_restoration in restorations) + { + if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) + { + throw new NotImplementedException(); + // TODO(Rinne); implement it. + } + else + { + Debug.Assert(trackable is BaseResourceVariable); + Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List()) + .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); + } + } + } + } + + private (List, Dictionary>, List, object?) _single_restore() + { + var trackable = this.Trackable; + trackable._maybe_initialize_trackable(); + if(_checkpoint.RestoreUid > trackable.UpdateUid) + { + var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); + trackable.UpdateUid = _checkpoint.RestoreUid; + return (restore_ops, tensor_saveables, positions, registered_savers); + } + else + { + return (new List(), new Dictionary>(), + new List(), null); + } + } +} + +public record class DeferredSlotVariableRestoration( + BaseResourceVariable OriginalVariable, + int SlotVariableId, + string SlotName +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Clustering/KMeans.cs b/src/TensorFlowNET.Core/Clustering/KMeans.cs new file mode 100644 index 000000000..04a0375f2 --- /dev/null +++ b/src/TensorFlowNET.Core/Clustering/KMeans.cs @@ -0,0 +1,110 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Clustering +{ + /// + /// Creates the graph for k-means clustering. + /// + public class KMeans + { + public const string CLUSTERS_VAR_NAME = "clusters"; + + public const string SQUARED_EUCLIDEAN_DISTANCE = "squared_euclidean"; + public const string COSINE_DISTANCE = "cosine"; + public const string RANDOM_INIT = "random"; + public const string KMEANS_PLUS_PLUS_INIT = "kmeans_plus_plus"; + public const string KMC2_INIT = "kmc2"; + + Tensor[] _inputs; + int _num_clusters; + string _initial_clusters; + string _distance_metric; + bool _use_mini_batch; + int _mini_batch_steps_per_iteration; + int _random_seed; + int _kmeans_plus_plus_num_retries; + int _kmc2_chain_length; + + public KMeans(Tensor inputs, + int num_clusters, + string initial_clusters = RANDOM_INIT, + string distance_metric = SQUARED_EUCLIDEAN_DISTANCE, + bool use_mini_batch = false, + int mini_batch_steps_per_iteration = 1, + int random_seed = 0, + int kmeans_plus_plus_num_retries = 2, + int kmc2_chain_length = 200) + { + _inputs = new Tensor[] { inputs }; + _num_clusters = num_clusters; + _initial_clusters = initial_clusters; + _distance_metric = distance_metric; + _use_mini_batch = use_mini_batch; + _mini_batch_steps_per_iteration = mini_batch_steps_per_iteration; + _random_seed = random_seed; + _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries; + _kmc2_chain_length = kmc2_chain_length; + } + + public object training_graph() + { + var initial_clusters = _initial_clusters; + var num_clusters = ops.convert_to_tensor(_num_clusters); + var inputs = _inputs; + var vars = _create_variables(num_clusters); + var cluster_centers_var = vars[0]; + var cluster_centers_initialized = vars[1]; + var total_counts = vars[2]; + var cluster_centers_updated = vars[3]; + var update_in_steps = vars[4]; + + var init_op = new _InitializeClustersOpFactory(_inputs, num_clusters, initial_clusters, _distance_metric, + _random_seed, _kmeans_plus_plus_num_retries, + _kmc2_chain_length, cluster_centers_var, cluster_centers_updated, + cluster_centers_initialized).op(); + + throw new NotImplementedException("KMeans training_graph"); + } + + private RefVariable[] _create_variables(Tensor num_clusters) + { + var init_value = constant_op.constant(new float[0], dtype: TF_DataType.TF_FLOAT); + var cluster_centers = tf.Variable(init_value, name: CLUSTERS_VAR_NAME, validate_shape: false); + var cluster_centers_initialized = tf.Variable(false, dtype: TF_DataType.TF_BOOL, name: "initialized"); + RefVariable update_in_steps = null; + if (_use_mini_batch && _mini_batch_steps_per_iteration > 1) + throw new NotImplementedException("KMeans._create_variables"); + else + { + var cluster_centers_updated = cluster_centers; + var ones = array_ops.ones(new Tensor[] { num_clusters }, dtype: TF_DataType.TF_INT64); + var cluster_counts = _use_mini_batch ? tf.Variable(ones) : null; + return new RefVariable[] + { + /*cluster_centers, + cluster_centers_initialized, + cluster_counts, + cluster_centers_updated,*/ + update_in_steps + }; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs new file mode 100644 index 000000000..1b295fcfd --- /dev/null +++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs @@ -0,0 +1,149 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Clustering +{ + /// + /// Internal class to create the op to initialize the clusters. + /// + public class _InitializeClustersOpFactory + { + Tensor[] _inputs; + Tensor _num_clusters; + string _initial_clusters; + string _distance_metric; + int _random_seed; + int _kmeans_plus_plus_num_retries; + int _kmc2_chain_length; + RefVariable _cluster_centers; + RefVariable _cluster_centers_updated; + RefVariable _cluster_centers_initialized; + Tensor _num_selected; + Tensor _num_remaining; + Tensor _num_data; + + public _InitializeClustersOpFactory(Tensor[] inputs, + Tensor num_clusters, + string initial_clusters, + string distance_metric, + int random_seed, + int kmeans_plus_plus_num_retries, + int kmc2_chain_length, + RefVariable cluster_centers, + RefVariable cluster_centers_updated, + RefVariable cluster_centers_initialized) + { + _inputs = inputs; + _num_clusters = num_clusters; + _initial_clusters = initial_clusters; + _distance_metric = distance_metric; + _random_seed = random_seed; + _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries; + _kmc2_chain_length = kmc2_chain_length; + _cluster_centers = cluster_centers; + _cluster_centers_updated = cluster_centers_updated; + _cluster_centers_initialized = cluster_centers_initialized; + + _num_selected = array_ops.shape(_cluster_centers).slice(0); + _num_remaining = _num_clusters - _num_selected; + + _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i).slice(0)).ToArray()); + } + + private Tensor _initialize() + { + return tf_with(ops.control_dependencies(new Operation[] + { + check_ops.assert_positive(_num_remaining) + }), delegate + { + var num_now_remaining = _add_new_centers(); + return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0), + () => + { + return state_ops.assign(_cluster_centers_initialized, true); + }, + () => + { + return control_flow_ops.no_op().output.slice(0); + }); + }); + } + + public Tensor op() + { + var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, ops.convert_to_tensor(0)), + () => + { + return check_ops.assert_equal(_cluster_centers_initialized, true); + }, + _initialize); + + return x; + } + + private Tensor _add_new_centers() + { + // Adds some centers and returns the number of centers remaining. + var new_centers = _choose_initial_centers(); + if (_distance_metric == KMeans.COSINE_DISTANCE) + new_centers = nn_impl.l2_normalize(new_centers.slice(0), axis: 1); + + // If cluster_centers is empty, it doesn't have the right shape for concat. + var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), + () => new Tensor[] { new_centers }, + () => new Tensor[] { array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); + + var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); + + return _num_clusters - array_ops.shape(a).slice(0); + } + + private Tensor _choose_initial_centers() + { + return _greedy_batch_sampler().slice(0); + } + + private Tensor _greedy_batch_sampler() + { + return control_flow_ops.cond(_num_data <= _num_remaining, + () => + { + return array_ops.concat(_inputs, 0); + }, + () => + { + return _random(); + }); + } + + private Tensor _random() + { + var reshape = array_ops.reshape(_num_remaining, new int[] { -1 }); + var cast = math_ops.cast(_num_data, TF_DataType.TF_INT64); + var indices = random_ops.random_uniform( + reshape, + minval: 0, + maxval: cast, + seed: _random_seed, + dtype: TF_DataType.TF_INT64); + return embedding_ops.embedding_lookup(_inputs, indices, partition_strategy: "div"); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs new file mode 100644 index 000000000..7502a3a78 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class DictionaryExtension + { + public static void Deconstruct(this KeyValuePair pair, out T1 first, out T2 second) + { + first = pair.Key; + second = pair.Value; + } + public static void Update(this Dictionary dic, IDictionary other) + { + foreach(var (key, value) in other) + { + dic[key] = value; + } + } + public static T2 GetOrDefault(this Dictionary dic, T1 key, T2 defaultValue) + { + if (dic.ContainsKey(key)) + { + return dic[key]; + } + return defaultValue; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs new file mode 100644 index 000000000..6ceba445a --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs @@ -0,0 +1,23 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class JObjectExtensions + { + public static T? TryGetOrReturnNull(this JObject obj, string key) + { + var res = obj[key]; + if (res is null) + { + return default; + } + else + { + return res.ToObject(); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs new file mode 100644 index 000000000..287b48cc3 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class LinqExtensions + { +#if NETSTANDARD2_0 + public static IEnumerable TakeLast(this IEnumerable sequence, int count) + { + return sequence.Skip(sequence.Count() - count); + } + + public static IEnumerable SkipLast(this IEnumerable sequence, int count) + { + return sequence.Take(sequence.Count() - count); + } +#endif + public static Tensors ToTensors(this Tensor[] tensors) + { + return new Tensors(tensors); + } + + public static Tensors ToTensors(this IList tensors) + { + return new Tensors(tensors); + } + + public static void Deconstruct(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) + { + first = values.Item1; + second = values.Item2; + third = values.Item3; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs new file mode 100644 index 000000000..76bdd6133 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Common.Extensions +{ + public static class NestExtensions + { + public static Tensors ToTensors(this INestable tensors) + { + return new Tensors(tensors.AsNest()); + } + + public static Tensors? ToTensors(this Nest tensors) + { + return Tensors.FromNest(tensors); + } + + /// + /// If the nested object is already a nested type, this function could reduce it. + /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. + /// + /// + /// + /// + /// + public static Nest ReduceTo(this INestStructure input) where TIn: INestStructure + { + return Nest.ReduceFrom(input); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs new file mode 100644 index 000000000..c7fb80938 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs @@ -0,0 +1,13 @@ +using OneOf; +using System; + +namespace Tensorflow.Common.Extensions +{ + public static class OneofExtension + { + public static bool IsTypeOrDeriveFrom(this IOneOf src) + { + return src.Value is T; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs new file mode 100644 index 000000000..d0c35ee70 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This is a temp solution, which should be removed after refactoring `Tensors` + /// + [Obsolete] + public class FakeTensorByTensorArray: Tensor + { + public TensorArray TensorArray { get; set; } + + public FakeTensorByTensorArray(TensorArray array) + { + TensorArray = array; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs new file mode 100644 index 000000000..986136f4d --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class GeneralizedTensorShape: Nest + { + public GeneralizedTensorShape(Shape value, string? name = null) + { + NodeValue = value; + NestType = NestType.Node; + } + + public GeneralizedTensorShape(IEnumerable values, string? name = null) + { + ListValue = values.Select(s => new Nest(s) as INestStructure).ToList(); + Name = name; + NestType = NestType.List; + } + + public GeneralizedTensorShape(Dictionary value, string? name = null) + { + DictValue = value.ToDictionary(x => x.Key, x => new Nest(x.Value) as INestStructure); + Name = name; + NestType = NestType.Dictionary; + } + + public GeneralizedTensorShape(Nest other) + { + NestType = other.NestType; + NodeValue = other.NodeValue; + DictValue = other.DictValue; + ListValue = other.ListValue; + Name = other.Name; + } + + public Shape ToSingleShape() + { + var shapes = Flatten().ToList(); + if (shapes.Count != 1) + { + throw new ValueError("The generalized shape contains more than 1 dim."); + } + return shapes[0]; + } + + public long ToNumber() + { + var shapes = Flatten().ToList(); + if (shapes.Count != 1 || shapes[0].ndim != 1) + { + throw new ValueError("The generalized shape contains more than 1 dim."); + } + return shapes[0].dims[0]; + } + + public INestStructure ToTensorShapeConfigs() + { + return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select(x => x == -1 ? null : x).ToArray() }); + } + + public static implicit operator GeneralizedTensorShape(Shape shape) + { + return new GeneralizedTensorShape(shape); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/INestStructure.cs b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs new file mode 100644 index 000000000..32b662937 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This interface indicates that a class may have a nested structure and provide + /// methods to manipulate with the structure. + /// + public interface INestStructure: INestable + { + NestType NestType { get; } + + /// + /// The item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. + /// + int ShallowNestedCount { get; } + /// + /// The total item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. + /// + int TotalNestedCount { get; } + + /// + /// Flatten the Nestable object. Node that if the object contains only one value, + /// it will be flattened to an enumerable with one element. + /// + /// + IEnumerable Flatten(); + /// + /// Construct a new object with the same nested structure. + /// + /// + /// + /// + INestStructure MapStructure(Func func); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/INestable.cs b/src/TensorFlowNET.Core/Common/Types/INestable.cs new file mode 100644 index 000000000..7ce49f85a --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INestable.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public interface INestable + { + Nest AsNest(); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs new file mode 100644 index 000000000..427e71aaa --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This interface is used when some corresponding python methods have optional args. + /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while + /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` + /// as the parameter of the method. + /// + public interface IOptionalArgs + { + /// + /// The identifier of the class. It is not an argument but only something to + /// separate different OptionalArgs. + /// + string Identifier { get; } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NamedTuple.cs b/src/TensorFlowNET.Core/Common/Types/NamedTuple.cs new file mode 100644 index 000000000..48073c61b --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NamedTuple.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class NamedTuple + { + public string Name { get; set; } + public Dictionary ValueDict { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs new file mode 100644 index 000000000..dc7fd3a1f --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public static class Nest + { + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, TOut[] flatItems) + { + return template.AsNest().PackSequence(flatItems); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, List flatItems) + { + return template.AsNest().PackSequence(flatItems.ToArray()); + } + + /// + /// Flatten the nested object. + /// + /// + /// + /// + public static IEnumerable Flatten(INestable nestedObject) + { + return nestedObject.AsNest().Flatten(); + } + + /// + /// Map the structure with specified function. + /// + /// + /// + /// + /// + /// + public static INestStructure MapStructure(Func func, INestable nestedObject) + { + return nestedObject.AsNest().MapStructure(func); + } + + public static bool IsNested(INestable obj) + { + return obj.AsNest().IsNested(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.cs b/src/TensorFlowNET.Core/Common/Types/Nest.cs new file mode 100644 index 000000000..89ce29f2f --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.cs @@ -0,0 +1,485 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Common.Types +{ + public enum NestType + { + Empty, + Node, + List, + Dictionary + } + + /// + /// A nested structure which may inclulde value, list and dictionary. + /// Note that dictionary does not ensure the data order. When using it as IEnumerable, + /// its order is depth-first. + /// + /// + public class Nest : INestStructure, IEnumerable + { + private static readonly Nest _empty = new Nest() + { + NestType = NestType.Empty, + }; + public static Nest Empty => _empty; + public NestType NestType { get; protected set; } + public string? Name { get; set; } + public T? NodeValue { get; protected set; } + public List>? ListValue { get; protected set; } + public Dictionary>? DictValue { get; protected set; } + + public int ShallowNestedCount + { + get + { + if (NestType == NestType.Empty) + { + return 0; + } + else if (NestType == NestType.Node) + { + return 1; + } + else if (NestType == NestType.List) + { + return ListValue!.Count; + } + else // dict + { + return DictValue!.Count; + } + } + } + + public int TotalNestedCount + { + get + { + return Flatten().Count(); + } + } + + protected Nest() { } + + public Nest(T value, string? name = null) + { + NodeValue = value; + Name = name; + NestType = NestType.Node; + } + + public Nest(IEnumerable> values, string? name = null) + { + ListValue = values.ToList(); + Name = name; + NestType = NestType.List; + } + + public Nest(Dictionary> value, string? name = null) + { + DictValue = value; + Name = name; + NestType = NestType.Dictionary; + } + + public Nest(Nest other) + { + NestType = other.NestType; + NodeValue = other.NodeValue; + DictValue = other.DictValue; + ListValue = other.ListValue; + Name = other.Name; + } + + public virtual IEnumerable Flatten() + { + return FlattenInternal(this); + } + public virtual INestStructure MapStructure(Func func) + { + return MapStructureInternal(func); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + public virtual Nest PackSequence(TOut[] flatItems) + { + if(flatItems.Length == 0) + { + return Nest.Empty; + } + int index = 0; + return PackSequenceInternal(this, flatItems, ref index); + } + + private static Nest PackSequenceInternal(Nest template, TOut[] flatItems, ref int index) + { + if(template.NestType == NestType.Node) + { + if(index >= flatItems.Length) + { + throw new InvalidArgumentError("The template and flat items are not matched."); + } + return new Nest(flatItems[index++]); + } + else if(template.NestType == NestType.List) + { + List> nestedObjects = new List>(); + for (int i = 0; i < template.ListValue!.Count; i++) + { + nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); + } + return new Nest(nestedObjects); + } + else if(template.NestType == NestType.Node) + { + Dictionary> dict = new Dictionary>(); + foreach(var (key, value) in template.DictValue!) + { + dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); + } + return new Nest(dict); + } + // Consider Empty as invalid type. + throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); + } + + public virtual Nest AsNest() + { + return this; + } + + public virtual Nest MergeWith(Nest? other) + { + if(other is null || other == Nest.Empty) + { + return this; + } + if(this == Nest.Empty) + { + return other; + } + if(NestType == NestType.Node && other.NestType == NestType.Node) + { + return new Nest(new Nest[] { this, other }); + } + else if(NestType == NestType.List && other.NestType == NestType.List) + { + return new Nest(this.ListValue!.Concat(other.ListValue!)); + } + else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) + { + return new Nest(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); + } + else + { + return new Nest(new Nest[] { this, other }); + } + } + + /// + /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not + /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. + /// + /// + public bool IsNested() + { + if(NestType is NestType.Empty or NestType.Node) + { + return false; + } + else if(NestType is NestType.List) + { + return ListValue!.Count > 0; + } + else + { + return DictValue!.Count > 0; + } + } + + [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] + public T this[int index] + { + get + { + bool success = FindInternal(this, index, out var result); + if (success) + { + return result; + } + else + { + throw new IndexOutOfRangeException(); + } + } + set + { + bool success = SetInternal(this, index, value); + if (!success) + { + throw new IndexOutOfRangeException(); + } + } + } + + /// + /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it + /// to `Nest[T]`. + /// + /// + /// + /// + public static Nest ReduceFrom(INestStructure input) where TOut: INestStructure + { + var nested = input.AsNest(); + return ReduceInternal(nested).AsNest(); + } + + private static INestStructure ReduceInternal(Nest node) where TOut : INestStructure + { + if(node.NestType == NestType.Empty) + { + return Nest.Empty; + } + else if(node.NestType == NestType.Node) + { + return node.NodeValue!.AsNest(); + } + else if(node.NestType == NestType.List) + { + return new Nest(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); + } + else // Dictionary type + { + return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); + } + } + + private static bool FindInternal(Nest node, int index, out T? result) + { + if (node.NestType == NestType.Node) + { + if(index == 0) + { + result = node.NodeValue!; + return true; + } + result = default(T); + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if(index == 0) + { + return FindInternal(item.AsNest(), index, out result); + } + index--; + } + result = default(T); + return false; + } + else if(node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return FindInternal(item.AsNest(), index, out result); + } + index--; + } + result = default(T); + return false; + } + else + { + result = default(T); + return false; + } + } + + private static bool SetInternal(Nest node, int index, T newValue) + { + if (node.NestType == NestType.Node) + { + if (index == 0) + { + node.NodeValue = newValue; + return true; + } + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if (index == 0) + { + return SetInternal(item.AsNest(), index, newValue); + } + index--; + } + return false; + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return SetInternal(item.AsNest(), index, newValue); + } + index--; + } + return false; + } + else + { + return false; + } + } + + private static IEnumerable FlattenInternal(Nest node) + { + if (node.NestType == NestType.Node) + { + yield return node.NodeValue!; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + foreach(var val in FlattenInternal(item.AsNest())) + { + yield return val; + } + } + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + foreach (var val in FlattenInternal(item.AsNest())) + { + yield return val; + } + } + } + } + + private Nest MapStructureInternal(Func func) + { + if (NestType == NestType.Node) + { + return new Nest(func(NodeValue!)); + } + else if (NestType == NestType.List) + { + List> outs = new List>(); + foreach (var item in ListValue!) + { + outs.Add(item.AsNest().MapStructureInternal(func)); + } + return new Nest(outs); + } + else if (NestType == NestType.Dictionary) + { + Dictionary> outs = new Dictionary>(); + foreach (var (key, value) in DictValue!) + { + outs.Add(key, value.AsNest().MapStructureInternal(func)); + } + return new Nest(outs); + } + else + { + return Nest.Empty; + } + } + + public IEnumerator GetEnumerator() + { + return Flatten().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + sb.Append("("); + WriteString(this, sb); + sb.Append(")"); + return sb.ToString(); + } + + private static void WriteString(Nest node, StringBuilder sb) + { + if (!string.IsNullOrEmpty(node.Name)) + { + sb.Append($"{node.Name}: "); + } + if (node.NestType == NestType.Node) + { + sb.Append(node.NodeValue!.ToString()); + } + else if (node.NestType == NestType.List) + { + sb.Append("["); + for(int i = 0; i < node.ListValue!.Count; i++) + { + WriteString(node.ListValue![i].AsNest(), sb); + if(i != node.ListValue!.Count - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else if (node.NestType == NestType.Dictionary) + { + sb.Append("{"); + int count = node.DictValue!.Count; + int i = 0; + foreach (var (key, value) in node.DictValue!) + { + sb.Append($"{key}: "); + WriteString(value.AsNest(), sb); + if (i != count - 1) + { + sb.Append(", "); + } + i++; + } + sb.Append("}"); + } + else + { + sb.Append(""); + } + } + + public static implicit operator Nest((INestStructure, INestStructure) inputs) + { + return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2 }); + } + + public static implicit operator Nest((INestStructure, INestStructure, INestStructure) inputs) + { + return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2, inputs.Item3 }); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs new file mode 100644 index 000000000..cf1994554 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class NestDictionary : INestStructure, IDictionary where TKey : notnull + { + public NestType NestType => NestType.Dictionary; + public IDictionary Value { get; set; } + public int ShallowNestedCount => Values.Count; + + public int TotalNestedCount => Values.Count; + public NestDictionary(IDictionary dict) + { + Value = dict; + } + public IEnumerable Flatten() + { + return Value.Select(x => x.Value); + } + public INestStructure MapStructure(Func func) + { + return new NestList(Value.Select(x => func(x.Value))); + } + + public Nest AsNest() + { + return new Nest(Value.Values.Select(x => new Nest(x))); + } + + // Required IDictionary members + public int Count => Value.Count; + + public bool IsReadOnly => Value.IsReadOnly; + + public ICollection Keys => Value.Keys; + + public ICollection Values => Value.Values; + + public void Add(TKey key, TValue value) + { + Value.Add(key, value); + } + + public void Add(KeyValuePair item) + { + Value.Add(item); + } + + public void Clear() + { + Value.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return Value.Contains(item); + } + + public bool ContainsKey(TKey key) + { + return Value.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + Value.CopyTo(array, arrayIndex); + } + + public IEnumerator> GetEnumerator() + { + return Value.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public bool Remove(TKey key) + { + return Value.Remove(key); + } + + public bool Remove(KeyValuePair item) + { + return Value.Remove(item); + } + + public bool TryGetValue(TKey key, out TValue value) + { + return Value.TryGetValue(key, out value); + } + + // Optional IDictionary members + public TValue this[TKey key] + { + get => Value[key]; + set => Value[key] = value; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs new file mode 100644 index 000000000..1e0d272b7 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// The implementation of a list that support nest structure, in which the depth is 1. + /// + /// + public sealed class NestList : INestStructure, IEnumerable + { + public NestType NestType => NestType.List; + public List Values { get; set; } + public int ShallowNestedCount => Values.Count; + + public int TotalNestedCount => Values.Count; + + public NestList(params T[] values) + { + Values = new List(values); + } + + public NestList(IEnumerable values) + { + Values = new List(values); + } + public IEnumerable Flatten() + { + return Values; + } + public INestStructure MapStructure(Func func) + { + return new NestList(Values.Select(x => func(x))); + } + + public Nest AsNest() + { + return new Nest(Values.Select(x => new Nest(x))); + } + + // Enumerator implementation + public IEnumerator GetEnumerator() + { + return Values.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestNode.cs b/src/TensorFlowNET.Core/Common/Types/NestNode.cs new file mode 100644 index 000000000..701aade9a --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestNode.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// A nested structure with only one element. + /// + /// + public class NestNode : INestStructure + { + public NestType NestType => NestType.Node; + public T Value { get; set; } + public int ShallowNestedCount => 1; + + public int TotalNestedCount => 1; + public NestNode(T value) + { + Value = value; + } + public IEnumerable Flatten() + { + yield return Value; + } + public INestStructure MapStructure(Func func) + { + return new NestNode(func(Value)); + } + + public Nest AsNest() + { + return new Nest(Value); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs new file mode 100644 index 000000000..a36930eca --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Common.Types +{ + public class TensorShapeConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } = "TensorShape"; + [JsonProperty("items")] + public long?[] Items { get; set; } + + public static implicit operator Shape(TensorShapeConfig shape) + => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + + public static implicit operator TensorShapeConfig(Shape shape) + => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() }; + } +} diff --git a/src/TensorFlowNET.Core/Contexts/Context.Config.cs b/src/TensorFlowNET.Core/Contexts/Context.Config.cs new file mode 100644 index 000000000..0c7bded6e --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/Context.Config.cs @@ -0,0 +1,136 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Contexts +{ + /// + /// Environment in which eager operations execute. + /// + public sealed partial class Context + { + protected Device.PhysicalDevice[] _physical_devices; + protected Dictionary _physical_device_to_index; + ConfigProto _config; + public ConfigProto Config + { + get + { + _initialize_physical_devices(); + + var config = new ConfigProto(); + if(_config is not null) + { + config.MergeFrom(_config); + } + config.LogDevicePlacement = _log_device_placement; + + config.DeviceCount["CPU"] = 0; + config.DeviceCount["GPU"] = 0; + foreach(var dev in _physical_devices) + { + if (config.DeviceCount.ContainsKey(dev.DeviceType)) + { + config.DeviceCount[dev.DeviceType] += 1; + } + else + { + config.DeviceCount[dev.DeviceType] = 1; + } + } + + var gpu_options = _compute_gpu_options(); + config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); + + return config; + } + set + { + _config = value; + } + } + + protected void _initialize_physical_devices(bool reinitialize = false) + { + if(!reinitialize && _physical_devices is not null) + { + return; + } + var devs = list_physical_devices(); + _physical_devices = devs.Select(d => new Device.PhysicalDevice() + { + DeviceName = d.DeviceName, + DeviceType = d.DeviceType + }).ToArray(); + _physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair(p, i)) + .ToDictionary(x => x.Key, x => x.Value); + + _import_config(); + } + + protected void _import_config() + { + if(_config is null) + { + return; + } + if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) + { + num_cpus = 1; + } + if(num_cpus != 1) + { + // TODO(Rinne): implement it. + } + + var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); + if(gpus.Count() == 0) + { + return; + } + + if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) + { + gpu_count = 0; + } + + // TODO(Rinne): implement it. + } + + ConfigProto MergeConfig() + { + Config.LogDevicePlacement = _log_device_placement; + // var gpu_options = _compute_gpu_options(); + // Config.GpuOptions.AllowGrowth = gpu_options.AllowGrowth; + return Config; + } + + GPUOptions _compute_gpu_options() + { + // By default, TensorFlow maps nearly all of the GPU memory of all GPUs + // https://www.tensorflow.org/guide/gpu + return new GPUOptions() + { + AllowGrowth = get_memory_growth("GPU") + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs new file mode 100644 index 000000000..d35d10847 --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs @@ -0,0 +1,174 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; +using Google.Protobuf; +using Tensorflow.Device; +using Tensorflow.Exceptions; +using System.Collections.Generic; + +namespace Tensorflow.Contexts +{ + /// + /// Environment in which eager operations execute. + /// + public sealed partial class Context + { + internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); + internal List _logical_devices = null; + internal List _context_devices = null; + + ContextDevicePlacementPolicy _device_policy; + bool _log_device_placement; + int _num_gpus; + Dictionary _memory_growth_map = new Dictionary(); + + public string DeviceName { get; set; } = ""; + public DeviceSpec DeviceSpec { get; set; } = null; + + internal List Devices + { + get + { + if(_context_devices is null) + { + throw new AssertionError("Context must be initialized first."); + } + return _context_devices; + } + } + + public void log_device_placement(bool enable) + { + if (_handle != null) + c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status); + _log_device_placement = enable; + // _thread_local_data.function_call_options = null; + } + + public bool get_memory_growth(string device_type) + { + foreach(var map in _memory_growth_map) + { + if (map.Key.DeviceType == device_type) + return map.Value; + } + return false; + } + + public void set_memory_growth(PhysicalDevice device, bool enable) + { + _memory_growth_map[device] = enable; + } + + public PhysicalDevice[] list_physical_devices(string device_type = null) + { + using var opts = c_api.TFE_NewContextOptions(); + using var ctx = c_api.TFE_NewContext(opts, tf.Status); + using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status); + tf.Status.Check(true); + + int num_devices = c_api.TF_DeviceListCount(devices); + var results = new List(); + for (int i = 0; i < num_devices; ++i) + { + var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status)); + tf.Status.Check(true); + + if (dev_type.StartsWith("XLA")) + continue; + + if (device_type == null || dev_type == device_type) + { + var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status); + tf.Status.Check(true); + + results.Add(new PhysicalDevice + { + DeviceName = dev_name, + DeviceType = dev_type + }); + } + } + + return results.ToArray(); + } + + public bool is_custom_device(string device_name) + { + return false; + // TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. + //ensure_initialized(); + //return c_api.TFE_IsCustomDevice(_handle, device_name); + } + + public EagerDeviceContext device(string name) + { + return new EagerDeviceContext(this, name); + } + + internal void _set_device(string device_name, DeviceSpec device_spec) + { + DeviceSpec = device_spec; + DeviceName = device_name; + } + + internal void _initialize_logical_devices() + { + List logical_devices = new(); + List context_devices = new(); + Status status = new(); + var device_list = c_api.TFE_ContextListDevices(_handle, status); + status.Check(true); + try + { + this._num_gpus = 0; + string current_job = null; + int current_task = -1; + for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) + { + var dev_name = c_api.TF_DeviceListName(device_list, i, status); + status.Check(true); + context_devices.Add(DeviceUtils.canonical_name(dev_name)); + var spec = DeviceSpec.from_string(dev_name); + if(spec.Job == "localhost") + { + spec = spec.replace(job: null, replica: -1, task: -1); + } + logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); + var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); + var dev_type = c_api.StringPiece(dev_type_memory); + status.Check(true); + if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) + { + _num_gpus++; + } + } + } + finally + { + _logical_devices = logical_devices; + _context_devices = context_devices; + } + } + } + + public record class LogicalDevice(string name, string device_type); +} diff --git a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs new file mode 100644 index 000000000..f6e0911ca --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs @@ -0,0 +1,106 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; +using Google.Protobuf; +using System.Collections.Generic; + +namespace Tensorflow.Contexts +{ + /// + /// Environment in which eager operations execute. + /// + public sealed partial class Context + { + Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args) + { + var keywords = new Dictionary(); + if (args.OpInputArgs != null) + { + foreach (var (i, input) in enumerate(args.OpInputArgs)) + keywords[$"input_{i}"] = input; + } + + if (args.OpAttrs != null) + { + foreach (var attr in args.OpAttrs) + keywords[attr.Key] = attr.Value; + } + + return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; + } + + Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args) + { + var opExecInfo = new FastPathOpExecInfo(tf.Context, OpType, Name, args.OpInputArgs) + { + attrs = args.OpAttrs + }; + return tf.Runner.TFE_FastPathExecute(opExecInfo); + } + + // [DebuggerStepThrough] + public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args) + { + if (tf.Context.has_graph_arg(args.OpInputArgs)) + { + if (executing_eagerly()) + { + graph_mode(); + var result = ExecGraphAction(opType, name, args); + restore_mode(); + return result; + } + else + { + var result = ExecGraphAction(opType, name, args); + if (tf.Runner.MustRecordGradient()) + { + var op = result[0].op; + Dictionary attrs; + if (args.GetGradientAttrs == null) + { + attrs = new Dictionary(); + attrs["T"] = op.dtype; + } + else + { + attrs = ConvertToDict(args.GetGradientAttrs(op)); + } + var args1 = new object[attrs.Count() * 2]; + int i = 0; + foreach (var arg in attrs) + { + args1[i] = arg.Key; + args1[i + 1] = arg.Value; + i += 2; + } + tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs); + } + return result; + } + } + else + { + return ExecEagerAction(opType, name, args); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs new file mode 100644 index 000000000..0507cc2f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -0,0 +1,242 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; +using Google.Protobuf; +using Tensorflow.Util; +using Tensorflow.NumPy; + +namespace Tensorflow.Contexts +{ + /// + /// Environment in which eager operations execute. + /// + public sealed partial class Context + { + public const int GRAPH_MODE = 0; + public const int EAGER_MODE = 1; + + int defaultExecutionMode = EAGER_MODE; + public string ScopeName { get; set; } = ""; + bool initialized = false; + ContextSwitchStack context_switches; + protected FunctionCallOptions _function_call_options; + public FunctionCallOptions FunctionCallOptions + { + get + { + if(_function_call_options is null) + { + var config = Config; + _function_call_options = new FunctionCallOptions() + { + Config = config + }; + } + return _function_call_options; + } + set + { + _function_call_options = value; + } + } + + SafeContextHandle _handle; + + int? _seed; + Random _rng; + + public Context() + { + _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; + context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); + initialized = false; + FunctionCallOptions = new FunctionCallOptions(); + ensure_initialized(); + } + + /// + /// Initialize handle and devices if not already done so. + /// + public void ensure_initialized() + { + if (initialized) + return; + + Debug.Assert(_context_devices is null); + + Config = MergeConfig(); + FunctionCallOptions.Config = Config; + var config_str = Config.ToByteArray(); + var opts = new ContextOptions(); + var status = new Status(); + c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status); + status.Check(true); + c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); + _handle = c_api.TFE_NewContext(opts, status); + status.Check(true); + _initialize_logical_devices(); + initialized = true; + } + + public void set_global_seed(int? seed) + { + _seed = seed; + if (seed.HasValue) + _rng = new Random(seed.Value); + else + _rng = null; + // Also clear the kernel cache, to reset any existing seeds + if (_handle != null) + c_api.TFE_ContextClearCaches(_handle); + } + + public int? global_seed() + => _seed; + + public int? internal_operation_seed() + => _rng?.Next(0, int.MaxValue); + + public void start_step() + => c_api.TFE_ContextStartStep(_handle); + + public void end_step() + => c_api.TFE_ContextEndStep(_handle); + + /// + /// Checks whether the current thread has eager execution enabled. + /// + /// + [DebuggerStepThrough] + public bool executing_eagerly() + { + if(context_switches.Count() == 0) + tf.enable_eager_execution(); + + return context_switches.Current().EagerMode; + } + + public bool is_build_function() + => context_switches.Current().IsBuildingFunction; + + public string shared_name(string name = null) + => !string.IsNullOrEmpty(name) || !executing_eagerly() ? + name : + "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + + public string anonymous_name() + { + return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + } + + public void graph_mode(bool isFunc = false) + => context_switches.Push(false, isFunc); + + public void eager_mode(bool isFunc = false) + => context_switches.Push(true, isFunc); + + public bool switched_to_graph(params object[] args) + { + var switching_to_graph = has_graph_arg(args) && tf.Context.executing_eagerly(); + if (switching_to_graph) + tf.Context.graph_mode(tf.Context.is_build_function()); + return switching_to_graph; + } + + public bool has_graph_arg(params object[] args) + { + var flatten_args = nest.flatten(args); + /*if (flatten_args.Count(x => x.GetType().IsValueType) == flatten_args.Count()) + return tf.Context.executing_eagerly() == false*/ + + bool has_graph_arg = !tf.Context.executing_eagerly(); + foreach (var el in flatten_args) + { + if (el is NDArray) + continue; + else if (el is EagerTensor) + continue; + else if (el is Tensor) + { + has_graph_arg = true; + break; + } + } + return has_graph_arg; + } + + public bool has_function(string name) + { + ensure_initialized(); + return c_api.TFE_ContextHasFunction(_handle, name); + } + + public void add_function(SafeFuncGraphHandle fn) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextAddFunction(_handle, fn, status); + status.Check(true); + } + + public void remove_function(string name) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextRemoveFunction(_handle, name, status); + status.Check(true); + } + + public void add_function_def(FunctionDef fdef) + { + ensure_initialized(); + var fdef_string = fdef.ToByteArray(); + Status status = new Status(); + c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); + status.Check(true); + } + + public void restore_mode() + { + context_switches.Pop(); + tf.get_default_graph(); + } + + public void reset_context() + { + // ops.reset_uid(); + // tf.defaultSession = null; + ops.reset_default_graph(); + context_switches.Clear(); + tf.Context.ensure_initialized(); + + if (_handle != null) + { + c_api.TFE_ContextClearCaches(_handle); + } + _device_parsing_cache.Clear(); + } + + public static implicit operator SafeContextHandle(Context ctx) + { + return ctx._handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs b/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs new file mode 100644 index 000000000..96836a2fc --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Contexts +{ + public enum ContextDevicePlacementPolicy + { + // Running operations with input tensors on the wrong device will fail. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default placement + // policy. + DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, + } +} diff --git a/src/TensorFlowNET.Core/Contexts/ContextOptions.cs b/src/TensorFlowNET.Core/Contexts/ContextOptions.cs new file mode 100644 index 000000000..4a07f1f5c --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ContextOptions.cs @@ -0,0 +1,34 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Eager; + +namespace Tensorflow.Contexts; + +public sealed class ContextOptions +{ + SafeContextOptionsHandle _handle { get; } + + public ContextOptions() + { + _handle = c_api.TFE_NewContextOptions(); + } + + public static implicit operator SafeContextOptionsHandle(ContextOptions opt) + { + return opt._handle; + } +} diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs new file mode 100644 index 000000000..4046e8772 --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Contexts +{ + public class ContextSwitch + { + public bool EagerMode { get; set; } + + /// + /// Whether the context is building a function. + /// + public bool IsBuildingFunction { get; set; } + + /// + /// A callable that executes the context switch. + /// + public Action EnterContextFn { get; set; } + + public string DeviceStack { get; set; } + + public override string ToString() + => $"EagerMode: {EagerMode}, IsBuildingFunction: {IsBuildingFunction}"; + } +} diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs new file mode 100644 index 000000000..27704b3ee --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs @@ -0,0 +1,63 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Contexts +{ + /// + /// Match the semantics of DefaultGraphStack + /// + public class ContextSwitchStack + { + Stack stack; + + public ContextSwitchStack(bool isEager, bool isFunc) + { + stack = new Stack(); + Push(isEager, isFunc); + } + + public void Push(bool isEager, bool isFunc) + { + stack.Push(new ContextSwitch + { + EagerMode = isEager, + IsBuildingFunction = isFunc + }); + } + + public void Clear() + { + stack.Clear(); + } + + public void Pop() + { + stack.Pop(); + } + + public int Count() + { + return stack.Count; + } + + public ContextSwitch Current() + { + return stack.Peek(); + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs new file mode 100644 index 000000000..2d5f61cdb --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Device; + +namespace Tensorflow.Contexts +{ + public class EagerDeviceContext : ITensorFlowObject + { + private Context _ctx; + private string _device_name; + private Stack<(string, DeviceSpec, DeviceSpec)> _stack; + + public EagerDeviceContext(Context ctx, string device_name) + { + _ctx = ctx; + _device_name = device_name; + _stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); + } + public void __enter__() + { + var ctx = _ctx; + var old_device_name = ctx.DeviceName; + var old_device_spec = ctx.DeviceSpec; + var new_device_name = _device_name; + var cache_key = (old_device_name, new_device_name); + DeviceSpec new_device_spec; + if (Context._device_parsing_cache.ContainsKey(cache_key)) + { + (new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; + } + else + { + if(new_device_name is not null) + { + var device_spec = DeviceSpec.from_string(new_device_name); + if (!string.IsNullOrEmpty(old_device_name)) + { + new_device_spec = new DeviceSpec(old_device_spec); + } + else + { + ctx.ensure_initialized(); + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_spec = new_device_spec.make_merged_spec(device_spec); + } + else + { + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_name = new_device_spec.ToString(); + Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); + } + ctx._set_device(new_device_name, new_device_spec); + _stack.Push((old_device_name, old_device_spec, new_device_spec)); + } + + public void __exit__() + { + var ctx = _ctx; + var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); + ctx._set_device(old_device_name, old_device_spec); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs new file mode 100644 index 000000000..2e6337601 --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class ExecuteOpArgs + { + public Func GetGradientAttrs { get; set; } + public object[] OpInputArgs { get; set; } + public Dictionary OpAttrs { get; set; } + + /// + /// + /// + /// For array: OpInputArgs = new object[]{ } + [DebuggerStepThrough] + public ExecuteOpArgs(params object[] inputArgs) + { + OpInputArgs = inputArgs; + } + + [DebuggerStepThrough] + public ExecuteOpArgs SetAttributes(object attrs) + { + OpAttrs = ConvertToDict(attrs); + return this; + } + } +} diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs new file mode 100644 index 000000000..71312d11b --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Google.Protobuf; +using Protobuf.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Contexts +{ + public class FunctionCallOptions + { + public ConfigProto Config { get; set; } + public string ExecutorType { get; set; } + + public ByteString config_proto_serialized() + { + return Config.ToByteString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/BatchDataset.cs b/src/TensorFlowNET.Core/Data/BatchDataset.cs new file mode 100644 index 000000000..874c433de --- /dev/null +++ b/src/TensorFlowNET.Core/Data/BatchDataset.cs @@ -0,0 +1,38 @@ +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that batches contiguous elements from its input. + /// + public class BatchDataset : UnaryDataset + { + Tensor _batch_size; + Tensor _drop_remainder; + + public BatchDataset(IDatasetV2 input_dataset, int batch_size, bool drop_remainder = false) : + base(input_dataset) + { + _input_dataset = input_dataset; + _batch_size = tf.convert_to_tensor(batch_size, dtype: TF_DataType.TF_INT64, name: "batch_size"); + _drop_remainder = tf.convert_to_tensor(drop_remainder, dtype: TF_DataType.TF_BOOL, name: "drop_remainder"); + + if (drop_remainder) + { + throw new NotImplementedException(""); + } + else + { + structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); + } + + variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, + _batch_size, + _drop_remainder, + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/CacheDataset.cs b/src/TensorFlowNET.Core/Data/CacheDataset.cs new file mode 100644 index 000000000..a85d58f72 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/CacheDataset.cs @@ -0,0 +1,20 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class CacheDataset : UnaryUnchangedStructureDataset + { + Tensor _filename; + public CacheDataset(IDatasetV2 input_dataset, + string filename = "") : + base(input_dataset) + { + _filename = tf.convert_to_tensor(filename, dtype: TF_DataType.TF_STRING, name: "filename"); + variant_tensor = ops.cache_dataset_v2(input_dataset.variant_tensor, + _filename, + ops.dummy_memory_cache(), + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs new file mode 100644 index 000000000..9d4abd6b2 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ConcatenateDataset.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using static Tensorflow.Binding; + +namespace Tensorflow.Data +{ + /// + /// A `Dataset` that concatenates its input with given dataset. + /// + public class ConcatenateDataset : DatasetV2 + { + IDatasetV2 _input_dataset; + IDatasetV2 _dataset_to_concatenate; + public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate) + { + _input_dataset = input_dataset; + _dataset_to_concatenate = dataset_to_concatenate; + var _structure = new List(); + foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec)) + { + var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape); + _structure.Add(new TensorSpec(shape, dtype: spec.dtype)); + } + structure = _structure.ToArray(); + + variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor, + dataset_to_concatenate.variant_tensor, + output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/DataSetBase.cs b/src/TensorFlowNET.Core/Data/DataSetBase.cs new file mode 100644 index 000000000..2face8bcb --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DataSetBase.cs @@ -0,0 +1,10 @@ +using Tensorflow.NumPy; + +namespace Tensorflow +{ + public abstract class DataSetBase : IDataSet + { + public NDArray Data { get; protected set; } + public NDArray Labels { get; protected set; } + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs new file mode 100644 index 000000000..b55185059 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -0,0 +1,44 @@ +using Tensorflow.NumPy; +using System.Collections.Generic; +using Tensorflow.Data; + +namespace Tensorflow +{ + public class DatasetManager + { + public IDatasetV2 from_generator(IEnumerable generator, TF_DataType[] output_types, Shape[] output_shapes) + => new GeneratorDataset(); + + /// + /// Creates a `Dataset` with a single element, comprising the given tensors. + /// + /// + /// + public IDatasetV2 from_tensors(NDArray tensors) + => new TensorDataset(tensors); + + public IDatasetV2 from_tensors(Tensors tensors) + => new TensorDataset(tensors); + + public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) + => new TensorSliceDataset(features, labels); + + public IDatasetV2 from_tensor_slices(Tensor tensor) + => new TensorSliceDataset(tensor); + + public IDatasetV2 from_tensor_slices(string[] array) + => new TensorSliceDataset(array); + + public IDatasetV2 from_tensor_slices(NDArray array) + => new TensorSliceDataset(array); + + public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64) + => new RangeDataset(count, output_type: output_type); + + public IDatasetV2 range(int start, int stop, int step = 1, TF_DataType output_type = TF_DataType.TF_INT64) + => new RangeDataset(stop, start: start, step: step, output_type: output_type); + + public IDatasetV2 zip(params IDatasetV2[] ds) + => new ZipDataset(ds); + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetOps.cs b/src/TensorFlowNET.Core/Data/DatasetOps.cs new file mode 100644 index 000000000..171e90f82 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetOps.cs @@ -0,0 +1,6 @@ +namespace Tensorflow +{ + public class DatasetOps + { + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetOptions.cs b/src/TensorFlowNET.Core/Data/DatasetOptions.cs new file mode 100644 index 000000000..189b80ce0 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetOptions.cs @@ -0,0 +1,6 @@ +namespace Tensorflow +{ + public class DatasetOptions + { + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetSource.cs b/src/TensorFlowNET.Core/Data/DatasetSource.cs new file mode 100644 index 000000000..c235fcf61 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetSource.cs @@ -0,0 +1,12 @@ +namespace Tensorflow +{ + public class DatasetSource : DatasetV2 + { + protected Tensor[] _tensors; + + public DatasetSource() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs new file mode 100644 index 000000000..c1762d670 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -0,0 +1,174 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Data; +using Tensorflow.Framework.Models; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Abstract class representing a dataset with no inputs. + /// + public class DatasetV2 : IDatasetV2 + { + protected dataset_ops ops = new dataset_ops(); + public string[] class_names { get; set; } + public Tensor variant_tensor { get; set; } + + public TensorSpec[] structure { get; set; } + + public int FirstInputTensorCount { get; set; } = 1; + + public Shape[] output_shapes => structure.Select(x => x.shape).ToArray(); + + public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); + + public TensorSpec[] element_spec => structure; + + public int length => cardinality().numpy(); + + public IDatasetV2 cache(string filename = "") + => new CacheDataset(this, filename: filename); + + public IDatasetV2 concatenate(IDatasetV2 dataset) + => new ConcatenateDataset(this, dataset); + + public IDatasetV2 take(int count = -1) + => new TakeDataset(this, count: count); + + public IDatasetV2 batch(int batch_size, bool drop_remainder = false) + => new BatchDataset(this, batch_size, drop_remainder: drop_remainder); + + public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null) + => new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period); + + public IDatasetV2 repeat(int count = -1) + => new RepeatDataset(this, count: count); + + public IDatasetV2 shard(int num_shards, int index) + => new ShardDataset(this, num_shards, index); + + public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) + => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); + + public IDatasetV2 skip(int count) + => new SkipDataset(this, count); + + public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) + => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); + + public IDatasetV2 map(Func map_func, + bool use_inter_op_parallelism = true, + bool preserve_cardinality = true, + bool use_legacy_function = false) + => new MapDataset(this, + map_func, + use_inter_op_parallelism: use_inter_op_parallelism, + preserve_cardinality: preserve_cardinality, + use_legacy_function: use_legacy_function); + + public IDatasetV2 map(Func map_func, int num_parallel_calls) + => new ParallelMapDataset(this, map_func, + num_parallel_calls: num_parallel_calls, + preserve_cardinality: true); + + public IDatasetV2 filter(Func predicate_func) + => new FilterDataset(this, predicate_func); + + public IDatasetV2 filter(Func predicate_func) + => new FilterDataset(this, predicate_func); + + public OwnedIterator make_one_shot_iterator() + { + if (tf.Context.executing_eagerly()) + { + // with ops.colocate_with(self._variant_tensor) + return new OwnedIterator(this); + } + + throw new NotImplementedException(""); + } + + public IDatasetV2 flat_map(Func map_func) + => new FlatMapDataset(this, map_func); + + public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget, long ram_budget) + => new ModelDataset(this, algorithm, cpu_budget, ram_budget); + + public IDatasetV2 with_options(DatasetOptions options) + => new OptionsDataset(this, options); + + public IDatasetV2 apply_options() + { + IDatasetV2 dataset = this; + // (1) Apply threading options + + // (2) Apply autotune options + var autotune = true; + long cpu_budget = 0; + long ram_budget = 0; + if (autotune) + dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget, ram_budget); + + // (3) Apply graph rewrite options + var graph_rewrites = new[] + { + "map_and_batch_fusion", + "map_parallelization", + "noop_elimination", + "shuffle_and_repeat_fusion" + }; + var graph_rewrite_configs = new string[] + { + "autotune_buffer_sizes:autotune:true", + "batch_parallelization:autotune:true", + "disable_prefetch_legacy_autotune:autotune:true", + "enable_gradient_descent:autotune:true", + "map_parallelization:autotune:true" + }; + + dataset = new OptimizeDataset(dataset, new string[0], new string[0], graph_rewrites, graph_rewrite_configs); + + // (4) Apply stats aggregator options + + dataset.FirstInputTensorCount = this.FirstInputTensorCount; + return dataset; + } + + public Tensor cardinality(string name = null) + => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); + + public override string ToString() + => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, " + + $"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " + + $"len: {length}"; + + public IEnumerator<(Tensors, Tensors)> GetEnumerator() + { + using var ownedIterator = new OwnedIterator(this); + + Tensor[] results = null; + while (true) + { + try + { + results = ownedIterator.next(); + } + catch (StopIteration) + { + break; + } + + yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? + null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/Datasets.cs b/src/TensorFlowNET.Core/Data/Datasets.cs new file mode 100644 index 000000000..6a4bb1ca1 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/Datasets.cs @@ -0,0 +1,43 @@ +using Tensorflow.NumPy; + +namespace Tensorflow +{ + public class Datasets where TDataSet : IDataSet + { + public TDataSet Train { get; private set; } + + public TDataSet Validation { get; private set; } + + public TDataSet Test { get; private set; } + + public Datasets(TDataSet train, TDataSet validation, TDataSet test) + { + Train = train; + Validation = validation; + Test = test; + } + + public (NDArray, NDArray) Randomize(NDArray x, NDArray y) + { + var perm = np.random.permutation((int)y.dims[0]); + np.random.shuffle(perm); + return (x[perm], y[perm]); + } + + /// + /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) + /// + /// + /// + /// + /// + /// + public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) + { + var slice = new Slice(start, end); + var x_batch = x[slice]; + var y_batch = y[slice]; + return (x_batch, y_batch); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/FilterDataset.cs b/src/TensorFlowNET.Core/Data/FilterDataset.cs new file mode 100644 index 000000000..84dfa0aea --- /dev/null +++ b/src/TensorFlowNET.Core/Data/FilterDataset.cs @@ -0,0 +1,58 @@ +using System; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that filters its input according to a predicate function. + /// + public class FilterDataset : UnaryDataset + { + public FilterDataset(IDatasetV2 input_dataset, + Func predicate_func) : base(input_dataset) + { + Func predicate_func_update = x => + { + var result = predicate_func(x); + return constant_op.constant(result); + }; + + var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = predicate_func_update(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes); + } + + public FilterDataset(IDatasetV2 input_dataset, + Func predicate_func) : base(input_dataset) + { + var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = predicate_func(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/FlatMapDataset.cs b/src/TensorFlowNET.Core/Data/FlatMapDataset.cs new file mode 100644 index 000000000..8b1872c3b --- /dev/null +++ b/src/TensorFlowNET.Core/Data/FlatMapDataset.cs @@ -0,0 +1,22 @@ +using System; +using Tensorflow.Functions; + +namespace Tensorflow +{ + /// + /// + /// + public class FlatMapDataset : UnaryDataset + { + public FlatMapDataset(IDatasetV2 input_dataset, + Func map_func) : base(input_dataset) + { + var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); + structure = func.OutputStructure; + variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/GeneratorDataset.cs b/src/TensorFlowNET.Core/Data/GeneratorDataset.cs new file mode 100644 index 000000000..b1c46d3b8 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/GeneratorDataset.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Data +{ + public class GeneratorDataset : DatasetSource + { + + } +} diff --git a/src/TensorFlowNET.Core/Data/IDataSet.cs b/src/TensorFlowNET.Core/Data/IDataSet.cs new file mode 100644 index 000000000..0ac6ee99b --- /dev/null +++ b/src/TensorFlowNET.Core/Data/IDataSet.cs @@ -0,0 +1,10 @@ +using Tensorflow.NumPy; + +namespace Tensorflow +{ + public interface IDataSet + { + NDArray Data { get; } + NDArray Labels { get; } + } +} diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs new file mode 100644 index 000000000..320cbe348 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -0,0 +1,101 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Framework.Models; + +namespace Tensorflow +{ + public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)> + { + string[] class_names { get; set; } + + Tensor variant_tensor { get; set; } + + Shape[] output_shapes { get; } + + TF_DataType[] output_types { get; } + + TensorSpec[] element_spec { get; } + + TensorSpec[] structure { get; set; } + + int FirstInputTensorCount { get; set; } + + /// + /// Caches the elements in this dataset. + /// + /// + /// + IDatasetV2 cache(string filename = ""); + + /// + /// Creates a `Dataset` by concatenating the given dataset with this dataset. + /// + /// + /// + IDatasetV2 concatenate(IDatasetV2 dataset); + + /// + /// + /// + /// + /// + IDatasetV2 repeat(int count = -1); + + /// + /// Creates a `Dataset` that includes only 1/`num_shards` of this dataset. + /// + /// The number of shards operating in parallel + /// The worker index + /// + IDatasetV2 shard(int num_shards, int index); + + IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); + + /// + /// Creates a `Dataset` that skips `count` elements from this dataset. + /// + /// + /// + IDatasetV2 skip(int count); + + IDatasetV2 batch(int batch_size, bool drop_remainder = false); + + IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); + + IDatasetV2 take(int count); + + IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); + + IDatasetV2 map(Func map_func, + bool use_inter_op_parallelism = true, + bool preserve_cardinality = true, + bool use_legacy_function = false); + + IDatasetV2 map(Func map_func, + int num_parallel_calls); + + IDatasetV2 filter(Func map_func); + IDatasetV2 filter(Func map_func); + + OwnedIterator make_one_shot_iterator(); + + IDatasetV2 flat_map(Func map_func); + + IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget, long ram_budget); + + IDatasetV2 with_options(DatasetOptions options); + + /// + /// Apply options, such as optimization configuration, to the dataset. + /// + /// + IDatasetV2 apply_options(); + + /// + /// Returns the cardinality of `dataset`, if known. + /// + /// + /// + Tensor cardinality(string name = null); + } +} diff --git a/src/TensorFlowNET.Core/Data/IModelLoader.cs b/src/TensorFlowNET.Core/Data/IModelLoader.cs new file mode 100644 index 000000000..fd94dbe34 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/IModelLoader.cs @@ -0,0 +1,10 @@ +using System.Threading.Tasks; + +namespace Tensorflow +{ + public interface IModelLoader + where TDataSet : IDataSet + { + Task> LoadAsync(ModelLoadSetting setting); + } +} diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs new file mode 100644 index 000000000..df7becc4d --- /dev/null +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -0,0 +1,37 @@ +using System; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that maps a function over elements in its input. + /// + public class MapDataset : UnaryDataset + { + public MapDataset(IDatasetV2 input_dataset, + Func map_func, + bool use_inter_op_parallelism = true, + bool preserve_cardinality = false, + bool use_legacy_function = false) : base(input_dataset) + { + var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = map_func(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + variant_tensor = ops.map_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes, + use_inter_op_parallelism: use_inter_op_parallelism, + preserve_cardinality: preserve_cardinality); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/MnistDataSet.cs b/src/TensorFlowNET.Core/Data/MnistDataSet.cs new file mode 100644 index 000000000..7e5d0cc21 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/MnistDataSet.cs @@ -0,0 +1,85 @@ +using Tensorflow.NumPy; +using System; +using System.Diagnostics; + +namespace Tensorflow +{ + public class MnistDataSet : DataSetBase + { + public int NumOfExamples { get; private set; } + public int EpochsCompleted { get; private set; } + public int IndexInEpoch { get; private set; } + + public MnistDataSet(NDArray images, NDArray labels, TF_DataType dataType, bool reshape) + { + EpochsCompleted = 0; + IndexInEpoch = 0; + + NumOfExamples = (int)images.dims[0]; + + // images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); + images = images.astype(dataType); + // for debug np.multiply performance + var sw = new Stopwatch(); + sw.Start(); + images = np.multiply(images, 1.0f / 255.0f); + sw.Stop(); + Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms"); + Data = images; + + labels = labels.astype(dataType); + Labels = labels; + } + + public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) + { + if (IndexInEpoch >= NumOfExamples) + IndexInEpoch = 0; + + var start = IndexInEpoch; + // Shuffle for the first epoch + if (EpochsCompleted == 0 && start == 0 && shuffle) + { + var perm0 = np.arange(NumOfExamples); + np.random.shuffle(perm0); + Data = Data[perm0]; + Labels = Labels[perm0]; + } + + // Go to the next epoch + if (start + batch_size > NumOfExamples) + { + // Finished epoch + EpochsCompleted += 1; + + // Get the rest examples in this epoch + var rest_num_examples = NumOfExamples - start; + var images_rest_part = Data[np.arange(start, NumOfExamples)]; + var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; + // Shuffle the data + if (shuffle) + { + var perm = np.arange(NumOfExamples); + np.random.shuffle(perm); + Data = Data[perm]; + Labels = Labels[perm]; + } + + start = 0; + IndexInEpoch = batch_size - rest_num_examples; + var end = IndexInEpoch; + var images_new_part = Data[np.arange(start, end)]; + var labels_new_part = Labels[np.arange(start, end)]; + + return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), + np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); + } + else + { + IndexInEpoch += batch_size; + var end = IndexInEpoch; + return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs new file mode 100644 index 000000000..c8b9fa30f --- /dev/null +++ b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs @@ -0,0 +1,178 @@ +using Tensorflow.NumPy; +using System; +using System.IO; +using System.Threading.Tasks; + +namespace Tensorflow +{ + public class MnistModelLoader : IModelLoader + { + private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; + private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; + private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; + private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; + + public async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) + { + var setting = new ModelLoadSetting + { + TrainDir = trainDir, + OneHot = oneHot, + ShowProgressInConsole = showProgressInConsole + }; + + if (trainSize.HasValue) + setting.TrainSize = trainSize.Value; + + if (validationSize.HasValue) + setting.ValidationSize = validationSize.Value; + + if (testSize.HasValue) + setting.TestSize = testSize.Value; + + return await LoadAsync(setting); + } + + public async Task> LoadAsync(ModelLoadSetting setting) + { + if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) + throw new ArgumentException("Validation set should be smaller than training set"); + + var sourceUrl = setting.SourceUrl; + + if (string.IsNullOrEmpty(sourceUrl)) + sourceUrl = DEFAULT_SOURCE_URL; + + // load train images + await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); + + // load train labels + await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); + + // load test images + await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); + + // load test labels + await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); + + var end = trainImages.dims[0]; + + var validationSize = setting.ValidationSize; + + var validationImages = trainImages[np.arange(validationSize)]; + var validationLabels = trainLabels[np.arange(validationSize)]; + + trainImages = trainImages[np.arange(validationSize, (int)end)]; + trainLabels = trainLabels[np.arange(validationSize, (int)end)]; + + var dtype = setting.DataType; + var reshape = setting.ReShape; + + var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); + var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); + var test = new MnistDataSet(testImages, testLabels, dtype, reshape); + + return new Datasets(train, validation, test); + } + + private NDArray ExtractImages(string file, int? limit = null) + { + if (!Path.IsPathRooted(file)) + file = Path.Combine(AppContext.BaseDirectory, file); + + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2051) + throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); + + var num_images = Read32(bytestream); + num_images = limit == null ? num_images : Math.Min(num_images, (int)limit); + + var rows = Read32(bytestream); + var cols = Read32(bytestream); + + var buf = new byte[rows * cols * num_images]; + + bytestream.Read(buf, 0, buf.Length); + + var data = np.frombuffer(buf, (num_images, rows * cols), np.uint8); + return data; + } + } + + private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) + { + if (!Path.IsPathRooted(file)) + file = Path.Combine(AppContext.BaseDirectory, file); + + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2049) + throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); + + var num_items = Read32(bytestream); + num_items = limit == null ? num_items : Math.Min(num_items, (int)limit); + + var buf = new byte[num_items]; + + bytestream.Read(buf, 0, buf.Length); + + var labels = np.frombuffer(buf, new Shape(num_items), np.uint8); + + if (one_hot) + return DenseToOneHot(labels, num_classes); + + return labels; + } + } + + private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) + { + var num_labels = (int)labels_dense.dims[0]; + // var index_offset = np.arange(num_labels) * num_classes; + var labels_one_hot = np.zeros((num_labels, num_classes)); + var labels = labels_dense.ToArray(); + for (int row = 0; row < num_labels; row++) + { + var col = labels[row]; + labels_one_hot[row, col] = 1.0; + } + + return labels_one_hot; + } + + private int Read32(FileStream bytestream) + { + var buffer = new byte[sizeof(uint)]; + var count = bytestream.Read(buffer, 0, 4); + return np.frombuffer(buffer, ">u4").ToArray()[0]; + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ModelDataset.cs b/src/TensorFlowNET.Core/Data/ModelDataset.cs new file mode 100644 index 000000000..1b01788c4 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ModelDataset.cs @@ -0,0 +1,24 @@ +using Tensorflow.Framework.Models; + +namespace Tensorflow +{ + /// + /// A `Dataset` that acts as an identity, and models performance. + /// + public class ModelDataset : UnaryUnchangedStructureDataset + { + public ModelDataset(IDatasetV2 input_dataset, + AutotuneAlgorithm algorithm, + long cpu_budget, + long ram_budget) : + base(input_dataset) + { + variant_tensor = ops.model_dataset(input_dataset.variant_tensor, + output_types, + output_shapes, + algorithm, + cpu_budget, + ram_budget); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ModelLoadSetting.cs b/src/TensorFlowNET.Core/Data/ModelLoadSetting.cs new file mode 100644 index 000000000..11f6928f5 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ModelLoadSetting.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class ModelLoadSetting + { + public string TrainDir { get; set; } + public bool OneHot { get; set; } + public TF_DataType DataType { get; set; } = TF_DataType.TF_FLOAT; + public bool ReShape { get; set; } + public int ValidationSize { get; set; } = 5000; + public int? TrainSize { get; set; } + public int? TestSize { get; set; } + public string SourceUrl { get; set; } + public bool ShowProgressInConsole { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Data/OptimizeDataset.cs b/src/TensorFlowNET.Core/Data/OptimizeDataset.cs new file mode 100644 index 000000000..56f36388a --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OptimizeDataset.cs @@ -0,0 +1,40 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that acts as an identity, and applies optimizations. + /// + public class OptimizeDataset : UnaryUnchangedStructureDataset + { + public OptimizeDataset(IDatasetV2 dataset, + string[] optimizations_enabled = null, + string[] optimizations_disabled = null, + string[] optimizations_default = null, + string[] optimization_configs = null) : + base(dataset) + { + if (optimizations_enabled == null) + optimizations_enabled = new string[0]; + if (optimizations_disabled == null) + optimizations_disabled = new string[0]; + if (optimizations_default == null) + optimizations_default = new string[0]; + if (optimization_configs == null) + optimization_configs = new string[0]; + + var _optimizations_enabled = tf.convert_to_tensor(optimizations_enabled, dtype: TF_DataType.TF_STRING, name: "optimizations_enabled"); + var _optimizations_disabled = tf.convert_to_tensor(optimizations_disabled, dtype: TF_DataType.TF_STRING, name: "optimizations_disabled"); + var _optimizations_default = tf.convert_to_tensor(optimizations_default, dtype: TF_DataType.TF_STRING, name: "optimizations_default"); + + variant_tensor = ops.optimize_dataset_v2( + _input_dataset.variant_tensor, + _optimizations_enabled, + _optimizations_disabled, + _optimizations_default, + output_types, + output_shapes, + optimization_configs: optimization_configs); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/OptionsDataset.cs b/src/TensorFlowNET.Core/Data/OptionsDataset.cs new file mode 100644 index 000000000..ae63814fb --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OptionsDataset.cs @@ -0,0 +1,17 @@ +namespace Tensorflow +{ + /// + /// An identity `Dataset` that stores options. + /// + public class OptionsDataset : UnaryUnchangedStructureDataset + { + DatasetOptions options; + + public OptionsDataset(IDatasetV2 input_dataset, DatasetOptions options) + : base(input_dataset) + { + this.options = options; + variant_tensor = input_dataset.variant_tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs new file mode 100644 index 000000000..6f6fd0b58 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -0,0 +1,54 @@ +using System; +using System.Linq; +using Tensorflow.Framework.Models; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// An iterator producing tf.Tensor objects from a tf.data.Dataset. + /// + public class OwnedIterator : IDisposable + { + IDatasetV2 _dataset; + TensorSpec[] _element_spec; + dataset_ops ops = new dataset_ops(); + //Tensor _deleter; + Tensor _iterator_resource; + + public OwnedIterator(IDatasetV2 dataset) + { + _create_iterator(dataset); + } + + void _create_iterator(IDatasetV2 dataset) + { + dataset = dataset.apply_options(); + _dataset = dataset; + _element_spec = dataset.element_spec; + _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); + // TODO(Rinne): deal with graph mode. + ops.make_iterator(dataset.variant_tensor, _iterator_resource); + } + + public Tensor[] next() + { + try + { + var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + foreach(var (i, tensor) in enumerate(results)) + tensor.shape = _element_spec[i].shape; + return results; + } + catch (OutOfRangeError ex) + { + throw new StopIteration(ex.Message); + } + } + + public void Dispose() + { + //tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs new file mode 100644 index 000000000..6deb30bd2 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs @@ -0,0 +1,40 @@ +using System; +using System.Linq; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + //A `Dataset` that maps a function over elements in its input in parallel. + public class ParallelMapDataset : UnaryDataset + { + public ParallelMapDataset(IDatasetV2 input_dataset, + Func map_func, + int num_parallel_calls = -1, + bool use_inter_op_parallelism = true, + bool preserve_cardinality = false, + bool use_legacy_function = false) : base(input_dataset) + { + var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = map_func(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, + name: "num_parallel_calls"); + variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, + _num_parallel_calls, + func, + output_types, + output_shapes, + use_inter_op_parallelism: use_inter_op_parallelism, + preserve_cardinality: preserve_cardinality); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/PrefetchDataset.cs b/src/TensorFlowNET.Core/Data/PrefetchDataset.cs new file mode 100644 index 000000000..826b5ffa4 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/PrefetchDataset.cs @@ -0,0 +1,24 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Creates a `Dataset` that prefetches elements from this dataset. + /// + public class PrefetchDataset : UnaryUnchangedStructureDataset + { + public PrefetchDataset(IDatasetV2 input_dataset, + long buffer_size = -1, + int? slack_period = null) : + base(input_dataset) + { + var buffer_size_tensor = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size"); + + variant_tensor = ops.prefetch_dataset(input_dataset.variant_tensor, + buffer_size_tensor, + input_dataset.output_types, + input_dataset.output_shapes, + slack_period: slack_period); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/RangeDataset.cs b/src/TensorFlowNET.Core/Data/RangeDataset.cs new file mode 100644 index 000000000..e3e027669 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/RangeDataset.cs @@ -0,0 +1,21 @@ +using Tensorflow.Framework.Models; +using static Tensorflow.Binding; + +namespace Tensorflow.Data +{ + public class RangeDataset : DatasetSource + { + public RangeDataset(int stop, + int start = 0, + int step = 1, + TF_DataType output_type = TF_DataType.TF_INT64) + { + var start_tensor = tf.convert_to_tensor((long)start); + var step_tensor = tf.convert_to_tensor((long)step); + var stop_tensor = tf.convert_to_tensor((long)stop); + + structure = new TensorSpec[] { new TensorSpec(new int[0], dtype: output_type) }; + variant_tensor = ops.range_dataset(start_tensor, stop_tensor, step_tensor, output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/RepeatDataset.cs b/src/TensorFlowNET.Core/Data/RepeatDataset.cs new file mode 100644 index 000000000..7cd46452d --- /dev/null +++ b/src/TensorFlowNET.Core/Data/RepeatDataset.cs @@ -0,0 +1,18 @@ +namespace Tensorflow +{ + /// + /// A `Dataset` that repeats its input several times. + /// + public class RepeatDataset : UnaryUnchangedStructureDataset + { + public RepeatDataset(IDatasetV2 input_dataset, int count = -1) : + base(input_dataset) + { + var count_tensor = constant_op.constant(count, dtype: TF_DataType.TF_INT64, name: "count"); + variant_tensor = ops.repeat_dataset(input_dataset.variant_tensor, + count_tensor, + input_dataset.output_types, + input_dataset.output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ShardDataset.cs b/src/TensorFlowNET.Core/Data/ShardDataset.cs new file mode 100644 index 000000000..673fe2c4c --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ShardDataset.cs @@ -0,0 +1,28 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` for sharding its input. + /// + public class ShardDataset : UnaryUnchangedStructureDataset + { + Tensor _num_shards; + Tensor _index; + + public ShardDataset(IDatasetV2 input_dataset, + int num_shards, + int index) : base(input_dataset) + { + _num_shards = tf.convert_to_tensor(num_shards, dtype: TF_DataType.TF_INT64, name: "num_shards"); + _index = tf.convert_to_tensor(index, dtype: TF_DataType.TF_INT64, name: "index"); + + variant_tensor = ops.shard_dataset + (input_dataset.variant_tensor, + num_shards: _num_shards, + index: _index, + input_dataset.output_types, + input_dataset.output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ShuffleDataset.cs b/src/TensorFlowNET.Core/Data/ShuffleDataset.cs new file mode 100644 index 000000000..8d22ab919 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ShuffleDataset.cs @@ -0,0 +1,35 @@ +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Randomly shuffles the elements of this dataset. + /// + public class ShuffleDataset : UnaryUnchangedStructureDataset + { + Tensor _buffer_size; + Tensor _seed; + Tensor _seed2; + bool _reshuffle_each_iteration; + + public ShuffleDataset(IDatasetV2 input_dataset, + long buffer_size, + int? seed = null, + bool reshuffle_each_iteration = true) : + base(input_dataset) + { + _buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size"); + (_seed, _seed2) = random_seed.get_seed_tensor(seed); + _reshuffle_each_iteration = reshuffle_each_iteration; + var seed_generator = ops.dummy_seed_generator(); + if (tf.Context.executing_eagerly()) + variant_tensor = ops.shuffle_dataset_v3(input_dataset.variant_tensor, _buffer_size, + _seed, _seed2, seed_generator, + output_types, output_shapes, + reshuffle_each_iteration: _reshuffle_each_iteration); + else + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/SkipDataset.cs b/src/TensorFlowNET.Core/Data/SkipDataset.cs new file mode 100644 index 000000000..48746f02b --- /dev/null +++ b/src/TensorFlowNET.Core/Data/SkipDataset.cs @@ -0,0 +1,21 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` skipping the first `count` elements from its input. + /// + public class SkipDataset : UnaryUnchangedStructureDataset + { + Tensor _count; + + public SkipDataset(IDatasetV2 input_dataset, + int count) : base(input_dataset) + { + _count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); + variant_tensor = ops.skip_dataset(input_dataset.variant_tensor, + _count, + output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/TakeDataset.cs b/src/TensorFlowNET.Core/Data/TakeDataset.cs new file mode 100644 index 000000000..6c4a49f37 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/TakeDataset.cs @@ -0,0 +1,17 @@ +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class TakeDataset : UnaryUnchangedStructureDataset + { + Tensor _count; + + public TakeDataset(IDatasetV2 input_dataset, int count) : + base(input_dataset) + { + _count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); + variant_tensor = ops.take_dataset(input_dataset.variant_tensor, _count, + output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorDataset.cs b/src/TensorFlowNET.Core/Data/TensorDataset.cs new file mode 100644 index 000000000..0ac2eeaa1 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/TensorDataset.cs @@ -0,0 +1,29 @@ +using Tensorflow.NumPy; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` with a single element. + /// + public class TensorDataset : DatasetSource + { + public TensorDataset(Tensors elements) + { + _tensors = elements; + structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); + + variant_tensor = ops.tensor_dataset(_tensors, output_shapes); + } + + public TensorDataset(NDArray element) + { + _tensors = new[] { tf.convert_to_tensor(element) }; + var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); + structure = batched_spec.ToArray(); + + variant_tensor = ops.tensor_dataset(_tensors, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs new file mode 100644 index 000000000..f9d6ea747 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs @@ -0,0 +1,47 @@ +using Tensorflow.NumPy; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Data +{ + public class TensorSliceDataset : DatasetSource + { + public TensorSliceDataset(string[] array) + { + var element = tf.constant(array); + _tensors = new[] { element }; + var batched_spec = new[] { element.ToTensorSpec() }; + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); + } + + public TensorSliceDataset(NDArray array) + { + var element = tf.constant(array); + _tensors = new[] { element }; + var batched_spec = new[] { element.ToTensorSpec() }; + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); + } + + public TensorSliceDataset(Tensor tensor) + { + _tensors = new[] { tensor }; + var batched_spec = new[] { tensor.ToTensorSpec() }; + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); + } + + public TensorSliceDataset(Tensor features, Tensor labels) + { + _tensors = new[] { features, labels }; + var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/UnaryDataset.cs b/src/TensorFlowNET.Core/Data/UnaryDataset.cs new file mode 100644 index 000000000..8a95b00ec --- /dev/null +++ b/src/TensorFlowNET.Core/Data/UnaryDataset.cs @@ -0,0 +1,16 @@ +namespace Tensorflow +{ + /// + /// Abstract class representing a dataset with one input. + /// + public class UnaryDataset : DatasetV2 + { + protected IDatasetV2 _input_dataset; + + public UnaryDataset(IDatasetV2 input_dataset) + { + _input_dataset = input_dataset; + structure = input_dataset.structure; + } + } +} diff --git a/src/TensorFlowNET.Core/Data/UnaryUnchangedStructureDataset.cs b/src/TensorFlowNET.Core/Data/UnaryUnchangedStructureDataset.cs new file mode 100644 index 000000000..31b718f35 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/UnaryUnchangedStructureDataset.cs @@ -0,0 +1,14 @@ +namespace Tensorflow +{ + /// + /// Represents a unary dataset with the same input and output structure. + /// + public class UnaryUnchangedStructureDataset : UnaryDataset + { + public UnaryUnchangedStructureDataset(IDatasetV2 input_dataset) : + base(input_dataset) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Data/Utils.cs b/src/TensorFlowNET.Core/Data/Utils.cs new file mode 100644 index 000000000..082a9a68d --- /dev/null +++ b/src/TensorFlowNET.Core/Data/Utils.cs @@ -0,0 +1,135 @@ +using System; +using System.IO; +using System.IO.Compression; +using System.Net; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow +{ + public static class Utils + { + public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string saveTo) + where TDataSet : IDataSet + { + var dir = Path.GetDirectoryName(saveTo); + var fileName = Path.GetFileName(saveTo); + await modelLoader.DownloadAsync(url, dir, fileName); + } + + public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) + where TDataSet : IDataSet + { + if (!Path.IsPathRooted(dirSaveTo)) + dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo); + + var fileSaveTo = Path.Combine(dirSaveTo, fileName); + + if (showProgressInConsole) + { + Binding.tf_output_redirect.WriteLine($"Downloading {fileName}"); + } + + if (File.Exists(fileSaveTo)) + { + if (showProgressInConsole) + { + Binding.tf_output_redirect.WriteLine($"The file {fileName} already exists"); + } + + return; + } + + Directory.CreateDirectory(dirSaveTo); + + using (var wc = new WebClient()) + { + await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); + } + + } + + public static async Task UnzipAsync(this IModelLoader modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) + where TDataSet : IDataSet + { + if (!Path.IsPathRooted(saveTo)) + saveTo = Path.Combine(AppContext.BaseDirectory, saveTo); + + Directory.CreateDirectory(saveTo); + + if (!Path.IsPathRooted(zipFile)) + zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); + + var destFileName = Path.GetFileNameWithoutExtension(zipFile); + var destFilePath = Path.Combine(saveTo, destFileName); + + if (showProgressInConsole) + Binding.tf_output_redirect.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); + + if (File.Exists(destFilePath)) + { + if (showProgressInConsole) + Binding.tf_output_redirect.WriteLine($"The file {destFileName} already exists"); + } + + using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) + { + using (var destStream = File.Create(destFilePath)) + { + await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); + await destStream.FlushAsync().ConfigureAwait(false); + destStream.Close(); + } + + unzipStream.Close(); + } + } + + public static async Task ShowProgressInConsole(this Task task, bool enable) + { + if (!enable) + { + await task; + return; + } + + var cts = new CancellationTokenSource(); + + var showProgressTask = ShowProgressInConsole(cts); + + try + { + await task; + } + finally + { + cts.Cancel(); + } + + await showProgressTask; + Binding.tf_output_redirect.WriteLine("Done."); + } + + private static async Task ShowProgressInConsole(CancellationTokenSource cts) + { + var cols = 0; + + await Task.Delay(100); + + while (!cts.IsCancellationRequested) + { + await Task.Delay(100); + Binding.tf_output_redirect.Write("."); + cols++; + + if (cols % 50 == 0) + { + Binding.tf_output_redirect.WriteLine(); + } + } + + if (cols > 0) + Binding.tf_output_redirect.WriteLine(); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ZipDataset.cs b/src/TensorFlowNET.Core/Data/ZipDataset.cs new file mode 100644 index 000000000..888948f80 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ZipDataset.cs @@ -0,0 +1,22 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework.Models; + +namespace Tensorflow +{ + public class ZipDataset : DatasetV2 + { + // keep all dataset references + IDatasetV2[] _inputs; + public ZipDataset(params IDatasetV2[] ds) + { + _inputs = ds; + var input_datasets = ds.Select(x => x.variant_tensor).ToArray(); + var _structure = new List(); + foreach (var dataset in ds) + _structure.AddRange(dataset.structure); + structure = _structure.ToArray(); + variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Debugging/DebugImpl.cs b/src/TensorFlowNET.Core/Debugging/DebugImpl.cs new file mode 100644 index 000000000..816273514 --- /dev/null +++ b/src/TensorFlowNET.Core/Debugging/DebugImpl.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Debugging +{ + public class DebugImpl + { + /// + /// Set if device placements should be logged. + /// + /// Whether to enabled device placement logging. + public void set_log_device_placement(bool enabled) + => tf.Context.log_device_placement(enabled); + + /// + /// Assert the condition `x == y` holds element-wise. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor assert_equal(T1 t1, + T2 t2, + object[] data = null, + string message = null, + string name = null) + => check_ops.assert_equal(t1, + t2, + data: data, + message: message, + name: name); + + public Tensor assert_greater_equal(Tensor x, + Tensor y, + object[] data = null, + string message = null, + string name = null) + => check_ops.assert_greater_equal(x, + y, + data: data, + message: message, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceSpec.cs b/src/TensorFlowNET.Core/Device/DeviceSpec.cs new file mode 100644 index 000000000..255191cb5 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceSpec.cs @@ -0,0 +1,206 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Device +{ + public class DeviceSpec + { + private static ConcurrentDictionary _STRING_TO_COMPONENTS_CACHE = new(); + private static ConcurrentDictionary _COMPONENTS_TO_STRING_CACHE = new(); + private string _job; + private int _replica; + private int _task; + private string _device_type; + private int _device_index; + private string _as_string; + + public string Job => _job; + public int Replica => _replica; + public int Task => _task; + public string DeviceType => _device_type; + public int DeviceIndex => _device_index; + + public DeviceSpec(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + _job = job; + _replica = replica; + _task = task; + _device_type = device_type; + _device_index = device_index; + _as_string = _components_to_string(job, replica, task, device_type, _device_index); + + } + + public DeviceSpec(DeviceSpec other) + { + _job = other._job; + _replica = other._replica; + _task = other._task; + _device_type = other._device_type; + _device_index = other._device_index; + _as_string = other._as_string; + } + + protected DeviceSpec(Components com) + { + _job = com.Job; + _replica = com.Replica; + _task = com.Task; + _device_type = com.DeviceType; + _device_index = com.DeviceIndex; + _as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); + } + + public DeviceSpec replace(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + job = job ?? _job; + replica = replica == -1 ? _replica : replica; + task = task == -1 ? _task : task; + device_type = device_type ?? _device_type; + device_index = device_index == -1 ? _device_index : device_index; + return new DeviceSpec(job, replica, task, device_type, device_index); + } + + public static DeviceSpec from_string(string spec) + { + var components = _string_to_components(spec); + return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); + } + + public DeviceSpec make_merged_spec(DeviceSpec dev) + { + return new DeviceSpec(_get_combined_properties(dev)); + } + + private Components _get_combined_properties(DeviceSpec dev) + { + return new Components( + dev.Job ?? _job, + dev.Replica == -1 ? _replica : dev.Replica, + dev.Task == -1 ? _task : dev.Task, + dev.DeviceType ?? _device_type, + dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex + ); + } + + private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) + { + var key = new Components(job, replica, task, device_type, device_index); + if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) + { + return cache_result; + } + + StringBuilder output = new(); + if(job is not null) + { + output.Append($"/job:{job}"); + } + if(replica != -1) + { + output.Append($"/replica:{replica}"); + } + if(task != -1) + { + output.Append($"/task:{task}"); + } + if (device_type is not null) + { + string device_index_string = "*"; + if (device_index != -1) + { + device_index_string = device_index.ToString(); + } + output.Append($"/device:{device_type}:{device_index_string}"); + } + var result = output.ToString(); + _COMPONENTS_TO_STRING_CACHE[key] = result; + return result; + } + + private static Components _string_to_components(string spec) + { + if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) + { + return cached_result; + } + var raw_spec = spec; + var splits = spec.Split('/').Select(x => x.Split(':')); + var valid_device_types = _get_valid_device_types(); + string job = null, device_type = null; + int replica = -1, task = -1, device_index = -1; + foreach (var y in splits) + { + var ly = y.Length; + if (ly > 0) + { + if(ly == 2 && y[0] == "job") + { + job = y[1]; + } + else if(ly == 2 && y[0] == "replica") + { + replica = int.Parse(y[1]); + } + else if(ly == 2 && y[0] == "task") + { + task = int.Parse(y[1]); + } + else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) + { + if (device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[0].ToUpper(); + if(ly == 2 && y[1] != "*") + { + device_index = int.Parse(y[1]); + } + } + else if(ly == 3 && y[0] == "device") + { + if(device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[1]; + if (y[2] != "*") + { + device_index = int.Parse(y[2]); + } + } + else if (y[0] != "") + { + throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + + $"while parsing the device spec: {spec}."); + } + } + } + + var output = new Components(job, replica, task, device_type, device_index); + _STRING_TO_COMPONENTS_CACHE[raw_spec] = output; + return output; + } + + private static HashSet _get_valid_device_types() + { + // TODO(Rinne): revise it to calling C API (need customized API). + return new HashSet(new string[] { "CPU", "GPU" }); + } + + public override string ToString() + { + return _as_string; + } + + protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceUtils.cs b/src/TensorFlowNET.Core/Device/DeviceUtils.cs new file mode 100644 index 000000000..8f11e6c8a --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceUtils.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Device +{ + internal static class DeviceUtils + { + public static string canonical_name(string device) + { + if(device is null) + { + return ""; + } + return DeviceSpec.from_string(device).ToString(); + } + public static string canonical_name(DeviceSpec device) + { + if (device is null) + { + return ""; + } + return device.ToString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Device/PhysicalDevice.cs b/src/TensorFlowNET.Core/Device/PhysicalDevice.cs new file mode 100644 index 000000000..3f215d12f --- /dev/null +++ b/src/TensorFlowNET.Core/Device/PhysicalDevice.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Device +{ + public class PhysicalDevice + { + public string DeviceName { get; set; } + public string DeviceType { get; set; } + + public override string ToString() + => $"{DeviceType}: {DeviceName}"; + } +} diff --git a/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs new file mode 100644 index 000000000..86e2a4fd4 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/SafeDeviceListHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Device +{ + public sealed class SafeDeviceListHandle : SafeTensorflowHandle + { + private SafeDeviceListHandle() + { + } + + public SafeDeviceListHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteDeviceList(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Device/c_api.device.cs b/src/TensorFlowNET.Core/Device/c_api.device.cs new file mode 100644 index 000000000..bd2d12959 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/c_api.device.cs @@ -0,0 +1,96 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; +using Tensorflow.Device; +using Tensorflow.Eager; +using Tensorflow.Util; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Specify the device for `desc`. Defaults to empty, meaning unconstrained. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetDevice(IntPtr desc, string device); + + /// + /// Counts the number of elements in the device list. + /// + /// TF_DeviceList* + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_DeviceListCount(SafeDeviceListHandle list); + + /// + /// Retrieves the type of the device at the given index. + /// + /// TF_DeviceList* + /// int + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status); + + /// + /// Deallocates the device list. + /// + /// TF_DeviceList* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteDeviceList(IntPtr list); + + /// + /// Create a new TFE_TensorHandle with the same contents as 'h' but placed + /// in the memory of the device name 'device_name'. + /// + /// TFE_TensorHandle* + /// TFE_Context* + /// char* + /// TF_Status* + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); + + /// + /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) + /// + /// TF_DeviceList* + /// + /// TF_Status* + public static string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) + { + using var _ = list.Lease(); + return StringPiece(TF_DeviceListNameImpl(list, index, status)); + } + + /// + /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) + /// The return value will be a pointer to a null terminated string. The caller + /// must not modify or delete the string. It will be deallocated upon a call to + /// TF_DeleteDeviceList. + /// + /// TF_DeviceList* + /// + /// TF_Status* + [DllImport(TensorFlowLibName, EntryPoint = "TF_DeviceListName")] + private static extern IntPtr TF_DeviceListNameImpl(SafeDeviceListHandle list, int index, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs new file mode 100644 index 000000000..c3c677fff --- /dev/null +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -0,0 +1,161 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Tensorflow.Train; + +namespace Tensorflow +{ + /// + /// Abstract class for disposable object allocated in unmanaged runtime. + /// https://docs.microsoft.com/en-us/dotnet/api/system.idisposable.dispose?redirectedfrom=MSDN&view=net-5.0#System_IDisposable_Dispose + /// + public abstract class DisposableObject : IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableObject() + { } + + protected DisposableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableObject() + { + Dispose(false); + } + } + + public abstract class DisposableTrackableObject: Trackable, IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableTrackableObject() + { } + + protected DisposableTrackableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableTrackableObject() + { + Dispose(false); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs deleted file mode 100644 index 3d9c875d4..000000000 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ /dev/null @@ -1,16 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow.Eager -{ - public class Context - { - public static int GRAPH_MODE = 0; - public static int EAGER_MODE = 1; - - public int default_execution_mode; - - - } -} diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs new file mode 100644 index 000000000..3664f1875 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -0,0 +1,63 @@ +using System; + +namespace Tensorflow.Eager +{ + public class EagerOperation : Operation + { + public string Name { get; set; } + public new int NumInputs; + public IntPtr[] InputHandles { get; set; } + public Tensor[] Inputs { get; set; } + public new int NumOutputs; + public IntPtr[] OutputHandles { get; set; } + public Tensor[] Outputs { get; set; } + public long[] SkipInputIndices { get; set; } + public object[] Attrs { get; set; } + + public EagerOperation() : base(IntPtr.Zero) + { + + } + + public override InputList inputs + { + get + { + if (_inputs_val == null) + { + _inputs_val = new InputList(Inputs); + } + + return _inputs_val; + } + } + + public override Tensor[] outputs + { + get + { + if (_outputs == null) + { + _outputs = Outputs; + } + + return _outputs; + } + } + + public override object get_attr(string attr_name) + { + // var attrType = c_api.TFE_OpNameGetAttrType(tf.Context.Handle, Name, attr_name, ref isList, tf.Status.Handle); + for (int i = 0; i < Attrs.Length; i = i + 2) + { + if (Attrs[i].Equals(attr_name)) + return Attrs[i + 1]; + } + + return null; + } + + public override string ToString() + => $"tf.EagerOperation {Name}"; + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.ArgsToMatchingEager.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.ArgsToMatchingEager.cs new file mode 100644 index 000000000..8a1da87af --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.ArgsToMatchingEager.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + public (TF_DataType, Tensor[]) ArgsToMatchingEager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) + { + if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) + return (default_dtype, null); + + if (args.Count(x => x is Tensor) == args.Length) + return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray()); + + var dtype = TF_DataType.DtInvalid; + foreach (var x in args) + { + if (x is Tensor et) + dtype = et.dtype; + } + + if (dtype == TF_DataType.DtInvalid) + { + var ret = new List(); + foreach (var t in args) + { + ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor); + if (dtype == TF_DataType.DtInvalid) + dtype = ret.Last().dtype; + } + + return (dtype, ret.ToArray()); + } + else + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.Execute.cs new file mode 100644 index 000000000..690d5a9a1 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.Execute.cs @@ -0,0 +1,64 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + /// + /// python\eager\pywrap_tfe_src.cc + /// + public partial class EagerRunner + { + /// + /// Execute a TensorFlow operation. + /// + /// + /// Name of the TensorFlow operation (see REGISTER_OP in C++ code) to + /// execute. + /// + /// + /// The number of outputs of the operation to fetch. + /// + /// + /// A list of inputs to the operation. Each entry should be a Tensor, or + /// a value which can be passed to the Tensor constructor to create one. + /// + /// + /// A tuple with alternating string attr names and attr values for this + /// operation. + /// + /// The value of context.context(). + /// Customized name for the operation. + /// List of output Tensor objects. The list is empty if there are no outputs + public Tensor[] Execute(Context ctx, string op_name, int num_outputs, + Tensor[] inputs, object[] attrs, + string name = null) + { + ctx.ensure_initialized(); + + var results = tf.Runner.TFE_Execute(ctx, + ctx.DeviceName, + op_name, + inputs, + attrs, + num_outputs); + + return results; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs new file mode 100644 index 000000000..333827037 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs @@ -0,0 +1,47 @@ +using System; +using Tensorflow.Gradients; +using static Tensorflow.Binding; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + public bool MustRecordGradient() + { + return HasGradientTape(); + } + + public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) + { + var tape_set = tf.GetTapeSet(); + var input_ids = MakeTensorIDList(tensors); + var input_dtypes = MakeTensorDtypeList(tensors); + bool some_tape_watching = false; + if (tape_set is not null && tape_set.Count > 0) + { + foreach (var tape in tape_set) + { + if (tape.ShouldRecord(input_ids, input_dtypes)) + { + if (tape.Persistent || some_tape_watching) + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; + } + some_tape_watching = true; + } + } + } + // skip the forward_accumulators. + + if (some_tape_watching) + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; + } + else + { + return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs new file mode 100644 index 000000000..2bdd65f5b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -0,0 +1,148 @@ +using System; +using System.Linq; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + public bool RecordGradient(string op_name, + Tensor[] inputs, + object[] attrs, + Tensor[] results, + BackwardFunction backwardFunction = null) + { + var input_ids = MakeTensorIDList(inputs); + var input_dtypes = MakeTensorDtypeList(inputs); + bool should_record = false; + foreach (var tape in tf.GetTapeSet()) + { + if (tape.ShouldRecord(input_ids, input_dtypes)) + { + should_record = true; + break; + } + } + + if (!should_record) + { + /*for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) + { + if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) + { + should_record = true; + break; + } + }*/ + } + + if (!should_record) return should_record; + // tf.Logger.Debug($"RecordGradient: op_name={op_name}"); + + /*Tensor[] op_outputs = null; + var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name); + if (unused_output_indices != null) + { + if (unused_output_indices.Length == 0) + op_outputs = new Tensor[0]; + else + { + // op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices); + } + } + else + op_outputs = results; + + Tensor[] op_inputs = null; + var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name); + if (unused_input_indices != null) + { + if (unused_input_indices.Length == 0) + op_inputs = new Tensor[0]; + else + { + // op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); + } + } + else + op_inputs = inputs;*/ + + backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); + TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); + + return true; + } + + BackwardFunction GetGradientFunction(string op_name, + Tensor[] op_inputs, + object[] attrs, + Tensor[] op_outputs) + => (out_grads, unneeded_gradients) => + { + if(!ops.gradientFunctions.ContainsKey(op_name)) + { + throw new Exception($"gradientFunctions not find op_name: {op_name}"); + } + + if (ops.gradientFunctions[op_name] == null) + return new Tensor[op_inputs.Length]; + + var oper = new EagerOperation + { + Name = op_name, + NumInputs = op_inputs.Length, + Inputs = op_inputs, + NumOutputs = op_outputs.Length, + Outputs = op_outputs, + SkipInputIndices = unneeded_gradients, + Attrs = attrs + }; + + /*return op_name switch + { + "Add" => math_grad._AddGrad(oper, out_grads), + "AddV2" => math_grad._AddV2Grad(oper, out_grads), + "BiasAdd" => nn_grad._BiasAddGrad(oper, out_grads), + "Cast" => math_grad._CastGrad(oper, out_grads), + "ConcatV2" => array_grad._ConcatV2Grad(oper, out_grads), + "Conv2D" => nn_grad._Conv2DGrad(oper, out_grads), + "ExpandDims" => array_grad._ExpandDimsGrad(oper, out_grads), + "Exp" => math_grad._ExpGrad(oper, out_grads), + "FusedBatchNormV3" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), + "Id" => math_grad._IdGrad(oper, out_grads), + "LeakyRelu" => nn_grad._LeakyReluGrad(oper, out_grads), + "Log1p" => math_grad._Log1pGrad(oper, out_grads), + "Maximum" => math_grad._MaximumGrad(oper, out_grads), + "Mean" => math_grad._MeanGrad(oper, out_grads), + "Minimum" => math_grad._MinimumGrad(oper, out_grads), + "Mul" => math_grad._MulGrad(oper, out_grads), + "Neg" => math_grad._NegGrad(oper, out_grads), + "Pad" => array_grad._PadGrad(oper, out_grads), + "Pow" => math_grad._PowGrad(oper, out_grads), + "RealDiv" => math_grad._RealDivGrad(oper, out_grads), + "Read" => resource_variable_grad._ReadGrad(oper, out_grads), + "Reshape" => array_grad._ReshapeGrad(oper, out_grads), + "ResizeNearestNeighbor" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), + "Select" => math_grad._SelectGrad(oper, out_grads), + "Sigmoid" => math_grad._SigmoidGrad(oper, out_grads), + "Sum" => math_grad._SumGrad(oper, out_grads), + "Sub" => math_grad._SubGrad(oper, out_grads), + "StridedSlice" => array_grad._StridedSliceGrad(oper, out_grads), + _ => ops.gradientFunctions[op_name](oper, out_grads) + };*/ + + return ops.gradientFunctions[op_name](oper, out_grads); + }; + + bool CouldForwardprop() + { + return HasAccumulator(); + } + + bool CouldBackprop() + { + return HasGradientTape(); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RunCallbacks.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RunCallbacks.cs new file mode 100644 index 000000000..1dfa40465 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RunCallbacks.cs @@ -0,0 +1,28 @@ +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + bool RunCallbacks(FastPathOpExecInfo op_exec_info, + int num_inferred_attrs, + Tensor[] inputs, + object[] attrs, + Tensor[] flattened_result) + { + if (op_exec_info.run_gradient_callback) + { + if (!RecordGradient(op_exec_info.op_name, inputs, attrs, + flattened_result)) + { + return false; + } + } + + if (op_exec_info.run_post_exec_callbacks) + { + + } + + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs new file mode 100644 index 000000000..018ba921e --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -0,0 +1,74 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + /// + /// python\eager\pywrap_tfe_src.cc + /// + public partial class EagerRunner + { + public Tensor[] TFE_Execute(Context ctx, + string device_name, + string op_name, + Tensor[] inputs, + object[] attrs, + int num_outputs) + => TFE_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs); + + public Tensor[] TFE_ExecuteCancelable(Context ctx, + string device_name, + string op_name, + Tensor[] inputs, + object[] attrs, + int num_outputs) + { + var status = new Status(); + var op = GetOp(ctx, op_name, status); + c_api.TFE_OpSetDevice(op, device_name, status); + if (status.ok()) + { + for (int i = 0; i < inputs.Length; ++i) + { + SafeEagerTensorHandle tensor_handle = inputs[i] switch + { + EagerTensor et => et.EagerTensorHandle, + Tensor nd => nd.EagerTensorHandle, + _ => throw new NotImplementedException("Eager tensor handle has not been allocated.") + }; + c_api.TFE_OpAddInput(op, tensor_handle, status); + status.Check(true); + } + } + if (status.ok() && attrs != null) + SetOpAttrs(op, attrs); + + var outputs = new SafeEagerTensorHandle[num_outputs]; + if (status.ok()) + { + c_api.TFE_Execute(op, outputs, out num_outputs, status); + status.Check(true); + } + return outputs.Select(x => new EagerTensor(x)).ToArray(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs new file mode 100644 index 000000000..0ce55841b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -0,0 +1,378 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.Util; +using static Tensorflow.Binding; +using static Tensorflow.OpDef.Types; + +namespace Tensorflow.Eager +{ + /// + /// python\eager\pywrap_tfe_src.cc + /// + public partial class EagerRunner + { + UnorderedMap thread_local_eager_operation_map = new UnorderedMap(); + public void ClearEagerOperationMap() + => thread_local_eager_operation_map.Clear(); + + public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info) + { + if (op_exec_info.ctx == null) + op_exec_info.ctx = tf.Context; + if (string.IsNullOrEmpty(op_exec_info.device_name)) + op_exec_info.device_name = tf.Context.DeviceName; + + var attr_list_sizes = new Dictionary(); + + op_exec_info.run_gradient_callback = HasAccumulatorOrTape(); + op_exec_info.run_post_exec_callbacks = op_exec_info.callbacks != null; + op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; + + var status = tf.Status; + var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); + + var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name); + + var flattened_attrs = new List(op_def.Attr.Count * 2); + var flattened_inputs = new List(op_def.InputArg.Count); + + // Set non-inferred attrs, including setting defaults if the attr is passed in + // as None. + if(op_exec_info.attrs != null) + { + foreach (var attr1 in op_exec_info.attrs) + { + var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr1.Key); + if (attr != null) + { + flattened_attrs.Add(attr.Name); + flattened_attrs.Add(attr1.Value); + + SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr.Name, attr1.Value, attr_list_sizes, status); + status.Check(true); + } + } + } + + // c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); + // status.Check(true); + + // Add inferred attrs and inputs. + for (int i = 0; i < op_def.InputArg.Count; i++) + { + var input = op_exec_info.args[i]; + var input_arg = op_def.InputArg[i]; + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) + { + var fast_input_array = input is Tensors tensors ? (object[])tensors : (object[])input; + int len = fast_input_array.Length; + c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); + if (op_exec_info.run_callbacks) + { + flattened_attrs.Add(input_arg.NumberAttr); + flattened_attrs.Add(len); + } + attr_list_sizes[input_arg.NumberAttr] = len; + + if (len > 0) + { + // First item adds the type attr. + if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) + return null; + + for (var j = 1; j < len; j++) + { + // Since the list is homogeneous, we don't need to re-add the attr. + if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status)) + return null; + } + } + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + var attr_name = input_arg.TypeListAttr; + var fast_input_array = input as object[]; + var len = fast_input_array.Length; + var attr_values = new TF_DataType[len]; + + for (var j = 0; j < len; j++) + { + var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); + attr_values[j] = eager_tensor.dtype; + + c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status); + + if (op_exec_info.run_callbacks) + { + flattened_inputs.Add(eager_tensor); + } + } + + if (op_exec_info.run_callbacks) + { + flattened_attrs.Add(attr_name); + flattened_attrs.Add(attr_values); + } + c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length); + attr_list_sizes[attr_name] = len; + } + else + { + // The item is a single item. + AddInputToOp(op_exec_info.args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); + } + } + + int num_retvals = 0; + for (int i = 0; i < op_def.OutputArg.Count; i++) + { + var output_arg = op_def.OutputArg[i]; + var delta = 1L; + if (!string.IsNullOrEmpty(output_arg.NumberAttr)) + delta = attr_list_sizes[output_arg.NumberAttr]; + else if (!string.IsNullOrEmpty(output_arg.TypeListAttr)) + delta = attr_list_sizes[output_arg.TypeListAttr]; + if (delta < 0) + throw new RuntimeError("Attributes suggest that the size of an output list is less than 0"); + num_retvals += (int)delta; + } + + var retVals = new SafeEagerTensorHandle[num_retvals]; + c_api.TFE_Execute(op, retVals, out num_retvals, status); + status.Check(true); + + var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); + + if (op_exec_info.run_callbacks) + { + RunCallbacks(op_exec_info, + op_def.InputArg.Count(), + flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result); + } + + return flat_result; + } + + SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) + { + if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) + c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status); + else + { + op = c_api.TFE_NewOp(ctx, op_or_function_name, status); + thread_local_eager_operation_map[op_or_function_name] = op; + } + + status.Check(true); + return op; + /*var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); + status.Check(true); + return op;*/ + } + + bool HasAccumulator() + { + //return !GetAccumulatorSet()->empty(); + return false; + } + + bool HasGradientTape() + { + return tf.GetTapeSet().Count > 0; + } + + bool HasAccumulatorOrTape() + { + return HasGradientTape() || HasAccumulator(); + } + + /// + /// Adds input and type attr to the op, and to the list of flattened + /// inputs/attrs. + /// + /// + /// + /// + /// + /// + /// + bool AddInputToOp(object inputs, + bool add_type_attr, + ArgDef input_arg, + List flattened_attrs, + List flattened_inputs, + SafeEagerOpHandle op, + Status status) + { + var tensor = tf.convert_to_tensor(inputs); + flattened_inputs.Add(tensor); + + if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) + { + var dtype = tensor.dtype; + c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); + flattened_attrs.Add(input_arg.TypeAttr); + flattened_attrs.Add(dtype); + } + + c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status); + status.Check(true); + + return true; + } + + public void SetOpAttrs(SafeEagerOpHandle op, params object[] attrs) + { + var status = tf.Status; + var len = attrs.Length; + for (int i = 0; i < len; i += 2) + { + var key = attrs[i].ToString(); + var value = attrs[i + 1]; + + byte is_list = 0; + var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); + if (!status.ok()) return; + if (is_list != 0) + SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); + else + SetOpAttrScalar(tf.Context, op, key, value, type, null, status); + status.Check(true); + } + } + + /// + /// This function will set the op attrs required. If an attr has the value of + /// None, then it will read the AttrDef to get the default value and set that + /// instead. Any failure in this function will simply fall back to the slow + /// path. + /// + /// + /// + /// + /// + /// + /// + /// + void SetOpAttrWithDefaults(Context ctx, SafeEagerOpHandle op, AttrDef attr, + string attr_name, object attr_value, + Dictionary attr_list_sizes, + Status status) + { + byte is_list = 0; + var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status); + if (status.Code != TF_Code.TF_OK) return; + + if (attr_value == null) + { + + } + else + { + if (is_list != 0) + SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); + else + SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); + } + } + + bool SetOpAttrList(Context ctx, SafeEagerOpHandle op, + string key, object values, TF_AttrType type, + Dictionary attr_list_sizes, + Status status) + { + if (type == TF_AttrType.TF_ATTR_STRING && values is string[] values3) + { + c_api.TFE_OpSetAttrStringList(op, key, values3, values3.Select(x => Convert.ToUInt64(x.Length)).ToArray(), values3.Length); + attr_list_sizes[key] = values3.Length; + } + else if (type == TF_AttrType.TF_ATTR_SHAPE && values is Shape[] values1) + { + // Make one pass through the input counting the total number of + // dims across all the input lists. + var num_values = values1.Length; + attr_list_sizes[key] = num_values; + var dims = new IntPtr[num_values]; + var num_dims = values1.Select(x => x.ndim).ToArray(); + + for (int i = 0; i < num_values; ++i) + { + dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); + tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); + } + + c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status); + Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); + } + else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) + { + c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); + attr_list_sizes[key] = values2.Length; + } + else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) + { + c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); + attr_list_sizes[key] = values4.Length; + } + else + { + throw new NotImplementedException(""); + } + + return true; + } + + bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op, + string key, object value, TF_AttrType type, + Dictionary attr_list_sizes, + Status status) + { + switch (type) + { + case TF_AttrType.TF_ATTR_STRING: + c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); + break; + case TF_AttrType.TF_ATTR_TYPE: + c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); + break; + case TF_AttrType.TF_ATTR_BOOL: + c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); + break; + case TF_AttrType.TF_ATTR_INT: + var size = Convert.ToInt64(value); + c_api.TFE_OpSetAttrInt(op, key, size); + if (attr_list_sizes != null) + attr_list_sizes[key] = size; + break; + case TF_AttrType.TF_ATTR_FLOAT: + c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); + break; + case TF_AttrType.TF_ATTR_SHAPE: + long[] dims; + if (value is Shape shape) dims = shape.dims.ToArray(); + else if (value is long[] longs) dims = longs.ToArray(); + else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); + else dims = ((long[])value).ToArray(); + c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); + status.Check(true); + break; + case TF_AttrType.TF_ATTR_FUNC: + if (value is ConcreteFunction func) + c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); + else if(value is string str) + c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); + else + throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); + break; + default: + throw new NotImplementedException($"SetOpAttrScalar for {type}"); + } + + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs new file mode 100644 index 000000000..3515fed83 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs @@ -0,0 +1,197 @@ +using OneOf.Types; +using System; +using Tensorflow.Gradients; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + /// + /// python\eager\pywrap_tfe_src.cc + /// + public partial class EagerRunner + { + /// + /// + /// + /// + /// + /// + /// + /// determines the value returned if the target and + /// sources are unconnected.When 'none' the value returned is None wheras when + /// 'zero' a zero tensor in the same shape as the sources is returned. + /// + /// + public Tensor[] TFE_TapeGradient(ITape tape, + Tensor[] target, + Tensor[] sources, + List output_gradients, + Tensor[] sources_raw, + string unconnected_gradients) + { + if (!tape.Persistent) + { + var tape_set = tf.GetTapeSet(); + if (tape_set.Contains(tape)) + { + throw new RuntimeError("gradient() cannot be invoked within the " + + "GradientTape context (i.e., while operations are being " + + "recorded). Either move the call to gradient() to be " + + "outside the 'with tf.GradientTape' block, or " + + "use a persistent tape: " + + "'with tf.GradientTape(persistent=true)'"); + } + } + + var target_vec = MakeTensorIDList(target); + var sources_vec = MakeTensorIDList(sources); + HashSet sources_set = new HashSet(sources_vec); + var source_tensors_that_are_targets = new UnorderedMap(); + + int len = target.Length; + for(int i = 0; i < len; i++) + { + var target_id = target_vec[i]; + if (sources_set.Contains(target_id)) + { + var tensor = target[i]; + source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); + } + } + + List outgrad_vec = new(); + if(output_gradients is not null) + { + outgrad_vec = output_gradients.ToList(); + } + var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); + + + bool unconnected_gradients_zero = unconnected_gradients == "zero"; + Tensor[] sources_obj = null; + if (unconnected_gradients_zero) + { + sources_obj = MakeTensorList(sources_raw); + } + + if (result.Length > 0) + { + for(int i = 0; i < result.Length; i++) + { + if (result[i] is null && unconnected_gradients_zero) + { + var dtype = sources_obj[i].dtype; + result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); + } + } + } + return result; + } + + Tensor[] MakeTensorList(IEnumerable tensors) + { + return tensors.ToArray(); + } + + long[] MakeTensorIDList(Tensor[] tensors) + { + int len = tensors.Length; + long[] ids = new long[len]; + for(int i = 0; i < len; i++) + { + var tensor = tensors[i]; + ids[i] = tensor.Id; + } + return ids; + } + + TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) + { + int len = tensors.Length; + TF_DataType[] dtypes = new TF_DataType[len]; + for (int i = 0; i < len; i++) + { + var tensor = tensors[i]; + dtypes[i] = tensor.dtype; + } + return dtypes; + } + + TapeTensor TapeTensorFromTensor(Tensor tensor) + { + long id = tensor.Id; + var dtype = tensor.dtype; + if (tensor is EagerTensor) + { + var handle = tensor.EagerTensorHandle; + if (DTypeNeedsHandleData(dtype)) + { + return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); + } + + Status status = new(); + int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); + long[] dims = new long[num_dims]; + for(int i = 0; i < num_dims; i++) + { + dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); + } + + if(status.Code != TF_Code.TF_OK) + { + return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); + } + else + { + Shape tensor_shape = new(dims); + return new TapeTensor(id, dtype, tensor_shape); + } + } + var shape_tuple = tensor.shape.dims; + if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) + { + return new TapeTensor(id, dtype, tensor); + } + long[] l = new long[shape_tuple.Length]; + for(int i = 0; i < shape_tuple.Length; i++) + { + if (shape_tuple[i] < 0) + { + l[i] = 0; + } + else + { + l[i] = shape_tuple[i]; + } + } + return new TapeTensor(id, dtype, new Shape(l)); + } + + bool DTypeNeedsHandleData(TF_DataType dtype) + { + return dtype == dtypes.variant || dtype == dtypes.resource; + } + + bool ListContainNone(long[]? list) + { + if(list is null) + { + return true; + } + int len = list.Length; + if(len == 0) + { + return true; + } + for(int i = 0; i < len; i++) + { + if (list[i] == -1) + { + return true; + } + } + return false; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs new file mode 100644 index 000000000..0a23cdd48 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs @@ -0,0 +1,15 @@ +using System; +using Tensorflow.Gradients; +using static Tensorflow.Binding; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + public int TapeSetPossibleGradientTypes(params Tensor[] args) + { + return 1; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs new file mode 100644 index 000000000..9bcc8fe2e --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs @@ -0,0 +1,26 @@ +using System; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + void TapeSetRecordBackprop(string op_type, + TapeTensor[] output_info, + long[] input_ids, + TF_DataType[] input_detyps, + BackwardFunction backward_function) + { + if (!CouldBackprop()) + { + return; + } + + foreach (var tape in tf.GetTapeSet()) + { + tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs new file mode 100644 index 000000000..0490447d9 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs @@ -0,0 +1,22 @@ +using System; +using Tensorflow.Gradients; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + bool TapeSetRecordForwardprop(string op_type, + Tensor[] input_tensors, + TapeTensor[] output_tensors, + BackwardFunction backward_function_getter) + { + if (!CouldForwardprop()) + { + return true; + } + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs new file mode 100644 index 000000000..3987e7a3d --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Gradients; + +namespace Tensorflow.Eager +{ + public partial class EagerRunner + { + public bool TapeSetRecordOperation(string op_type, + Tensor[] input_tensors, + Tensor[] output_tensors, + long[] input_ids, + TF_DataType[] input_dtypes, + BackwardFunction backward_function) + { + var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); + if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, + backward_function)) + return false; + + TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, + backward_function); + + return true; + } + + public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, + Tensor[] input_tensors, BackwardFunction backward_function) + { + var input_ids = MakeTensorIDList(input_tensors); + var input_dtypes = MakeTensorDtypeList(input_tensors); + TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, + backward_function); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.cs new file mode 100644 index 000000000..5a0e20be4 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.cs @@ -0,0 +1,10 @@ +namespace Tensorflow.Eager +{ + /// + /// Eager mode runner + /// + public partial class EagerRunner : IEagerRunner + { + + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs new file mode 100644 index 000000000..c7d71de38 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -0,0 +1,98 @@ +using Tensorflow.NumPy; +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + public partial class EagerTensor + { + public EagerTensor(SafeEagerTensorHandle handle) + { + _id = ops.uid(); + _eagerTensorHandle = handle; + } + + #region scalar eager tensor + public EagerTensor(bool value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(byte value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(sbyte value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(short value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(int value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(uint value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(long value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(ulong value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(float value) : base(value) + => NewEagerTensorHandle(_handle); + public EagerTensor(double value) : base(value) + => NewEagerTensorHandle(_handle); + #endregion + + public EagerTensor(object value, Shape? shape = null, string device_name = null, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value) + => NewEagerTensorHandle(_handle); + + public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) + => NewEagerTensorHandle(_handle); + + public EagerTensor(Array array, Shape shape) : base(array, shape) + => NewEagerTensorHandle(_handle); + + public EagerTensor(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype) + => NewEagerTensorHandle(_handle); + + public EagerTensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) : base(data_ptr, shape, dtype) + => NewEagerTensorHandle(_handle); + + void NewEagerTensorHandle(SafeTensorHandle h) + { + _id = ops.uid(); + _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status); +#if TRACK_TENSOR_LIFE + Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); +#endif + tf.Status.Check(true); + } + + public void Resolve() + { + if (_handle != null) + return; + _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status); + tf.Status.Check(true); + } + + /// + /// _create_substitute_placeholder + /// + /// + public Tensor AsPlaceholder(string name = null) + { + var placeholder = tf_with(ops.control_dependencies(null), _ => tf.placeholder(dtype, name: name)); + copy_handle_data(placeholder); + return placeholder; + } + + public Tensor AsConstant(string name = null) + { + return tf_with(ops.control_dependencies(null), _ => tf.constant(numpy(), name: name)); + } + + void copy_handle_data(Tensor target_t) + { + if (target_t.dtype == TF_DataType.TF_RESOURCE || + target_t.dtype == TF_DataType.TF_VARIANT) + { + // need to export + // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs new file mode 100644 index 000000000..d68522702 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs @@ -0,0 +1,10 @@ +using System; + +namespace Tensorflow.Eager +{ + public partial class EagerTensor + { + public static implicit operator IntPtr(EagerTensor tensor) + => tensor.EagerTensorHandle.DangerousGetHandle(); + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs new file mode 100644 index 000000000..71b3075aa --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs @@ -0,0 +1,20 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Eager +{ + public partial class EagerTensor + { + public override string ToString() + { + var nd = new NDArray(this); + var str = NDArrayRender.ToString(nd); + return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; + } + public string ToString(int maxLength) + { + var nd = new NDArray(this); + var str = NDArrayRender.ToString(nd, maxLength); + return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs new file mode 100644 index 000000000..02bd0bdf2 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -0,0 +1,85 @@ +using System; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + public partial class EagerTensor : Tensor + { + public override SafeTensorHandle Handle + { + get + { + Resolve(); + return _handle; + } + } + + public override IntPtr buffer + { + get + { + Resolve(); + return base.buffer; + } + } + + public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status)); + public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); + + public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status); + + public override ulong bytesize + { + get + { + Resolve(); + return base.bytesize; + } + } + + public override IntPtr TensorDataPointer + { + get + { + Resolve(); + return base.TensorDataPointer; + } + } + + protected override Shape GetShapeInternal() + { + var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)]; + for (int i = 0; i < dims.Length; i++) + dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status); + return dims; + } + + protected override void SetShapeInternal(Shape value) + { + if (!shape.is_compatible_with(value)) + throw new ValueError($"Tensor's shape is not compatible."); + } + + public static int GetRank(IntPtr handle) + { + var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); + return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status); + } + + public static int[] GetDims(IntPtr handle) + { + var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); + var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)]; + for (int i = 0; i < dims.Length; i++) + dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status); + return dims; + } + + public override T[] ToArray() + { + Resolve(); + return base.ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs new file mode 100644 index 000000000..307ca2ce4 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Contexts; + +namespace Tensorflow +{ + public class FastPathOpExecInfo + { + public Context ctx { get; set; } + public string device_name { get; set; } + public string op_name { get; set; } + public string name { get; set; } + public object[] args { get; set; } + public Dictionary attrs { get; set; } + public bool run_gradient_callback { get; set; } + public bool run_post_exec_callbacks { get; set; } + public bool run_callbacks { get; set; } + public Action callbacks { get; set; } + + public FastPathOpExecInfo(Context ctx, string opName, string name, params object[] inputArgs) + { + this.ctx = ctx; + this.op_name = opName; + this.name = name; + this.args = inputArgs; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs new file mode 100644 index 000000000..2c20cfe9b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs @@ -0,0 +1,25 @@ +using Tensorflow; + +internal static class GraphOnlyOps +{ + /// + /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. + /// + /// + /// + /// + /// + internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) + { + var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; + var shape_value = new AttrValue() { Shape = shape.as_proto() }; + var g = ops.get_default_graph(); + Dictionary attrs = new(); + attrs["dtype"] = dtype_value; + attrs["shape"] = shape_value; + var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, + new TF_DataType[0], attrs: attrs, name: name); + var result = op.outputs[0]; + return result; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs new file mode 100644 index 000000000..21a336690 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs @@ -0,0 +1,53 @@ +using System; +using Tensorflow.Contexts; +using Tensorflow.Gradients; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Eager +{ + public interface IEagerRunner + { + Tensor[] Execute(Context ctx, string op_name, + int num_outputs, + Tensor[] inputs, + object[] attrs, + string name = null); + + (TF_DataType, Tensor[]) ArgsToMatchingEager(Context ctx, + TF_DataType default_dtype = TF_DataType.DtInvalid, + object[] args = null); + + Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info); + + Tensor[] TFE_Execute(Context ctx, + string device_name, + string op_name, + Tensor[] inputs, + object[] attrs, + int num_outputs); + + Tensor[] TFE_TapeGradient(ITape tape, + Tensor[] target, + Tensor[] sources, + List output_gradients, + Tensor[] sources_raw, + string unconnected_gradients); + + void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, + Tensor[] input_tensors, BackwardFunction backward_function); + + int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); + + bool RecordGradient(string op_name, + Tensor[] inputs, + object[] attrs, + Tensor[] results, + BackwardFunction getBackwardFunction = null); + + bool MustRecordGradient(); + + int TapeSetPossibleGradientTypes(params Tensor[] args); + + void ClearEagerOperationMap(); + } +} diff --git a/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs b/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs new file mode 100644 index 000000000..de5cd2f15 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeContextHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeContextHandle : SafeTensorflowHandle + { + private SafeContextHandle() + { + } + + public SafeContextHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteContext(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs b/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs new file mode 100644 index 000000000..6a6d1d76b --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeContextOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeContextOptionsHandle : SafeTensorflowHandle + { + private SafeContextOptionsHandle() + { + } + + public SafeContextOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteContextOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/SafeEagerOpHandle.cs b/src/TensorFlowNET.Core/Eager/SafeEagerOpHandle.cs new file mode 100644 index 000000000..66c84d747 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeEagerOpHandle.cs @@ -0,0 +1,42 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeEagerOpHandle : SafeTensorflowHandle + { + private SafeEagerOpHandle() + { + + } + + public SafeEagerOpHandle(IntPtr handle) + : base(handle) + { + + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteOp(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs new file mode 100644 index 000000000..025e65114 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeEagerTensorHandle.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + public sealed class SafeEagerTensorHandle : SafeTensorflowHandle + { + private SafeEagerTensorHandle() + { + } + + public SafeEagerTensorHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { +#if TRACK_TENSOR_LIFE + print($"Delete EagerTensorHandle 0x{handle.ToString("x16")}"); +#endif + c_api.TFE_DeleteTensorHandle(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs b/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs new file mode 100644 index 000000000..cf6601e7e --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/SafeExecutorHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Eager +{ + public sealed class SafeExecutorHandle : SafeTensorflowHandle + { + private SafeExecutorHandle() + { + } + + public SafeExecutorHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteExecutor(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/backprop_util.cs b/src/TensorFlowNET.Core/Eager/backprop_util.cs new file mode 100644 index 000000000..0d726e1de --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/backprop_util.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; + +namespace Tensorflow.Eager +{ + internal static class backprop_util + { + // TODO: add quantized_dtypes (after being supported). + private static HashSet _trainable_dtypes = new HashSet(new TF_DataType[] + { + dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 + }); + public static bool IsTrainable(Tensor tensor) + { + var dtype = _DTypeFromTensor(tensor); + return _trainable_dtypes.Contains(dtype); + } + public static bool IsTrainable(TF_DataType dtype) + { + return _trainable_dtypes.Contains(dtype); + } + + private static TF_DataType _DTypeFromTensor(Tensor tensor) + { + var dtype = tensor.dtype; + if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) + { + CppShapeInferenceResult.Types.HandleData handle_data; + if (tensor is EagerTensor) + { + handle_data = tensor.HandleData; + } + else + { + handle_data = handle_data_util.get_resource_handle_data(tensor); + } + if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && + handle_data.ShapeAndType.Count > 0) + { + var first_type = handle_data.ShapeAndType[0].Dtype; + if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) + { + return first_type.as_tf_dtype(); + } + } + } + return dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs new file mode 100644 index 000000000..11de49600 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -0,0 +1,490 @@ +using Google.Protobuf; +using System; +using System.Runtime.InteropServices; +using Tensorflow.Contexts; +using Tensorflow.Device; +using Tensorflow.Eager; +using Tensorflow.Util; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Return a new options object. + /// + /// TFE_ContextOptions* + [DllImport(TensorFlowLibName)] + public static extern SafeContextOptionsHandle TFE_NewContextOptions(); + + /// + /// Set the config in TF_ContextOptions.options. + /// config should be a serialized tensorflow.ConfigProto proto. + /// If config was not parsed successfully as a ConfigProto, record the + /// error information in *status. + /// + /// TFE_ContextOptions* + /// + /// size_t + /// SafeStatusHandle + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); + + /// + /// Destroy an options object. + /// + /// TFE_ContextOptions* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteContextOptions(IntPtr options); + + /// + /// Configure device placement policy logging for the eager executor. Note this + /// policy is applied to any subsequent op executions. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextSetLogDevicePlacement(SafeContextHandle ctx, bool enable, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// unsigned char* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern TF_AttrType TFE_OpGetAttrType(SafeEagerOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); + + /// + /// Returns the length (number of tensors) of the input argument `input_name` + /// found in the provided `op`. + /// + /// TFE_Op* + /// const char* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); + + /// + /// Returns the length (number of tensors) of the output argument `output_name` + /// found in the provided `op`. + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Op* + /// TFE_TensorHandle** + /// int + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status); + + /// + /// + /// + /// const TFE_ContextOptions* + /// TF_Status* + /// TFE_Context* + [DllImport(TensorFlowLibName)] + public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status); + + /// + /// Adds a function (created from TF_GraphToFunction or + /// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with + /// TFE_Execute by creating an op with the same name as the function. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status); + + /// + /// Removes a function from the context. Once removed, you can no longer + /// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any + /// other function which calls it as an attribute. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextRemoveFunction(SafeContextHandle ctx, string name, SafeStatusHandle status); + + /// + /// Checks whether a function is registered under `name`. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern bool TFE_ContextHasFunction(SafeContextHandle ctx, string name); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextStartStep(SafeContextHandle ctx); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextEndStep(SafeContextHandle ctx); + + /// + /// + /// + /// TFE_Context* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteContext(IntPtr ctx); + + /// + /// Execute the operation defined by and return handles to computed + /// tensors in . + /// + /// + /// Upon successful return, the first slots in will + /// contain handle instances which the caller is responsible for disposing once they are no longer in use. + /// + /// + /// + /// + /// + public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) + { + unsafe + { + num_retvals = retvals?.Length ?? 0; + var rawReturns = stackalloc IntPtr[num_retvals]; + TFE_Execute(op, rawReturns, ref num_retvals, status); + for (var i = 0; i < num_retvals; i++) + { + // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be + // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return + // values. + retvals[i] = new SafeEagerTensorHandle(rawReturns[i]); + } + } + } + + /// + /// Execute the operation defined by 'op' and return handles to computed + /// tensors in `retvals`. + /// + /// TFE_Op* + /// TFE_TensorHandle** + /// int* + /// TF_Status* + [DllImport(TensorFlowLibName)] + private static unsafe extern void TFE_Execute(SafeEagerOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Context* + /// const char* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); + + /// + /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This + /// is for performance optimization by reusing an exiting unused op rather than + /// creating a new op every time. If `raw_device_name` is `NULL` or empty, it + /// does not set the device name. If it's not `NULL`, then it attempts to parse + /// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster + /// than separately calling it because if the existing op has the same + /// `raw_device_name`, it skips parsing and just leave as it is. + /// + /// TFE_Op* + /// const char* + /// const char* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpReset(SafeEagerOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Op* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteOp(IntPtr op); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// TF_DataType + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrInt(SafeEagerOpHandle op, string attr_name, long value); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrFloat(SafeEagerOpHandle op, string attr_name, float value); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// const int64_t* + /// const int + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrShapeList(SafeEagerOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrStringList(SafeEagerOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrBool(SafeEagerOpHandle op, string attr_name, bool value); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrFunctionName(SafeEagerOpHandle op, string attr_name, string data, int length); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// const void* + /// size_t + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, ulong length); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrTypeList(SafeEagerOpHandle op, string attr_name, TF_DataType[] values, int num_values); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); + + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Op* + /// TFE_TensorHandle* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status); + + /// + /// + /// + /// const tensorflow::Tensor& + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t); + + /// + /// Sets the default execution mode (sync/async). Note that this can be + /// overridden per thread using TFE_ContextSetExecutorForThread. + /// + /// TFE_ContextOptions* + /// unsigned char + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable); + + /// + /// + /// + /// TFE_TensorHandle* + /// + [DllImport(TensorFlowLibName)] + public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h); + + /// + /// This function will block till the operation that produces `h` has + /// completed. The memory returned might alias the internal memory used by + /// TensorFlow. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status); + + + /// + /// This function will block till the operation that produces `h` has completed. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status); + + /// + /// Returns the device of the operation that produced `h`. If `h` was produced by + /// a copy, returns the destination device of the copy. Note that the returned + /// device name is not always the device holding the tensor handle's memory. If + /// you want the latter, use TFE_TensorHandleBackingDeviceName. This function + /// will block till the operation that produces `h` has completed. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); + + /// + /// Returns the name of the device in whose memory `h` resides. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status); + + /// + /// + /// + /// TFE_Context* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); + + /// + /// Clears the internal caches in the TFE context. Useful when reseeding random ops. + /// + /// TFE_Context* + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextClearCaches(SafeContextHandle ctx); + + /// + /// + /// + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteTensorHandle(IntPtr h); + + /// + /// + /// + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteEagerTensor(IntPtr h); + + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteBindingArray(IntPtr h); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteBindingTensorArray(IntPtr h); + + /// + /// Creates a new eager Executor. Nodes in one executor are guaranteed to be + /// executed in sequence. Assigning nodes to different executors allows executing + /// nodes in parallel. + /// + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async); + + /// + /// Deletes the eager Executor without waiting for enqueued nodes. Please call + /// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to + /// make sure all nodes are finished. + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteExecutor(IntPtr executor); + + /// + /// Causes the calling thread to block till all ops dispatched in this executor + /// have been executed. Note that "execution" here refers to kernel execution / + /// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee + /// that lower level device queues (like GPU streams) have been flushed. + /// + /// This call may not block for execution of ops enqueued concurrently with this + /// call. + /// + /// TFE_Executor* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status); + + /// + /// Sets a custom Executor for current thread. All nodes created by this thread + /// will be added to this Executor. It will override current executor. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor); + + /// + /// Returns the Executor for current thread. + /// + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_TapeSetRemove(IntPtr tape); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_TapeWatch(IntPtr tape, IntPtr variable); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_TapeVariableAccessed(IntPtr variable); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_TapeWatchedVariables(IntPtr tape); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr ResourceVariable_Handle(IntPtr variable); + + [DllImport(TensorFlowLibName)] + public static extern SafeStatusHandle TFE_TapeGradient(IntPtr tape, + IntPtr[] target, int target_size, + IntPtr[] sources, int source_size, + IntPtr[] outputs, int output_size); + + [DllImport(TensorFlowLibName)] + public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); + } +} diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs new file mode 100644 index 000000000..e981c6c51 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.Binding; +using Tensorflow.Gradients; + +namespace Tensorflow.Eager +{ + internal static class _execute + { + public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) + { + var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); + var types = v.Select(t => t.dtype.as_datatype_enum()); + return (types.ToArray(), v.ToArray()); + } + public static Tensor[] execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); + } + public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + string device_name = ctx.DeviceName; + + ctx.ensure_initialized(); + var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); + + return tensors; + } + public static bool must_record_gradient() + { + return tf.GetTapeSet().Count != 0; + } + + public static bool record_gradient(string op_name, Tensor[] inputs, object[] attrs, Tensor[] results) + { + return tf.Runner.RecordGradient(op_name, inputs, attrs, results); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/forwardprop_util.cs b/src/TensorFlowNET.Core/Eager/forwardprop_util.cs new file mode 100644 index 000000000..a53026d42 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/forwardprop_util.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Eager +{ + public class TangentInfo + { + // TODO(Rinne): implement it. + public object Indices { get; set; } + public object Tangents { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs new file mode 100644 index 000000000..977fe2340 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Exceptions; + +public class AssertionError : TensorflowException +{ + public AssertionError() : base() + { + + } + + public AssertionError(string message) : base(message) + { + + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/InaccessibleTensorError.cs b/src/TensorFlowNET.Core/Exceptions/InaccessibleTensorError.cs new file mode 100644 index 000000000..5195fa6b1 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/InaccessibleTensorError.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Exceptions +{ + public class InaccessibleTensorError : TensorflowException + { + public InaccessibleTensorError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs b/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs new file mode 100644 index 000000000..d5d131564 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class InvalidArgumentError : TensorflowException + { + public InvalidArgumentError() : base() + { + + } + + public InvalidArgumentError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/KeyError.cs b/src/TensorFlowNET.Core/Exceptions/KeyError.cs new file mode 100644 index 000000000..5f9bbc79b --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/KeyError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class KeyError : TensorflowException + { + public KeyError() : base() + { + + } + + public KeyError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/LookupError.cs b/src/TensorFlowNET.Core/Exceptions/LookupError.cs new file mode 100644 index 000000000..5d5418a57 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/LookupError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class LookupError : TensorflowException + { + public LookupError() : base() + { + + } + + public LookupError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs new file mode 100644 index 000000000..c283c1a45 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Exceptions +{ + public class NotOkStatusException : TensorflowException + { + public NotOkStatusException() : base() + { + + } + + public NotOkStatusException(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs b/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs new file mode 100644 index 000000000..f330de821 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class OutOfRangeError : TensorflowException + { + public OutOfRangeError() : base() + { + + } + + public OutOfRangeError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs new file mode 100644 index 000000000..964534aa3 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class RuntimeError : TensorflowException + { + public RuntimeError() : base() + { + + } + + public RuntimeError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/StopIteration.cs b/src/TensorFlowNET.Core/Exceptions/StopIteration.cs new file mode 100644 index 000000000..bdfed2554 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/StopIteration.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class StopIteration : TensorflowException + { + public StopIteration() : base() + { + + } + + public StopIteration(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs new file mode 100644 index 000000000..ee9eca696 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs @@ -0,0 +1,36 @@ +using System; +using System.Runtime.Serialization; + +namespace Tensorflow +{ + + /// + /// Serves as a base class to all exceptions of Tensorflow.NET. + /// + [Serializable] + public class TensorflowException : Exception + { + /// Initializes a new instance of the class. + public TensorflowException() + { } + + /// Initializes a new instance of the class with serialized data. + /// The that holds the serialized object data about the exception being thrown. + /// The that contains contextual information about the source or destination. + /// The info parameter is null. + /// The class name is null or is zero (0). + protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) + { } + + /// Initializes a new instance of the class with a specified error message. + /// The message that describes the error. + public TensorflowException(string message) : base(message) + { } + + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified. + public TensorflowException(string message, Exception innerException) : base(message, innerException) + { } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/TypeError.cs b/src/TensorFlowNET.Core/Exceptions/TypeError.cs new file mode 100644 index 000000000..da340e4eb --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/TypeError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class TypeError : TensorflowException + { + public TypeError() : base() + { + + } + + public TypeError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/ValueError.cs b/src/TensorFlowNET.Core/Exceptions/ValueError.cs new file mode 100644 index 000000000..df9833b30 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/ValueError.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public class ValueError : TensorflowException + { + public ValueError() : base() + { + + } + + public ValueError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/CompositeTensor.cs b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs new file mode 100644 index 000000000..8a942e5a4 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs @@ -0,0 +1,9 @@ +namespace Tensorflow.Framework +{ + /// + /// Abstract base class for Tensor-like objects that are composed from Tensors. + /// + public abstract class CompositeTensor + { + } +} diff --git a/src/TensorFlowNET.Core/Framework/ConfigImpl.cs b/src/TensorFlowNET.Core/Framework/ConfigImpl.cs new file mode 100644 index 000000000..7d8e088a9 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/ConfigImpl.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Device; + +namespace Tensorflow.Framework +{ + public class ConfigImpl + { + /// + /// Return a list of physical devices visible to the host runtime. + /// + /// CPU, GPU, TPU + /// + public PhysicalDevice[] list_physical_devices(string device_type = null) + => tf.Context.list_physical_devices(device_type: device_type); + + public Experimental experimental => new Experimental(); + + public class Experimental + { + public void set_memory_growth(PhysicalDevice device, bool enable) + => tf.Context.set_memory_growth(device, enable); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs new file mode 100644 index 000000000..bac5e6fb1 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs @@ -0,0 +1,73 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Framework +{ + /// + /// A sparse representation of a set of tensor slices at given indices. + /// + public class IndexedSlices : CompositeTensor + { + Tensor _values; + public Tensor values => _values; + Tensor _indices; + public Tensor indices => _indices; + Tensor _dense_shape; + public Tensor dense_shape => _dense_shape; + + public string name => _values.name; + + public string device => _values.Device; + + public Operation op => _values.op; + + public TF_DataType dtype => _values.dtype; + + public Graph graph => _values.graph; + + public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) + { + _values = values; + _indices = indices; + _dense_shape = dense_shape; + + _values.Tag = this; + } + + public static implicit operator Tensor(IndexedSlices indexedSlices) + { + return _indexed_slices_to_tensor(indexedSlices); + } + + public static implicit operator IndexedSlices(Tensor tensor) + { + return tensor.Tag as IndexedSlices; + } + + /// + /// Converts an IndexedSlices object `value` to a Tensor. + /// + /// + /// + /// + /// + /// + public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) + { + return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/AutotuneAlgorithm.cs b/src/TensorFlowNET.Core/Framework/Models/AutotuneAlgorithm.cs new file mode 100644 index 000000000..5289de71e --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/Models/AutotuneAlgorithm.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Framework.Models +{ + public enum AutotuneAlgorithm + { + HILL_CLIMB = 0, + GRADIENT_DESCENT = 1, + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs b/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs new file mode 100644 index 000000000..5a89b90ed --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs @@ -0,0 +1,30 @@ +namespace Tensorflow.Framework.Models +{ + /// + /// Describes a dense object with shape, dtype, and name. + /// + public class DenseSpec : TypeSpec + { + protected Shape _shape; + public Shape shape + { + get { return _shape; } + set { _shape = value; } + } + protected TF_DataType _dtype; + public TF_DataType dtype => _dtype; + + protected string _name; + public string name => _name; + + public DenseSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + _shape = shape; + _dtype = dtype; + _name = name; + } + + public override string ToString() + => $"shape={_shape}, dtype={_dtype.as_numpy_name()}, name={_name}"; + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFGraph.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFGraph.cs new file mode 100644 index 000000000..d6d24875b --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/Models/ScopedTFGraph.cs @@ -0,0 +1,10 @@ +namespace Tensorflow.Framework.Models +{ + public class ScopedTFGraph : Graph + { + public ScopedTFGraph() : base() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs new file mode 100644 index 000000000..ac099ae2b --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs @@ -0,0 +1,41 @@ +using System.Linq; +using Tensorflow.Eager; + +namespace Tensorflow.Framework.Models +{ + public class TensorSpec : DenseSpec + { + public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : + base(shape, dtype, name) + { + + } + + public TensorSpec _unbatch() + { + if (_shape.ndim == 0) + throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); + + return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype); + } + + public TensorSpec _batch(int dim = -1) + { + var shapes = shape.dims.ToList(); + shapes.Insert(0, dim); + return new TensorSpec(shapes.ToArray(), _dtype); + } + + public static TensorSpec FromTensor(Tensor tensor, string? name = null) + { + if(tensor is EagerTensor) + { + return new TensorSpec(tensor.shape, tensor.dtype, name); + } + else + { + return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/Models/TypeSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TypeSpec.cs new file mode 100644 index 000000000..84fd6e256 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/Models/TypeSpec.cs @@ -0,0 +1,9 @@ +namespace Tensorflow.Framework.Models +{ + /// + /// Specifies a TensorFlow value type. + /// + public class TypeSpec + { + } +} diff --git a/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs new file mode 100644 index 000000000..11e920f86 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + internal class ScopedTFFunction + { + SafeFuncGraphHandle _handle; + string _name; + public ScopedTFFunction(SafeFuncGraphHandle func, string name) + { + _handle = func; + _name = name; + } + + public SafeFuncGraphHandle Get() + { + return _handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs new file mode 100644 index 000000000..28d9e5008 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs @@ -0,0 +1,89 @@ +using Tensorflow.Graphs; + +namespace Tensorflow.Framework +{ + internal static class auto_control_deps_utils + { + public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; + public static List get_read_only_resource_input_indices_graph(FuncGraph func_graph) + { + List result = new List(); + // A cache to store the read only resource inputs of an Op. + // Operation -> ObjectIdentitySet of resource handles. + Dictionary> opReadOnlyResourceInputs = + new Dictionary>(); + + for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) + { + Tensor t = func_graph.Inputs[inputIndex]; + if (t.dtype != dtypes.resource) + continue; + + bool readOnly = true; + foreach (var op in t.consumers()) + { + if (opReadOnlyResourceInputs.ContainsKey(op)) + { + if (!opReadOnlyResourceInputs[op].Contains(t)) + { + readOnly = false; + break; + } + } + else + { + List indices = _get_read_only_resource_input_indices_op(op); + opReadOnlyResourceInputs[op] = new HashSet( + indices.Select(i => op.inputs[i])); + if (!opReadOnlyResourceInputs[op].Contains(t)) + { + readOnly = false; + break; + } + } + } + + if (readOnly) + result.Add(inputIndex); + } + + return result; + } + + private static List _get_read_only_resource_input_indices_op(Operation op) + { + // ignore the RESOURCE_READ_OPS + + int[] read_only_input_indices; + + try + { + read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR); + } + catch (InvalidArgumentError) + { + return new List(); + } + + int read_only_index = 0; + List result = new(); + for (int i = 0; i < op.inputs.Length; i++) + { + if (read_only_index >= read_only_input_indices.Length) + { + break; + } + if (op.inputs[i].dtype != dtypes.resource) + { + continue; + } + if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) + { + result.Add(i); + read_only_index++; + } + } + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.cs b/src/TensorFlowNET.Core/Framework/c_api_util.cs new file mode 100644 index 000000000..e21c3b019 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/c_api_util.cs @@ -0,0 +1,145 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; +using System.Net; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow +{ + public class c_api_util + { + static bool isDllDownloaded = false; + static object locker = new object(); + public static void DownloadLibrary() + { + string dll = c_api.TensorFlowLibName; + string directory = AppDomain.CurrentDomain.BaseDirectory; + string file = ""; + string url = ""; + + switch (Environment.OSVersion.Platform) + { + case PlatformID.Win32NT: + dll = $"{dll}.dll"; + file = Path.Combine(directory, "libtensorflow-cpu-windows-x86_64-1.14.0.zip"); + url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip"; + break; + case PlatformID.Unix: + dll = $"lib{dll}.so"; + file = Path.Combine(directory, "libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz"); + url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz"; + break; + default: + throw new RuntimeError($"Unknown OS environment: {Environment.OSVersion.Platform}"); + } + + if (isDllDownloaded || File.Exists($"{directory}/{dll}")) + { + isDllDownloaded = true; + return; + } + + lock (locker) + { + if (!File.Exists(file)) + { + var wc = new WebClient(); + Binding.tf_output_redirect.WriteLine($"Downloading Tensorflow library from {url}..."); + var download = Task.Run(() => wc.DownloadFile(url, file)); + while (!download.IsCompleted) + { + Thread.Sleep(1000); + Binding.tf_output_redirect.Write("."); + } + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine($"Downloaded successfully."); + } + + Binding.tf_output_redirect.WriteLine($"Extracting..."); + var task = Task.Run(() => + { + switch (Environment.OSVersion.Platform) + { + case PlatformID.Win32NT: + ZipFile.ExtractToDirectory(file, directory); + Util.CmdHelper.Command($"move lib\\* .\\"); + Util.CmdHelper.Command($"rm -r lib"); + Util.CmdHelper.Command($"rm -r include"); + break; + case PlatformID.Unix: + Util.CmdHelper.Bash($"tar xvzf {file} ./lib/"); + Util.CmdHelper.Bash($"mv {directory}/lib/* {directory}"); + Util.CmdHelper.Bash($"rm -r {directory}/lib"); + break; + default: + throw new RuntimeError($"Unknown OS environment: {Environment.OSVersion.Platform}"); + } + }); + + while (!task.IsCompleted) + { + Thread.Sleep(100); + Binding.tf_output_redirect.Write("."); + } + + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extraction is completed."); + } + + isDllDownloaded = true; + } + + public static TF_Output tf_output(IntPtr c_op, int index) => new TF_Output(c_op, index); + + public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); + + public static Buffer tf_buffer(byte[] data = null) + { + if(data is not null) + { + return new Buffer(data); ; + } + else + { + return new Buffer(); + } + } + + public static IEnumerable new_tf_operations(Graph graph) + { + foreach (var c_op in tf_operations(graph)) + { + if (graph._get_operation_by_tf_operation(c_op) == null) + yield return c_op; + } + } + + public static IEnumerable tf_operations(Graph graph) + { + uint pos = 0; + IntPtr c_op; + while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) + { + yield return new Operation(c_op, graph); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs new file mode 100644 index 000000000..9bb793da6 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Framework +{ + public static class common_shapes + { + /// + /// Returns the broadcasted shape between `shape_x` and `shape_y + /// + /// + /// + public static Tensor broadcast_shape(Tensor shape_x, Tensor shape_y) + { + var return_dims = _broadcast_shape_helper(shape_x, shape_y); + // return tensor_shape(return_dims); + throw new NotFiniteNumberException(); + } + /// + /// Helper functions for is_broadcast_compatible and broadcast_shape. + /// + /// A `Shape` + /// A `Shape` + /// Returns None if the shapes are not broadcast compatible, + /// a list of the broadcast dimensions otherwise. + /// + public static Tensor _broadcast_shape_helper(Tensor shape_x, Tensor shape_y) + { + throw new NotFiniteNumberException(); + } + + public static int? rank(Tensor tensor) + { + return tensor.rank; + } + + public static bool has_fully_defined_shape(Tensor tensor) + { + return tensor.shape.IsFullyDefined; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs new file mode 100644 index 000000000..488c6b654 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs @@ -0,0 +1,297 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Security.Cryptography; +using System.Text; +using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; +using static Tensorflow.Binding; +using static Tensorflow.CppShapeInferenceResult.Types; + +namespace Tensorflow.Framework +{ + public class function_def_lib + { + // TODO(Rinne): process signatures and structured outputs. + public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, + object? structured_outputs, List input_shapes = null) + { + var func_graph = new FuncGraph(fdef.Signature.Name); + if(input_shapes is null) + { + if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) + { + var raw_input_shapes = input_shapes_attr.List.Shape; + input_shapes = new List(); + foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) + { + if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) + { + input_shapes.Add(null); + } + else + { + input_shapes.Add(input_shape); + } + } + } + } + + var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); + + func_graph.as_default(); + importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); + var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); + func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); + + var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); + func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); + // TODO(Rinne): func_graph.ControlOutputs + _set_handle_data(func_graph, fdef); + + foreach(var node in graph_def.Node) + { + if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) + { + var op = func_graph.get_operation_by_name(node.Name); + foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) + { + op.outputs[output_index].shape = new Shape(shape); + } + } + } + Dictionary output_names = new(); + foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) + { + output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; + } + func_graph._output_names = output_names; + + func_graph.Exit(); + return func_graph; + } + + public static (GraphDef, Dictionary) function_def_to_graph_def(FunctionDef fdef, List input_shapes) + { + var graph_def = new GraphDef() + { + Versions = new VersionDef() + { + Producer = versions.GRAPH_DEF_VERSION, + MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER + } + }; + + var default_graph = ops.get_default_graph(); + + if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) + { + throw new ValueError($"Length of `input_shapes` must match the number " + + $"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + + $"{fdef.Signature.InputArg.Count} `input_arg`s."); + } + + foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) + { + NodeDef node_def = new(); + node_def.Name = arg_def.Name; + node_def.Op = "Placeholder"; + node_def.Attr["dtype"] = new AttrValue() + { + Type = arg_def.Type + }; + if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) + { + var input_shape = input_shapes[i]; + // skip the condition that input_shape is not `TensorShapeProto`. + AttrValue shape = new AttrValue() + { + Shape = new TensorShapeProto() + }; + shape.Shape = new TensorShapeProto(input_shape); + node_def.Attr["shape"] = shape; + } + if (!fdef.ArgAttr.ContainsKey((uint)i)) + { + fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); + } + var arg_attrs = fdef.ArgAttr[(uint)i].Attr; + foreach(var k in arg_attrs.Keys) + { + if(k == "_output_shapes") + { + if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) + { + node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); + } + else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) + { + node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); + } + } + else if (k.StartsWith("_")) + { + if (!node_def.Attr.ContainsKey(k)) + { + node_def.Attr[k] = new AttrValue(); + } + node_def.Attr[k] = new AttrValue(arg_attrs[k]); + } + } + + graph_def.Node.Add(node_def); + } + + graph_def.Node.AddRange(fdef.NodeDef); + + Dictionary nested_to_flat_tensor_name = new(); + foreach(var arg_def in fdef.Signature.InputArg) + { + nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; + string control_name = "^" + arg_def.Name; + nested_to_flat_tensor_name[control_name] = control_name; + } + + foreach(var node_def in fdef.NodeDef) + { + var graph = default_graph; + while (true) + { + if(graph is null) + { + break; + } + var f = graph.Functions.GetOrDefault(node_def.Op, null); + if(f is not null && graph.OuterGraph is null) + { + break; + } + graph = graph.OuterGraph; + } + + var op_def = default_graph.GetOpDef(node_def.Op); + + foreach(var attr in op_def.Attr) + { + if(attr.Type == "func") + { + var fname = node_def.Attr[attr.Name].Func.Name; + if (!is_function(fname)) + { + throw new ValueError($"Function {fname} was not found. Please make sure " + + $"the FunctionDef `fdef` is correct."); + } + } + else if(attr.Type == "list(func)") + { + foreach(var fn in node_def.Attr[attr.Name].List.Func) + { + var fname = fn.Name; + if (!is_function(fname)) + { + throw new ValueError($"Function {fname} was not found. Please make " + + $"sure the FunctionDef `fdef` is correct."); + } + } + } + } + + int flattened_index = 0; + foreach(var arg_def in op_def.OutputArg) + { + var num_args = _get_num_args(arg_def, node_def); + for(int i = 0; i < num_args; i++) + { + var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; + var flat_name = $"{node_def.Name}:{flattened_index}"; + nested_to_flat_tensor_name[nested_name] = flat_name; + flattened_index++; + } + } + string control_name = "^" + node_def.Name; + nested_to_flat_tensor_name[control_name] = control_name; + } + + foreach(var node_def in graph_def.Node) + { + for(int i = 0; i < node_def.Input.Count; i++) + { + node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; + } + } + + return (graph_def, nested_to_flat_tensor_name); + } + + private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) + { + foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) + { + if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) + { + tensor.shape = Shape.Scalar; + + var shape_and_type = arg_def.HandleData[0]; + var handle_data = new HandleData(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new HandleShapeAndType() + { + Shape = shape_and_type.Shape, + Dtype = shape_and_type.Dtype + }); + resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); + } + } + } + + private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) + { + if (!string.IsNullOrEmpty(arg_def.NumberAttr)) + { + return node_def.Attr[arg_def.NumberAttr].I; + } + else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) + { + return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; + } + else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) + { + return 1; + } + else + { + throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + + $"FunctionDef `fdef` is correct."); + } + } + + public static bool is_function(string fname) + { + if (tf.Context.executing_eagerly()) + { + return tf.Context.has_function(fname); + } + else + { + var graph = ops.get_default_graph(); + while(graph is not null) + { + if (graph.IsFunction(fname)) + { + return true; + } + if(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + else + { + return false; + } + } + } + throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/graph_util_impl.cs b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs new file mode 100644 index 000000000..af87c578d --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs @@ -0,0 +1,227 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class graph_util_impl + { + /// + /// Replaces all the variables in a graph with constants of the same values. + /// + /// Active TensorFlow session containing the variables. + /// GraphDef object holding the network. + /// List of name strings for the result nodes of the graph. + /// + /// + /// GraphDef containing a simplified version of the original. + public GraphDef convert_variables_to_constants(Session sess, + GraphDef input_graph_def, + string[] output_node_names, + string[] variable_names_whitelist = null, + string[] variable_names_blacklist = null) + { + // This graph only includes the nodes needed to evaluate the output nodes, and + // removes unneeded nodes like those involved in saving and assignment. + var inference_graph = extract_sub_graph(input_graph_def, output_node_names); + + // Identify the ops in the graph. + var map_name_to_node = new Dictionary(); + inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray(); + + // Get list of variables. + var variable_names = new List(); + var variable_dict_names = new List(); + + foreach (var node in inference_graph.Node) + { + if (new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op)) + { + var variable_name = node.Name; + + variable_dict_names.Add(variable_name); + if (node.Op == "VarHandleOp") + variable_names.Add(variable_name + "/Read/ReadVariableOp:0"); + else + variable_names.Add(variable_name + ":0"); + } + else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op)) + { + // There can be one or more Identity ops in between the ReadVariableOp and + // VarHandleOp. Store the Identity ops with the associated dtypes. + var source_op_name = get_input_name(node); + while (map_name_to_node[source_op_name].Op == "Identity") + { + throw new NotImplementedException("map_name_to_node[source_op_name].Op"); + /*resource_identity_types[source_op_name] = node.attr["dtype"]; + source_op_name = get_input_name(map_name_to_node[source_op_name]);*/ + } + } + } + + // Gets map of variables and the associated data. + NDArray[] returned_variables = null; + if (variable_names != null) + returned_variables = sess.run(variable_names); + + var variables_data_map = new Dictionary(); + foreach (var (i, name) in enumerate(variable_dict_names)) + variables_data_map[name] = returned_variables[i]; + print($"Froze {len(returned_variables)} variables."); + + // Reconstruct the graph with constants in place of variables. + var output_graph_def = new GraphDef(); + int how_many_converted = 0; + foreach (var input_node in inference_graph.Node) + { + var output_node = new NodeDef(); + if (variables_data_map.ContainsKey(input_node.Name)) + { + var data = variables_data_map[input_node.Name]; + output_node = create_const_op(input_node.Name, input_node.Attr["dtype"], + data, data.dims.Select(x => Convert.ToInt32(x)).ToArray()); + how_many_converted += 1; + } + // else if (resource_identity_types.ContainsKey(input_node.Name)) + else if (input_node.Op == "ReadVariableOp") + { + output_node.Op = "Identity"; + output_node.Name = input_node.Name; + output_node.Input.AddRange(new[] { input_node.Input[0] }); + output_node.Attr["T"] = input_node.Attr["dtype"]; + } + else if (input_node.Op == "ResourceGather") + { + + } + else + { + output_node.MergeFrom(input_node); + } + + output_graph_def.Node.AddRange(new[] { output_node }); + } + + output_graph_def.Library = inference_graph.Library; + print($"Converted {how_many_converted} variables to const ops."); + return output_graph_def; + } + + private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null) + { + var output_node = new NodeDef + { + Op = "Const", + Name = node_name + }; + output_node.Attr["dtype"] = dtype; + output_node.Attr["value"] = new AttrValue() + { + Tensor = tensor_util.make_tensor_proto(data, + dtype: dtype.Type.as_tf_dtype(), + shape: data_shape) + }; + + return output_node; + } + + /// + /// Gets the name of the first input. Errors if suffix is not :0. + /// + /// + /// + private string get_input_name(NodeDef node) + { + var details = node.Input[0].Split(':'); + if (details.Length == 1 || int.Parse(details[1]) == 0) + return details[0]; + // While it is valid for input tensors to have a suffix that is not :0, this + // method is used to find the associated ops, not tensors, and therefore it + // is not valid. + throw new ValueError($"Tensor name '{node.Input[0]}' is invalid."); + } + + + private GraphDef extract_sub_graph(GraphDef graph_def, string[] dest_nodes) + { + var (name_to_input_name, name_to_node, name_to_seq_num) = _extract_graph_summary( + graph_def); + + var nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name); + var nodes_to_keep_list = nodes_to_keep.OrderBy(n => name_to_seq_num[n]).ToArray(); + // Now construct the output GraphDef + var output = new GraphDef(); + foreach (var n in nodes_to_keep_list) + output.Node.Add(name_to_node[n]); // need deep clone? + output.Library = graph_def.Library; + output.Versions = graph_def.Versions; + + return output; + } + + private string[] _bfs_for_reachable_nodes(string[] target_nodes, Dictionary name_to_input_name) + { + var nodes_to_keep = new List(); + var next_to_visit = target_nodes.Select(x => x).ToList(); + while (next_to_visit.Count > 0) + { + var node = next_to_visit[0]; + next_to_visit.RemoveAt(0); + if (nodes_to_keep.Contains(node)) + continue; + nodes_to_keep.Add(node); + if (name_to_input_name.Keys.Contains(node)) + next_to_visit.AddRange(name_to_input_name[node]); + } + + return nodes_to_keep.ToArray(); + } + + private (Dictionary, Dictionary, Dictionary) _extract_graph_summary(GraphDef graph_def) + { + var name_to_input_name = new Dictionary(); + var name_to_node = new Dictionary(); + var name_to_seq_num = new Dictionary(); + + int seq = 0; + foreach (var node in graph_def.Node) + { + var n = _node_name(node.Name); + name_to_node[n] = node; + name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray(); + name_to_seq_num[n] = seq; + seq++; + } + + return (name_to_input_name, name_to_node, name_to_seq_num); + } + + private string _node_name(string n) + { + return n.StartsWith("^") ? n.Substring(1) : n.Split(':')[0]; + } + + private string get_input_name(string node) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs new file mode 100644 index 000000000..e7e7cf394 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -0,0 +1,307 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using static Tensorflow.Binding; +using static Tensorflow.OpDef.Types; + +namespace Tensorflow +{ + public class importer + { + public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) + { + return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); + } + public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, + Dictionary input_map = null, + string[] return_elements = null, + bool validate_colocation_constraints = true, + string name = null, + OpList producer_op_list = null) + { + var op_dict = op_def_registry.get_registered_ops(); + + graph_def = _ProcessGraphDefParam(graph_def, op_dict); + input_map = _ProcessInputMapParam(input_map); + return_elements = _ProcessReturnElementsParam(return_elements); + + if (producer_op_list != null) + _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def); + + string prefix = ""; + var graph = ops.get_default_graph(); + tf_with(ops.name_scope(name, "import", input_map.Values), scope => + { + prefix = scope; + /*if (!string.IsNullOrEmpty(prefix)) + prefix = prefix.Substring(0, prefix.Length - 1); + else + prefix = "";*/ + + // Generate any input map tensors inside name scope + input_map = _ConvertInputMapValues(name, input_map); + }); + + TF_ImportGraphDefResults results = null; + var bytes = graph_def.ToByteString().ToArray(); + var buffer = c_api_util.tf_buffer(bytes); + var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); + var status = new Status(); + + _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); + // need to create a class ImportGraphDefWithResults with IDisposal + results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); + status.Check(true); + + _ProcessNewOps(graph); + + if (return_elements == null) + return null; + else + return _GatherReturnElements(return_elements, graph, results); + } + + private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, + Graph graph, + TF_ImportGraphDefResults results) + { + var return_outputs = results.return_tensors; + var return_opers = results.return_opers; + + var combined_return_elements = new List(); + int outputs_idx = 0; +#pragma warning disable CS0219 // Variable is assigned but its value is never used + int opers_idx = 0; +#pragma warning restore CS0219 // Variable is assigned but its value is never used + foreach (var name in requested_return_elements) + { + if (name.Contains(":")) + { + combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx])); + outputs_idx += 1; + } + else + { + throw new NotImplementedException("_GatherReturnElements"); + // combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx])); + } + } + + return combined_return_elements.ToArray(); + } + + private static void _ProcessNewOps(Graph graph) + { + foreach (var new_op in graph._add_new_tf_operations()) + { + var original_device = new_op.Device; + new_op._set_device(original_device); + } + } + + public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, + string prefix, + Dictionary input_map, + string[] return_elements, + bool validate_colocation_constraints) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); + c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); + + foreach (var input in input_map) + { + var input_src = tf.compat.as_str(input.Key); + var input_dst = input.Value; + if (input_src.StartsWith("^")) + { + var src_name = tf.compat.as_str(input_src.Substring(1)); + var dst_op = input_dst._as_tf_output().oper; + c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); + } + else + { + var (src_name, src_index) = _ParseTensorName(input.Key); + src_name = tf.compat.as_str(src_name); + var dst_output = input_dst._as_tf_output(); + c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); + } + } + + if (return_elements == null) + return_elements = new string[0]; + + foreach (var name in return_elements) + { + if (name.Contains(":")) + { + var (op_name, index) = _ParseTensorName(name); + op_name = tf.compat.as_str(op_name); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); + } + else + { + c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); + } + } + + c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); + } + + private static (string, int) _ParseTensorName(string tensor_name) + { + var components = tensor_name.Split(':'); + if (components.Length == 2) + return (components[0], int.Parse(components[1])); + else if (components.Length == 1) + return (components[0], 0); + else + throw new ValueError($"Cannot convert {tensor_name} to a tensor name."); + } + + public static Dictionary _ConvertInputMapValues(string name, Dictionary input_map) + { + return input_map; + } + + public static GraphDef _ProcessGraphDefParam(GraphDef graph_def, Dictionary op_dict) + { + foreach (var node in graph_def.Node) + { + if (!op_dict.ContainsKey(node.Op)) + continue; + + var op_def = op_dict[node.Op]; + _SetDefaultAttrValues(node, op_def); + } + + return graph_def; + } + + private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) + { + var old_graph_def = graph_def; + graph_def = new GraphDef(old_graph_def); + + return graph_def; + } + + private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) + { + foreach (var attr_def in op_def.Attr) + { + var key = attr_def.Name; + if (attr_def.DefaultValue != null) + { + if (node_def.Attr.ContainsKey(key)) + { + var value = node_def.Attr[key]; + if (value == null) + node_def.Attr[key] = attr_def.DefaultValue; + } + else + { + node_def.Attr[key] = attr_def.DefaultValue; + } + } + } + } + + private static Dictionary _ProcessInputMapParam(Dictionary input_map) + { + if (input_map == null) + return new Dictionary(); + + return input_map; + } + + private static string[] _ProcessReturnElementsParam(string[] return_elements) + { + if (return_elements == null) + return null; + + return return_elements; + } + + private static void _RemoveDefaultAttrs(Dictionary op_dict, OpList producer_op_list, GraphDef graph_def) + { + var producer_op_dict = new Dictionary(); + producer_op_list.Op.Select(op => + { + producer_op_dict[op.Name] = op; + return op; + }).ToArray(); + + foreach (var node in graph_def.Node) + { + // Remove any default attr values that aren't in op_def. + if (producer_op_dict.ContainsKey(node.Op)) + { + var op_def = op_dict[node.Op]; + var producer_op_def = producer_op_dict[node.Op]; + foreach (var key in node.Attr) + { + if (_FindAttrInOpDef(key.Key, op_def) == null) + { + var attr_def = _FindAttrInOpDef(key.Key, producer_op_def); + if (attr_def != null && attr_def.DefaultValue != null && + node.Attr[key.Key] == attr_def.DefaultValue) + node.Attr[key.Key].ClearValue(); + } + } + } + } + } + + private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) + { + var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); + + foreach (var node in graph_def.Node) + { + // Remove any default attr values that aren't in op_def. + if (producer_op_dict.ContainsKey(node.Op)) + { + var op_def = op_def_registry.GetOpDef(node.Op); + if(op_def is null) + { + continue; + } + var producer_op_def = producer_op_dict[node.Op]; + foreach (var key in node.Attr.Keys) + { + if (_FindAttrInOpDef(key, op_def) is null) + { + var attr_def = _FindAttrInOpDef(key, producer_op_def); + if (attr_def != null && attr_def.DefaultValue != null && + node.Attr[key] == attr_def.DefaultValue) + node.Attr[key].ClearValue(); + } + } + } + } + } + + private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) + { + return op_def.Attr.FirstOrDefault(x => x.Name == name); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs new file mode 100644 index 000000000..c3616fafd --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -0,0 +1,433 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; +using static Tensorflow.CollectionDef; +using static Tensorflow.MetaGraphDef.Types; + +namespace Tensorflow +{ + public class meta_graph + { + public static MetaGraphDef read_meta_graph_file(string filename) + { + var bytes = File.ReadAllBytes(filename); + var meta_graph_def = MetaGraphDef.Parser.ParseFrom(bytes); + return meta_graph_def; + } + + public static (Dictionary, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, + bool clear_devices = false, + string import_scope = "", + Dictionary input_map = null, + string unbound_inputs_col_name = "unbound_inputs", + string[] return_elements = null) + { + var meta_graph_def = meta_graph_or_file; + + if (!string.IsNullOrEmpty(unbound_inputs_col_name)) + { + foreach (var col in meta_graph_def.CollectionDef) + { + if (col.Key == unbound_inputs_col_name) + { + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + } + } + + // Sets graph to default graph if it's not passed in. + var graph = ops.get_default_graph(); + + // Gathers the list of nodes we are interested in. + OpList producer_op_list = null; + if (meta_graph_def.MetaInfoDef.StrippedOpList != null) + producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; + var input_graph_def = meta_graph_def.GraphDef; + // Remove all the explicit device specifications for this node. This helps to + // make the graph more portable. + if (clear_devices) + foreach (var node in input_graph_def.Node) + node.Device = ""; + + var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); + var imported_return_elements = importer.import_graph_def(input_graph_def, + name: scope_to_prepend_to_names, + input_map: input_map, + producer_op_list: producer_op_list, + return_elements: return_elements); + + // Restores all the other collections. + var variable_objects = new Dictionary(); + foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) + { + // Don't add unbound_inputs to the new graph. + if (col.Key == unbound_inputs_col_name) + continue; + + switch (col.Value.KindCase) + { + case KindOneofCase.NodeList: + foreach (var value in col.Value.NodeList.Value) + { + var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); + graph.add_to_collection(col.Key, col_op); + } + break; + case KindOneofCase.BytesList: + //var proto_type = ops.get_collection_proto_type(key) + if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) + { + foreach (var value in col.Value.BytesList.Value) + { + IVariableV1 variable = null; + if (!variable_objects.ContainsKey(value)) + { + var proto = VariableDef.Parser.ParseFrom(value); + if (proto.IsResource) + variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); + else + variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); + variable_objects[value] = variable; + } + variable = variable_objects[value]; + graph.add_to_collection(col.Key, variable); + } + } + else + { + foreach (var value in col.Value.BytesList.Value) + { + switch (col.Key) + { + case "cond_context": + { + var proto = CondContextDef.Parser.ParseFrom(value); + var condContext = new CondContext().from_proto(proto, import_scope); + graph.add_to_collection(col.Key, condContext); + } + break; + case "while_context": + { + var proto = WhileContextDef.Parser.ParseFrom(value); + var whileContext = new WhileContext().from_proto(proto, import_scope); + graph.add_to_collection(col.Key, whileContext); + } + break; + default: + Binding.tf_output_redirect.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); + continue; + } + } + } + + break; + default: + Binding.tf_output_redirect.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); + break; + } + } + + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, + scope: scope_to_prepend_to_names); + var var_list = new Dictionary(); + variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); + + return (var_list, imported_return_elements); + } + + /// + /// Returns `MetaGraphDef` proto. Optionally writes it to filename. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", + GraphDef graph_def = null, + bool as_text = false, + string unbound_inputs_col_name = "unbound_inputs", + bool clear_devices = false, + SaverDef saver_def = null, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false, + byte[] meta_info_def = null) + { + var graph = ops.get_default_graph(); + + var var_list = new Dictionary(); + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); + + if (variables != null) + { + foreach (var v in variables) + { + var_list[v.Name] = v; + } + } + + var scoped_meta_graph_def = create_meta_graph_def( + graph_def: graph_def, + export_scope: "", + exclude_nodes: "", + clear_extraneous_savers: clear_extraneous_savers, + saver_def: saver_def, + strip_default_attrs: strip_default_attrs); + + if (!string.IsNullOrEmpty(filename)) + graph_io.write_graph(scoped_meta_graph_def, "", filename, as_text: as_text); + + return (scoped_meta_graph_def, var_list); + } + + private static bool _should_include_node() + { + return true; + } + + private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = null, + GraphDef graph_def = null, + string export_scope = "", + string exclude_nodes = "", + SaverDef saver_def = null, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false) + { + // Sets graph to default graph if it's not passed in. + var graph = ops.get_default_graph().as_default(); + // Creates a MetaGraphDef proto. + var meta_graph_def = new MetaGraphDef(); + if (meta_info_def == null) + meta_info_def = new MetaInfoDef(); + + // Set the tf version strings to the current tf build. + meta_info_def.TensorflowVersion = tf.VERSION; + meta_info_def.TensorflowGitVersion = "unknown"; + meta_graph_def.MetaInfoDef = meta_info_def; + + // Adds graph_def or the default. + if (graph_def == null) + meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true); + else + meta_graph_def.GraphDef = graph_def; + + // Fills in meta_info_def.stripped_op_list using the ops from graph_def. + if (meta_graph_def.MetaInfoDef.StrippedOpList == null || + meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) + meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); + + var clist = graph.get_all_collection_keys(); + foreach (var ctype in clist) + { + if (clear_extraneous_savers) + { + throw new NotImplementedException("create_meta_graph_def clear_extraneous_savers"); + } + else + { + add_collection_def(meta_graph_def, ctype, graph); + } + } + + return meta_graph_def; + } + + private static void add_collection_def(MetaGraphDef meta_graph_def, + string key, + Graph graph = null, + string export_scope = "") + { + if (!meta_graph_def.CollectionDef.ContainsKey(key)) + meta_graph_def.CollectionDef[key] = new CollectionDef(); + var col_def = meta_graph_def.CollectionDef[key]; + + switch (graph.get_collection(key)) + { + case List collection_list: + col_def.BytesList = new Types.BytesList(); + foreach (var x in collection_list) + { + if (x is RefVariable x_ref_var) + { + var proto = x_ref_var.to_proto(export_scope); + col_def.BytesList.Value.Add(proto.ToByteString()); + } + else if (x is ResourceVariable x_res_var) + { + var proto = x_res_var.to_proto(export_scope); + col_def.BytesList.Value.Add(proto.ToByteString()); + } + } + break; + case List collection_list: + col_def.BytesList = new Types.BytesList(); + foreach (var x in collection_list) + { + var proto = x.to_proto(export_scope); + col_def.BytesList.Value.Add(proto.ToByteString()); + } + + break; + case List collection_list: + col_def.NodeList = new Types.NodeList(); + foreach (var x in collection_list) + if (x is ITensorOrOperation x2) + col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); + break; + case List collection_list: + break; + } + } + + public static OpList stripped_op_list_for_graph(GraphDef graph_def) + { + var used_ops = ops_used_by_graph_def(graph_def); + + // Verify that all used ops are registered. + // var registered_ops = op_def_registry.get_registered_ops(); + + var op_list = new OpList(); + /*used_ops.OrderBy(x => x).Select(x => { + + }).ToArray();*/ + + return op_list; + } + + /// + /// Collect the list of ops used by a graph. + /// + /// + /// + private static string[] ops_used_by_graph_def(GraphDef graph_def) + { + var used_ops = new List(); + + Action mark_op_as_used = (op) => + { + if (!used_ops.Contains(op)) + { + + } + + used_ops.Add(op); + }; + + foreach (var node in graph_def.Node) + { + mark_op_as_used(node.Op); + } + + return used_ops.ToArray(); + } + + private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) + { + foreach (var attr_def in op_def.Attr) + { + if (attr_def.Name == attr_name) + { + if (attr_def.DefaultValue is null) return false; + // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. + return true; + } + } + + return false; + } + + public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) + { + Dictionary op_name_to_function = new(); + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + op_name_to_function[function_def.Signature.Name] = function_def; + } + + Action _strip_node_default_valued_attrs = (node_def) => + { + if (op_name_to_function.ContainsKey(node_def.Op)) return; + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is null) return; + + HashSet attrs_to_strip = new(); + foreach (var attr in node_def.Attr) + { + if (is_default_attr_value(op_def, attr.Key, attr.Value)) + { + attrs_to_strip.Add(attr.Key); + } + } + + foreach (var attr in attrs_to_strip) + { + node_def.Attr.Remove(attr); + } + }; + + foreach (var node_def in meta_graph_def.GraphDef.Node) + { + _strip_node_default_valued_attrs(node_def); + } + + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + foreach (var function_node_def in function_def.NodeDef) + { + _strip_node_default_valued_attrs(function_node_def); + } + } + + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + } + + /// + /// Extract the Op name from a Tensor name. + /// + /// + /// + public static string op_name(string tensor_name) + { + if (string.IsNullOrEmpty(tensor_name)) + { + throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); + } + + if (tensor_name.StartsWith("^")) + { + tensor_name = tensor_name.Substring(1); + } + if (tensor_name.Contains(":")) + { + return tensor_name.Split(':')[0]; + } + return tensor_name; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs new file mode 100644 index 000000000..111719aad --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -0,0 +1,52 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Util; + +namespace Tensorflow +{ + public class op_def_registry + { + static Dictionary _registered_ops = new Dictionary(); + + public static Dictionary get_registered_ops() + { + if (_registered_ops.Count == 0) + { + lock (_registered_ops) + { + // double validation to avoid multi-thread executing + if (_registered_ops.Count > 0) + return _registered_ops; + + var buffer = new Buffer(c_api.TF_GetAllOpList()); + var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } + } + + return _registered_ops; + } + + public static OpDef GetOpDef(string type) + { + var ops = get_registered_ops(); + return ops[type]; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/random_seed.cs b/src/TensorFlowNET.Core/Framework/random_seed.cs new file mode 100644 index 000000000..ccc09fb25 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/random_seed.cs @@ -0,0 +1,87 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class random_seed + { + private static int DEFAULT_GRAPH_SEED = 87654321; + private static Dictionary _graph_to_seed_dict = new Dictionary(); + + public static (int?, int?) get_seed(int? op_seed = null) + { + int? global_seed; + + if (tf.executing_eagerly()) + global_seed = tf.Context.global_seed(); + else + global_seed = ops.get_default_graph().seed; + + if (global_seed.HasValue) + { + if (!op_seed.HasValue) + if (tf.executing_eagerly()) + op_seed = tf.Context.internal_operation_seed(); + else + { + if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) + seed = 0; + _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; + op_seed = seed; + } + + return (global_seed, op_seed); + } + + if (op_seed.HasValue) + return (DEFAULT_GRAPH_SEED, op_seed); + else + return (null, null); + } + + public static (Tensor, Tensor) get_seed_tensor(int? op_seed = null) + { + var (seed, seed2) = get_seed(op_seed); + Tensor _seed, _seed2; + if (seed is null) + _seed = constant_op.constant(0L, name: "seed"); + else + _seed = constant_op.constant((long)seed.Value, name: "seed"); + + if (seed2 is null) + _seed2 = constant_op.constant(0L, name: "seed2"); + else + { + _seed2 = tf_with(ops.name_scope("seed2"), scope => + { + _seed2 = constant_op.constant((long)seed2.Value); + return array_ops.where_v2( + math_ops.logical_and( + math_ops.equal(_seed, 0L), + math_ops.equal(_seed2, 0L)), + constant_op.constant(2^31L - 1), + _seed2, + name: scope); + }); + } + + return (_seed, _seed2); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs new file mode 100644 index 000000000..e1f84d7eb --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -0,0 +1,69 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Framework +{ + public class smart_module + { + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + var pred_value = smart_constant_value(pred); + if (pred_value.HasValue) + { + if (pred_value.Value) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + } + else + return control_flow_ops.cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + + public static Tensor smart_cond(bool pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return pred ? true_fn() : false_fn(); + } + + public static bool? smart_constant_value(Tensor pred) + { + var pred_value = tensor_util.constant_value(pred); + if (pred_value is null) + { + var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); + var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status); + if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK) + return null; + else + throw new NotImplementedException(""); + } + + return pred_value; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs new file mode 100644 index 000000000..b2cb45464 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -0,0 +1,79 @@ +using Tensorflow.NumPy; +using System; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Framework +{ + public static class tensor_shape + { + public static void assert_is_compatible_with(this Tensor self, Tensor other) + { + /*if (!self.is_compatible_with(other)) + { + var selfDim = self.shape + .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) + .Replace(", }", "}"); + + var otherDim = other.shape + .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) + .Replace(", }", "}"); + + throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible"); + }*/ + } + + public static bool is_compatible_with(this Tensor self, Tensor other) + { + bool _shape_is_compatible_0dim(Shape _this, Shape _other) + { + var __other = _other; + if (_this.dims == null || __other.dims == null) + return true; + + if (_this.ndim != __other.ndim) + return false; + + foreach (var (x_dim, y_dim) in _this.dims.Zip(__other.dims, (x_dim, y_dim) => (x_dim, y_dim))) + { + if (x_dim != y_dim) + return false; + } + + return true; + } + + if (other is SparseTensor) + { + return self.dtype.is_compatible_with(other.dtype); + } + + return self.dtype.is_compatible_with(other.dtype) && + _shape_is_compatible_0dim(self.shape, other.shape) && + !(self is SparseTensor); + } + + public static Dimension dimension_at_index(Shape shape, int index) + { + return shape.ndim < 0 ? + new Dimension(-1) : + new Dimension(shape.dims[index]); + } + + public static int dimension_value(Dimension dimension) + => (int)dimension.value; + + public static Shape most_specific_compatible_shape(this Shape self, Shape other) + { + var dims = range(self.ndim).Select(x => -1L).ToArray(); + foreach(var (i, (d1, d2)) in enumerate(zip(self.dims, other.dims))) + { + if (d1 == d2) + dims[i] = d1; + } + + return new Shape(dims); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/versions.cs b/src/TensorFlowNET.Core/Framework/versions.cs new file mode 100644 index 000000000..e91f08a2c --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/versions.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + public class versions + { + public static int GRAPH_DEF_VERSION = 1286; + public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; + } +} diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs new file mode 100644 index 000000000..8742e4535 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -0,0 +1,329 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Eager; +using Tensorflow.Framework.Models; +using Tensorflow.Gradients; +using Tensorflow.Graphs; +using Tensorflow.Train; +using Tensorflow.Util; +using Tensorflow.Common.Extensions; +using static Tensorflow.Binding; + +namespace Tensorflow.Functions +{ + /// + /// + /// + public class ConcreteFunction: Trackable + { + protected IEnumerable _captured_inputs; + protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; + protected Dictionary _attrs; + protected FunctionSpec _function_spec; + protected FunctionSpec _pre_initialized_function_spec = null; + protected EagerDefinedFunction _inference_function; + protected Dictionary _tape_functions_cache = new(); + internal FuncGraph func_graph; + internal ForwardBackwardCall forward_backward; + public Tensor[] Inputs => func_graph.Inputs; + public Tensor[] CapturedInputs => func_graph.external_captures; + + public string Name => _delayed_rewrite_functions.Forward().Name; + + public Tensor[] Outputs => func_graph.Outputs; + public Type ReturnType; + public TensorSpec[] OutputStructure; + public IEnumerable ArgKeywords { get; set; } + public long NumPositionArgs { get; set; } + public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; + public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; + public IEnumerable Variables => func_graph.Variables; + public IEnumerable TrainableVariables => func_graph.TrainableVariables; + internal NameAttrList AsNameAttrList + { + get + { + NameAttrList ret = new() { Name = this.Name }; + foreach (var (name, value) in _attrs) + { + ret.Attr[name] = value; + } + return ret; + } + } + + public ConcreteFunction(string name) + { + func_graph = new FuncGraph(name); + _captured_inputs = func_graph.external_captures; + _attrs= new Dictionary(); + _set_infer_function(); + } + + public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) + { + func_graph = graph; + _captured_inputs = func_graph.external_captures; + + //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); + _attrs = attrs; + _set_infer_function(); + } + + public ConcreteFunction(Func func, TF_DataType dtype) + { + string func_name = $"{func.Method.Name}_{ops.uid_function()}"; + + func_graph = new FuncGraph(func_name); + func_graph.as_default(); + var input = tf.placeholder(dtype); + var output = func(input); + + var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + func_graph.ToGraph(opers, + new[] { input }, + new[] { output }, + null); + func_graph.Exit(); + _captured_inputs = func_graph.external_captures; + _attrs = new Dictionary(); + _set_infer_function(); + } + + public ConcreteFunction(Func func, TF_DataType dtype) + { + string func_name = $"{func.Method.Name}_{ops.uid_function()}"; + + func_graph = new FuncGraph(func_name); + func_graph.as_default(); + + var input = tf.placeholder(dtype); + var output = func(input); + + OutputStructure = output.structure; + + var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + func_graph.ToGraph(opers, + new[] { input }, + new[] { output.variant_tensor }, + null); + func_graph.Exit(); + _captured_inputs = func_graph.external_captures; + _attrs = new Dictionary(); + _set_infer_function(); + } + + /*public ConcreteFunction(Func func, + TF_DataType[] dtypes, Shape[] shapes) + { + string func_name = $"{func.Method.Name}_{ops.uid_function()}"; + + // IntPtr func_handle; + func_graph = new FuncGraph(func_name); + func_graph.as_default(); + + var inputs = new Tensors(); + foreach(var (i, dtype) in enumerate(dtypes)) + inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args")); + Outputs = func(inputs); + OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray(); + + var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + func_graph.ToGraph(opers, inputs, Outputs, null); + func_graph.Exit(); + }*/ + + public void ToGraph(Tensors inputs, Tensors outputs) + { + var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + func_graph.ToGraph(opers, + inputs, + outputs, + null); + OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); + } + + public void Enter() + { + func_graph.as_default(); + } + + public void Exit() + { + func_graph.Exit(); + } + + public Tensors FilteredCall(Tensors inputs) + { + return CallFlat(inputs, CapturedInputs); + } + + /// + /// Executes the wrapped function. + /// + /// + /// + /// + public Tensors CallFlat(Tensor[] args, Tensor[] captured_inputs) + { + var executing_eagerly = tf.Context.executing_eagerly(); + var default_graph = ops.get_default_graph(); + // TODO(Rinne): deal with `default_graph.building_function` + + var tempvv = func_graph.Variables; + if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) + { + foreach(var v in this.func_graph.Variables) + { + resource_variable_ops.variable_accessed(v); + } + } + + var tensor_inputs = new Tensors(); + foreach (var (i, arg) in enumerate(args)) + { + tensor_inputs.Add(arg); + // If we're graph building, shape inference is on. + } + if (!executing_eagerly) + { + // TODO(Rinne): add the check + } + tensor_inputs.AddRange(captured_inputs); + + args = tensor_inputs.ToArray(); + + var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); + // No tape is watching; skip to running the function. + if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) + { + return _build_call_outputs(_inference_function.Call(args)); + } + + forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); + var (forward_function, args_with_tangents) = forward_backward.Forward(); + Tensors flat_outputs = null; + if (executing_eagerly) + { + flat_outputs = forward_function.Call(args_with_tangents); + } + else + { + tf_with(default_graph._override_gradient_function(new Dictionary>(){ + { "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } + }), _ => + { + flat_outputs = forward_function.Call(args_with_tangents); + }); + } + forward_backward.Record(flat_outputs); + return _build_call_outputs(flat_outputs); + } + + public void AddTograph(Graph? g = null) + { + if(!tf.Context.executing_eagerly() && g is null) + { + g = ops.get_default_graph(); + } + _delayed_rewrite_functions.Forward().AddToGraph(g); + } + + public void SetExternalCaptures(IEnumerable captures) + { + _captured_inputs = captures; + } + + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) + { + TangentInfo input_tangents; + if (executing_eagerly) + { + // TODO(Rinne): check if it needs to be implemented. + input_tangents = new TangentInfo(); + } + else + { + input_tangents = new TangentInfo(); + } + if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) + { + if(input_tangents.Indices is not null || executing_eagerly) + { + string cache_key = "first_order"; + if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) + { + functions = new FirstOrderTapeGradientFunctions(func_graph, false); + _tape_functions_cache[cache_key] = functions; + } + return new ForwardBackwardCall(functions, args, tape_watching: true); + } + else + { + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); + } + } + else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) + { + throw new NotImplementedException(); + } + + // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); + } + + internal void set_variables(IEnumerable variables) + { + func_graph.Variables = variables; + } + + internal void _set_infer_function() + { + _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); + _inference_function = _delayed_rewrite_functions.Forward(); + } + + internal void _set_function_spec(FunctionSpec spec) + { + _function_spec = null; + _pre_initialized_function_spec = spec; + _initialize_function_spec(); + } + + internal void _initialize_function_spec() + { + if(_pre_initialized_function_spec is null) + { + return; + } + Debug.Assert(_function_spec is null, "already initialized"); + var spec = _pre_initialized_function_spec; + //var args = spec.Fullargspec.DictValue.Fields["args"]; + // TODO(Rinne): self.structured_input_signature + + _function_spec = new FunctionSpec() + { + Fullargspec = spec.Fullargspec, + IsMethod = spec.IsMethod, + InputSignature = spec.InputSignature + }; + } + + internal Func _get_gradient_function() + { + return _delayed_rewrite_functions._rewrite_forward_and_call_backward; + } + + private Tensors _build_call_outputs(Tensors result) + { + // TODO(Rinne): deal with `func_graph.structured_outputs` + + return result; + } + + public override string ToString() + => Name; + } +} diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs new file mode 100644 index 000000000..d547b6120 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -0,0 +1,232 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Graphs; +using Tensorflow.Operations; +using Tensorflow.Util; +using Tensorflow.Common.Extensions; +using static Tensorflow.Binding; +using Tensorflow.Framework; +using System.Buffers; +using Tensorflow.Gradients; + +namespace Tensorflow.Functions +{ + public class EagerDefinedFunction: IDisposable + { + public int _num_outputs; + FuncGraph _graph; + FunctionDef _definition; + OpDef _signature; + string _name; + internal ScopedTFFunction _c_func; + internal Tensor[] _func_graph_outputs; + internal string _grad_func_name; + internal Func csharp_grad_func; + internal EagerDefinedFunction _grad_func; + internal bool _registered_on_context = false; + public string Name => _name; + public DataType[] OutputTypes { get; protected set; } + public Shape[] OutputShapes { get; protected set; } + public FunctionDef Definition + { + get + { + if(_definition is null) + { + _definition = _get_definition(); + } + return _definition; + } + } + + public OpDef Signature + { + get + { + if( _signature is null) + { + _signature = Definition.Signature; + } + return _signature; + } + } + public unsafe EagerDefinedFunction(string name, FuncGraph graph, + Tensors inputs, Tensors outputs, + Dictionary attrs) + { + var input_ops = inputs.Select(x => x.op).ToArray(); + var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) + .Select(x => x as Operation).ToArray(); + var graph_output_names = graph._output_names; + string[] output_names; + if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) + { + output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); + if(output_names.Distinct().Count() != output_names.Length) + { + output_names = new string[0]; + } + } + else + { + output_names = new string[0]; + } + + Status status = new Status(); + var fn = c_api.TF_GraphToFunction(graph.c_graph, + name, + false, + operations.Length, + operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), + inputs.Length, + inputs.Select(t => t._as_tf_output()).ToArray(), + outputs.Length, + outputs.Select(t => t._as_tf_output()).ToArray(), + output_names.Length != outputs.Length ? null : output_names, + IntPtr.Zero, // warning: the control output hasbben totally ignored. + null, + status); + status.Check(true); + + _c_func = new ScopedTFFunction(fn, name); + + foreach(var (attr_name, attr_value) in attrs) + { + var serialized = attr_value.ToByteArray(); + c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); + status.Check(true); + } + + var signature = _get_definition().Signature; + _name = signature.Name; + tf_with(ops.init_scope(), s => + { + tf.Context.add_function(fn); + _registered_on_context = true; + }); + + _num_outputs = signature.OutputArg.Count; + OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); + OutputShapes = outputs.Select(x => x.shape).ToArray(); + _func_graph_outputs = new List(outputs).ToArray(); + csharp_grad_func = null; + _graph = graph; + } + + public unsafe Tensors Call(Tensors args) + { + // TODO(Rinne): Add arg `CancellationManager`. + // TODO(Rinne): Check the arg length. + var function_call_options = tf.Context.FunctionCallOptions; + string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. + + //if (function_call_options.config_proto_serialized().Length == 0) + //{ + // config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); + //} + //else + //{ + // config = function_call_options.config_proto_serialized().ToStringUtf8(); + //} + + string executor_type = function_call_options.ExecutorType ?? ""; + var executing_eagerly = tf.Context.executing_eagerly(); + + var attrs = new object[] + { + "executor_type", executor_type, + "config_proto", config + }; + + Tensor[] outputs; + if (executing_eagerly) + { + outputs = _execute.execute( + Signature.Name, + _num_outputs, + args, + attrs, + tf.Context); + } + else + { + if(tf.GetTapeSet().Count == 0) + { + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + } + else + { + var tape = tf.GetTapeSet().Peek(); + tape.StopRecord(); + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + tape.StartRecord(); + } + } + foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) + { + handle_data_util.copy_handle_data(func_graph_output, outputs[i]); + } + if (executing_eagerly) + { + return outputs; + } + else + { + foreach(var (i, shape) in enumerate(OutputShapes)) + { + outputs[i].shape = shape; + } + return outputs; + } + } + + public void AddToGraph(Graph g = null) + { + if(g is null && tf.Context.executing_eagerly()) + { + var ctx = tf.Context; + if (!ctx.has_function(this.Name)) + { + ctx.add_function_def(Definition); + } + } + else + { + if (!g.IsFunction(Name)) + { + g.AddFunction(this); + } + foreach(var f in _graph.Functions.Values) + { + if (!g.IsFunction(f.Name)) + { + g.AddFunction(f); + } + } + } + } + + private FunctionDef _get_definition() + { + var buffer = c_api_util.tf_buffer(); + Status status = new(); + c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); + status.Check(true); + var proto_data = c_api.TF_GetBuffer(buffer); + return FunctionDef.Parser.ParseFrom(proto_data.AsSpan()); + } + + public void Dispose() + { + tf.Context.remove_function(Name); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs new file mode 100644 index 000000000..bfb0defcb --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class FirstOrderTapeGradientFunctions : TapeGradientFunctions + { + public FirstOrderTapeGradientFunctions(FuncGraph func_graph, + bool need_gradients_for_jvps) : base(func_graph, + need_gradients_for_jvps) + { + + } + + public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + ForwardAndBackwardFunctions(Tensors inference_args) + { + var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); + return BuildFunctionsForOutputs(outputs, inference_args); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs new file mode 100644 index 000000000..392c06951 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Functions +{ + /// + /// Holds the state of a function call between execution and recording. + /// + public class ForwardBackwardCall + { + TapeGradientFunctions _functions; + Tensors _inference_args; + Tensors _input_tangents; + bool _tape_watching; + EagerDefinedFunction forward_function; + + public ForwardBackwardCall(TapeGradientFunctions functions, + Tensors inference_args, + bool tape_watching) + { + _functions = functions; + _inference_args = inference_args; + _tape_watching = tape_watching; + } + + public (EagerDefinedFunction, Tensors) Forward() + { + if (forward_function == null) + forward_function = _functions.Forward(_inference_args); + return (forward_function, _inference_args); + } + + public void Record(Tensors flat_outputs) + { + if (_tape_watching && flat_outputs != null) + _functions.Record(flat_outputs, _inference_args); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs new file mode 100644 index 000000000..e301048a8 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -0,0 +1,84 @@ +using System; +using Tensorflow.Functions; +using Tensorflow.Train; + +namespace Tensorflow +{ + public class Function: Trackable, IGenericFunction + { +#pragma warning disable CS0169 // The field 'Function._handle' is never used + private IntPtr _handle; +#pragma warning restore CS0169 // The field 'Function._handle' is never used + + protected Func _csharp_function; + protected ConcreteFunction _concrete_variable_creation_fn; + protected bool _autograph; + protected TracingCompiler _variable_creation_fn; + public string Name { get; set; } + public Function(Func csharp_function, + string name, bool auto_graph = true) + { + _csharp_function = csharp_function; + Name = name; + _autograph = auto_graph; + } + + public virtual Tensors Apply(Tensors inputs) + { + if (_run_functions_eagerly()) + { + return _csharp_function(inputs); + } + + var result = _call(inputs); + return result; + } + + public ConcreteFunction get_concrete_function(params Tensor[] args) + { + return _get_concrete_function_garbage_collected(args); + } + + protected virtual Tensors _call(Tensors inputs) + { + if(_variable_creation_fn is not null) + { + return _variable_creation_fn.Apply(inputs); + } + _initialize(inputs); + + return _concrete_variable_creation_fn.CallFlat(inputs, + _concrete_variable_creation_fn.CapturedInputs); + } + + protected TracingCompiler _compiler(Func fn) + { + var name = nameof(fn); + return new TracingCompiler(fn, name, autograph: _autograph); + } + + protected virtual bool _run_functions_eagerly() + { + return false; + } + + protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) + { + if(_variable_creation_fn is null) + { + _initialize(args); + // TODO(Rinne): _initialize_uninitialized_variables + } + + var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); + return concrete; + } + + private void _initialize(Tensor[] args) + { + _variable_creation_fn = _compiler(_csharp_function); + _variable_creation_fn._name = this.Name; + _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/IGenericFunction.cs b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs new file mode 100644 index 000000000..f046731bf --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Functions +{ + public interface IGenericFunction + { + Tensors Apply(Tensors args); + ConcreteFunction get_concrete_function(params Tensor[] args); + } +} diff --git a/src/TensorFlowNET.Core/Functions/TF_Function.cs b/src/TensorFlowNET.Core/Functions/TF_Function.cs new file mode 100644 index 000000000..a63dedcce --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TF_Function.cs @@ -0,0 +1,10 @@ +using System.Runtime.InteropServices; + +namespace Tensorflow.Functions +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_Function + { + FunctionDef fdef; + } +} diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs new file mode 100644 index 000000000..3895226ef --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -0,0 +1,253 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Eager; +using Tensorflow.Gradients; +using Tensorflow.Graphs; +using Tensorflow.NumPy; +using Tensorflow.Operations; +using static Tensorflow.Binding; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Functions +{ + /// + /// Caches forward and backward functions compatible with eager gradients. + /// + public abstract class TapeGradientFunctions + { + protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + protected string _FORWARD_PREFIX = "__forward_"; + protected string _BACKWARD_PREFIX = "__backward_"; + protected string _INFERENCE_PREFIX = "__inference_"; + + protected FuncGraph _func_graph; + protected EagerDefinedFunction _forward; + protected FuncGraph _forward_graph; + protected List _forwardprop_input_indices; + protected List _forwardprop_output_indices; + protected int _num_forwardprop_outputs; + protected int _num_inference_outputs; + protected int _num_outputs; + protected int _num_trainable_inference_outputs; + protected ConcreteFunction _backward; + BackwardFunction _backward_function_wrapper; + + public TapeGradientFunctions(FuncGraph func_graph, + bool need_gradients_for_jvps) + { + _func_graph = func_graph; + _forward_graph = null; + _forward = null; + _backward = null; + _num_outputs = func_graph.Outputs.Length; + _forwardprop_output_indices = null; + _num_forwardprop_outputs = 0; + _num_inference_outputs = func_graph.Outputs.Length; + _num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); + } + + public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) + { + // TODO(Rinne): add input_tangents arg. + if(_forward is null) + { + (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) + = ForwardAndBackwardFunctions(inference_args); + } + return _forward; + } + + /// + /// Record the function call operation. + /// + /// + /// + public virtual void Record(Tensors flat_outputs, Tensors inference_args) + { + // TODO(Rinne): add arg `input_tagents`. + var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); + if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) + { + // TODO(Rinne): implement it. + throw new NotImplementedException(); + } + tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); + } + + /// + /// Create a backward function given `outputs` from the forward function. + /// + /// + /// + /// + /// + (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) + { + var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) + .ToDictionary(x => x.Item1, x => x.Item2); + var captured_inputs = backward.CapturedInputs; + var remapped_captures = captured_inputs.Select(c => + { + if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) + { + return value; + } + else + { + return c; + } + }).ToArray(); + if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) + { + var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); + throw new RuntimeError($"Failed to map all backward graph captures to " + + $"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); + } + + Dictionary variant_zeros_like = new Dictionary(); + var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; + var recorded_outputs = new Tensors(); + int trainable_recorded_outputs = 0; + var skip_positions = new HashSet(); + var relevant_outputs = outputs; + foreach (var (output_index, output) in enumerate(relevant_outputs)) + { + if (trainable_recorded_outputs < backward_function_inputs) + recorded_outputs.Add(output); + if (backprop_util.IsTrainable(output)) + trainable_recorded_outputs++; + else + skip_positions.Add(output_index); + if (output.dtype == dtypes.variant) + variant_zeros_like[output_index] = default_gradient.zeros_like(output); + } + + _backward_function_wrapper = (args, unneeded_gradients) => + { + if(backward.Outputs is null || backward.Outputs.Length == 0) + { + return backward.FlatStructuredOutputs; + } + + var processed_args = new Tensors(); + int input_index = 0; + foreach (var (output_index, arg) in enumerate(args)) + { + if (skip_positions.Contains(output_index)) + continue; + if (arg is null) + { + var input_placeholder = backward.Inputs[input_index]; + Tensor variant_arg; + if (input_placeholder.dtype == dtypes.variant) + { + variant_arg = variant_zeros_like[output_index]; + } + else + { + var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); + + variant_arg = array_ops.zeros(shape, type); + } + processed_args.Add(variant_arg); + } + else + { + processed_args.Add(arg); + } + input_index++; + if (input_index >= backward_function_inputs) + break; + } + + tf.Logger.Debug($"Invoke backward function: {backward.Name}"); + var gradients = backward.CallFlat(processed_args, remapped_captures); + + foreach (var unneeded_gradient_index in unneeded_gradients) + { + var index = Convert.ToInt32(unneeded_gradient_index); + if (gradients.Length <= index) + gradients.Insert(index, null); + } + + return gradients; + }; + + return (_backward_function_wrapper, recorded_outputs); + } + + protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) + { + var trainable_outputs = new List(); + var trainable_indices = new List(); + foreach(var (index, output) in enumerate(outputs)) + { + if (backprop_util.IsTrainable(output)) + { + trainable_outputs.Add(output); + trainable_indices.Add(index); + } + } + + var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); + backwards_graph.as_default(); + var gradients_wrt_outputs = new List(); + foreach (var output in trainable_outputs) + { + var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); + var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); + gradients_wrt_outputs.Add(gradient_placeholder); + handle_data_util.copy_handle_data(output, gradient_placeholder); + } + // TODO(Rinne): with ops.device(None) + var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), + _func_graph.Inputs, + grad_ys: gradients_wrt_outputs.ToArray(), + src_graph: _func_graph); + + var captures_from_forward = backwards_graph.external_captures + .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) + .ToArray(); + HashSet existing_outputs = new(_func_graph.Outputs); + foreach(var capture in captures_from_forward) + { + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); + _func_graph.Outputs.Add(capture); + } + } + backwards_graph.Exit(); + + backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); + backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); + + var (wrapped_forward_function, wrapped_backward_function) = + monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); + //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; + //var backward_function_attr = new Dictionary(); + //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; + + //var backward_function = new ConcreteFunction(backwards_graph, + // monomorphic_function_utils._parse_func_attrs(backward_function_attr)); + + //var forward_function_attr = new Dictionary(); + //forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; + //var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, + // _func_graph.Inputs, _func_graph.Outputs, + // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); + + return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); + } + + public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) + ForwardAndBackwardFunctions(Tensors inference_args) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs new file mode 100644 index 000000000..aa30c9f79 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class TracingCompiler + { + Func _csharp_function; + //FunctionSpec _function_spec; + internal string _name; + bool _autograph; + Dictionary _function_cache; + Dictionary _function_attributes; + int _tracing_count; + + public TracingCompiler(Func csharp_function, string name, object? input_signatures = null, + Dictionary attributes = null, bool autograph = true, object? autograph_options = null, + bool reduce_retracing = false, bool capture_by_value = false) + { + _csharp_function = csharp_function; + bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); + _name = name; + _autograph = autograph; + _function_attributes = attributes ?? new Dictionary(); + _function_cache = new Dictionary(); + _tracing_count = 0; + } + + public Tensor[] Apply(Tensor[] inputs) + { + // TODO(Rinne): add lock here. + var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); + return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); + } + + internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) + { + var (concrete_function, _) = _maybe_define_concrete_function(args); + return concrete_function; + } + + private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) + { + return _maybe_define_function(args); + } + + private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) + { + var lookup_func_key = make_cache_key(args); + if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) + { + return (concrete_function, args); + } + concrete_function = _create_concrete_function(args); + _function_cache[lookup_func_key] = concrete_function; + return (concrete_function, args); + } + + private ConcreteFunction _create_concrete_function(Tensor[] args) + { + _tracing_count++; + + int arglen = args.Length; + var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( + _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), + args, new Dictionary(), autograph: _autograph + ), _function_attributes); + return concrete_function; + } + + private static string make_cache_key(Tensor[] inputs) + { + //string res = ""; + //foreach (var input in inputs) + //{ + // res += $"{input.name}_{input.Id}"; + //} + return inputs.Length.ToString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs new file mode 100644 index 000000000..04d102b5f --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -0,0 +1,63 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; +using Tensorflow.Functions; + +namespace Tensorflow +{ + public partial class c_api + { + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteFunction(IntPtr handle); + + /// + /// Write out a serialized representation of `func` (as a FunctionDef protocol + /// message) to `output_func_def` (allocated by TF_NewBuffer()). + /// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() + /// is called. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name, + bool append_hash_to_fn_name, + int num_opers, IntPtr[] opers, + int ninputs, TF_Output[] inputs, + int noutputs, TF_Output[] outputs, + string[] output_names, + IntPtr opts, + string description, + SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); + + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs new file mode 100644 index 000000000..7994bef11 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Util; + +namespace Tensorflow.Functions +{ + internal static class composite_tensor_utils + { + public static List flatten_with_variables(object inputs) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(inputs)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + public static List flatten_with_variables_or_variable_specs(object arg) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(arg)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + // TODO(Rinne): deal with `VariableSpec`. + else if(value is TypeSpec type_spec && value is not TensorSpec) + { + throw new NotImplementedException("The TypeSpec has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs new file mode 100644 index 000000000..b3caef96c --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using Tensorflow.Train; +using Tensorflow.Variables; +using static Tensorflow.Binding; + +namespace Tensorflow.Functions +{ + public static class function_saved_model_utils + { + /// + /// + /// + /// + /// a list tensors or other objects (such as variables) which + /// contain tensors that were originally captured by the function + public static void restore_captures(ConcreteFunction concrete_function, IEnumerable inputs) + { + var bound_inputs = inputs?.Select(obj => + { + if(obj is Tensor tensor) + { + return get_tensor_from_node(tensor); + } + else if(obj is IVariableV1 variable) + { + return get_tensor_from_node(variable); + } + else + { + throw new TypeError("Encountered an type error, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + }); + var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); + + List captured_inputs_list = new(); + concrete_function.set_variables(bound_variables); + + if (bound_inputs is not null) + { + foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) + { + if(hasattr(bound_input, "__tf_experimental_restore_capture__")) + { + throw new NotImplementedException(); + } + else + { + captured_inputs_list.Add(bound_input); + concrete_function.func_graph.replace_capture(bound_input, internal_capture); + if(internal_capture.dtype == dtypes.resource) + { + if (resource_variable_ops.is_resource_variable(bound_input)) + { + handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); + } + else + { + handle_data_util.copy_handle_data(bound_input, internal_capture); + } + } + concrete_function.func_graph.capture(bound_input); + } + } + } + + if(captured_inputs_list.Any(inp => inp is null)) + { + // TODO(Rinne): add warnings. + } + concrete_function.SetExternalCaptures(captured_inputs_list); + } + + public static Tensor get_tensor_from_node(Tensor node) + { + return node; + } + public static Tensor get_tensor_from_node(IVariableV1 node) + { + if (resource_variable_ops.is_resource_variable(node)) + { + return node.Handle; + } + else + { + throw new TypeError("Encountered an type error, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs new file mode 100644 index 000000000..7cb5c7050 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -0,0 +1,282 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; +using Tensorflow.Framework.Models; +using Tensorflow.Gradients; +using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; +using Tensorflow.Operations; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using System.Diagnostics; + +namespace Tensorflow.Functions +{ + internal static class monomorphic_function_utils + { + internal static string _FORWARD_PREFIX = "__forward_"; + internal static string _BACKWARD_PREFIX = "__backward_"; + internal static string _INFERENCE_PREFIX = "__inference_"; + internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; + internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + public static string _inference_name(string name) + { + return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + } + public static string _forward_name(string name) + { + return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; + } + public static string _backward_name(string name) + { + return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; + } + + public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary attrs, + FuncGraph forward_graph, FuncGraph backwards_graph) + { + string forward_function_name = _forward_name(forward_graph.Name); + Dictionary common_attributes; + if(attrs is null) + { + common_attributes = new Dictionary(); + } + else + { + common_attributes = new Dictionary(attrs); + } + + if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) + { + common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); + } + var backward_function_attr = _parse_func_attrs(new Dictionary() + { + {FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } + }); + backward_function_attr.Update(common_attributes); + var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + var forward_function_attr = _parse_func_attrs(new Dictionary() + { + {BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } + }); + forward_function_attr.Update(common_attributes); + var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, + forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); + return (forward_function, backward_function); + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach(var item in attributes) + { + var key = item.Key; + var value = item.Value; + if (value is AttrValue attr_value) + { + attrs[key] = attr_value; + } + else if (value is bool b) + { + attrs[key] = new AttrValue() { B = b }; + } + else if (value is int i) + { + attrs[key] = new AttrValue() { I = i }; + } + else if (value is float f) + { + attrs[key] = new AttrValue() { F = f }; + } + else if(value is string s) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; + } + else if (value is byte[] bytes) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; + } + else + { + throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + + $"AttrValue. Got {value.GetType()}."); + } + } + return attrs; + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach (var item in attributes) + { + var key = item.Key; + var value = item.Value; + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; + } + return attrs; + } + } + public class DelayedRewriteGradientFunctions : TapeGradientFunctions + { + EagerDefinedFunction _inference_function; + Dictionary _attrs; + int _num_inference_outputs; + Dictionary _cached_function_pairs = new(); + public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) + : base(func_graph, false) + { + _func_graph = func_graph; + _inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), + _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); + _attrs = attrs; + _num_inference_outputs = _func_graph.Outputs.Length; + } + + public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) + { + if (input_tangents is not null) + { + throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + + $"a class that does not support forwardprop."); + } + return _inference_function; + } + + public override void Record(Tensors flat_outputs, Tensors inference_args) + { + var (backward_function, to_record) = _backward(flat_outputs); + foreach(var tape in tf.GetTapeSet()) + { + tape.RecordOperation(_inference_function.Signature.Name, to_record, + inference_args, backward_function); + } + } + + public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) + { + if(num_doutputs == -2) + { + num_doutputs = _num_inference_outputs; + } + if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) + { + return target; + } + var (forward, backward) = _construct_forward_backward(num_doutputs); + _cached_function_pairs[num_doutputs] = (forward, backward); + return (forward, backward); + + } + + private (BackwardFunction, Tensors) _backward(Tensors outputs) + { + Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) + { + var call_op = outputs[0].op; + return _rewrite_forward_and_call_backward(call_op, args); + } + return (backward_function, outputs); + } + + internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) + { + var (forward_function, backward_function) = forward_backward(doutputs.Length); + if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) + { + return backward_function.FlatStructuredOutputs; + } + forward_function.AddToGraph(op.graph); + + op._set_func_attr("f", forward_function.Name); + op._set_type_list_attr("Tout", forward_function.OutputTypes); + op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). + Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() + ); + for(int i = 0; i < op.outputs.Length; i++) + { + var func_graph_output = forward_function._func_graph_outputs[i]; + handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); + } + + var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). + ToDictionary(x => x.Item1, x => x.Item2); + var remapped_captures = backward_function.CapturedInputs.Select( + x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) + ); + + List cleaned_doutputs = new(); + foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) + { + if (backprop_util.IsTrainable(placeholder)) + { + if(doutput is IndexedSlices) + { + cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); + } + else if(doutput is null) + { + cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); + } + else if(doutput is Tensor tensor) + { + cleaned_doutputs.Add(tensor); + } + else + { + throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); + } + } + } + + return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); + } + + private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) + { + var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); + + List signature = new(); + foreach(var t in trainable_outputs) + { + var (shape, dtype) = default_gradient.shape_and_dtype(t); + signature.Add(new TensorSpec(shape, dtype)); + } + + Tensor[] _backprop_function(Tensor[] grad_ys) + { + return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, + grad_ys, src_graph: _func_graph); + } + + _func_graph.as_default(); + FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); + FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => + { + Debug.Assert(y is Tensor); + return (Tensor)y; + }).ToArray()), new object[0], new Dictionary(), signature.ToArray(), backwards_graph); + var backwards_graph_captures = backwards_graph.external_captures; + var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); + + HashSet existing_outputs = new HashSet(_func_graph.Outputs); + foreach(var capture in captures_from_forward) + { + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); + _func_graph.Outputs.Add(capture); + } + } + + var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( + _attrs, _func_graph, backwards_graph); + _func_graph.Exit(); + return (forward_function, backward_function); + } + } +} diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs new file mode 100644 index 000000000..7e02c9083 --- /dev/null +++ b/src/TensorFlowNET.Core/GlobalUsing.cs @@ -0,0 +1,9 @@ +global using System; +global using System.Collections.Generic; +global using System.Text; +global using System.Collections; +global using System.Data; +global using System.Linq; +global using Tensorflow.Keras.Engine; +global using Tensorflow.Framework.Models; +global using static Tensorflow.Binding; \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Gradients/AccumulatorCallState.cs b/src/TensorFlowNET.Core/Gradients/AccumulatorCallState.cs new file mode 100644 index 000000000..1806a455d --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/AccumulatorCallState.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Gradients +{ + public class AccumulatorCallState + { + GradientTape backward_tape; + bool accumulating; + + public AccumulatorCallState(GradientTape backward_tape, bool accumulating) + { + this.backward_tape = backward_tape; + this.accumulating = accumulating; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/AggregationMethod.cs b/src/TensorFlowNET.Core/Gradients/AggregationMethod.cs new file mode 100644 index 000000000..6d3414f10 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/AggregationMethod.cs @@ -0,0 +1,11 @@ +namespace Tensorflow +{ + public class AggregationMethod + { + public static int ADD_N = 0; + public static int DEFAULT = ADD_N; + // The following are experimental and may not be supported in future releases. + public static int EXPERIMENTAL_TREE = 1; + public static int EXPERIMENTAL_ACCUMULATE_N = 2; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs new file mode 100644 index 000000000..743ed0d8e --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs @@ -0,0 +1,26 @@ +using Tensorflow.Util; + +namespace Tensorflow.Gradients +{ + public class BackpropInitialState + { + public OpTape op_tape { get; set; } + /// + /// Map from tensor to how many references still exist for this tensor in + /// the tape. + /// + public UnorderedMap tensor_usage_counts { get; set; } + /// + /// Maps from op ID to how many output tensors of this op still need to have + /// their gradients computed. + /// + public UnorderedMap op_missing_tensor { get; set; } + + public BackpropInitialState() + { + op_tape = new OpTape(); + tensor_usage_counts = new UnorderedMap(); + op_missing_tensor = new UnorderedMap(); + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs new file mode 100644 index 000000000..a714436a3 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -0,0 +1,161 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + /// + /// Gradient Tape Set + /// Record operations for automatic differentiation. + /// + /// Operations are recorded if they are executed within this context manager and + /// at least one of their inputs is being "watched". + /// + /// Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, + /// where `trainable=True` is default in both cases) are automatically watched. + /// Tensors can be manually watched by invoking the `watch` method on this context + /// manager. + /// + public class GradientTape : IDisposable + { + int _nextTapeId; + ITape _tape => _tapeSet.Peek(); + Stack _tapeSet; + + public GradientTape() + { + _tapeSet = new Stack(); + } + + /// + /// New tape onto the tape stack. + /// + public ITape PushTape(bool persistent = false, + bool watch_accessed_variables = true) + { + // Enters a context inside which operations are recorded on this tape. + if (tf.Context.executing_eagerly()) + tf.Context.ensure_initialized(); + + var tape = new Tape(persistent, watch_accessed_variables); + tape.SetTapeId(_nextTapeId++); + _tapeSet.Push(tape); + return tape; + } + + public void PushTape(ITape tape) + { + // Enters a context inside which operations are recorded on this tape. + if (tf.Context.executing_eagerly()) + tf.Context.ensure_initialized(); + + _tapeSet.Push(tape); + } + + ITape PopTape() + { + _tape.StopRecord(); + return _tapeSet.Pop(); + } + + /// + /// Marks this tensor to be watched by the given tape. + /// + /// + public void watch(Tensor x) + { + if (!_tapeSet.Any()) + return; + _tape.Watch(x); + } + + /// + /// Computes the gradient using operations recorded in context of this tape. + /// + /// + /// + /// + public Tensor gradient(Tensor target, Tensor source, List output_gradients = null, + string unconnected_gradients = null) + { + if(_tape is null) + { + throw new RuntimeError("A non-persistent GradientTape can only be used to " + + "compute one set of gradients (or jacobians)."); + } + + ITape tape = stop_recording(); + + var results = tf.Runner.TFE_TapeGradient(tape, + new[] { target }, + new[] { source }, + output_gradients, + new[] { source }, + unconnected_gradients); + + return results[0]; + } + + public Tensor gradient(Tensor target, ResourceVariable source, List output_gradients = null, + string unconnected_gradients = null) + { + var results = gradient(target, new List { source }, output_gradients, unconnected_gradients); + + return results[0]; + } + + public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List output_gradients = null, + string unconnected_gradients = null) + { + var results = gradient(target, new List { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); + + return (results[0], results[1]); + } + + public Tensor[] gradient(Tensor target, IEnumerable sources, List output_gradients = null, + string unconnected_gradients = null) + { + if (_tape is null) + { + throw new RuntimeError("A non-persistent GradientTape can only be used to " + + "compute one set of gradients (or jacobians)."); + } + var tape = stop_recording(); + + var results = tf.Runner.TFE_TapeGradient(tape, + new[] { target }, + sources.Select(x => x.Handle).ToArray(), + output_gradients, + sources.Select(x => x.Handle).ToArray(), + unconnected_gradients); + + if (!tape.Persistent) + { + // Keep track of watched variables before setting tape to None + // _watched_variables = _tape.WatchedVariables(); + } + + return results; + } + + /// + /// Temporarily stops recording operations on this tape. + /// + public ITape stop_recording() + { + var tape = _tape; + if (!tape.Persistent) + tape = PopTape(); + return tape; + } + + public Stack GetTapeSet() + => _tapeSet; + + public void Dispose() + { + _tapeSet.Clear(); + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/ITape.cs b/src/TensorFlowNET.Core/Gradients/ITape.cs new file mode 100644 index 000000000..07594dabd --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/ITape.cs @@ -0,0 +1,36 @@ +using System; +using Tensorflow.Util; + +namespace Tensorflow.Gradients +{ + public interface ITape + { + void SetTapeId(int id); + bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); + void StartRecord(); + void StopRecord(); + bool Persistent { get; } + void RecordOperation(string op_type, + TapeTensor[] output_tensors, + long[] input_tensor_id, + TF_DataType[] input_dtypes, + BackwardFunction backward_function); + + void RecordOperation(string op_type, + Tensor[] outputs, + Tensor[] inputs, + BackwardFunction backward_function); + + void VariableAccessed(IVariableV1 variable); + + void Watch(Tensor x); + + IVariableV1[] WatchedVariables(); + + Tensor[] ComputeGradient(long[] target_tensor_ids, + long[] source_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, + bool build_default_zeros_grads); + } +} diff --git a/src/TensorFlowNET.Core/Gradients/OpTape.cs b/src/TensorFlowNET.Core/Gradients/OpTape.cs new file mode 100644 index 000000000..cb9d0de73 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/OpTape.cs @@ -0,0 +1,12 @@ +using Tensorflow.Util; + +namespace Tensorflow.Gradients +{ + /// + /// Map from operation-id to tape entry. + /// + public class OpTape : UnorderedMap + { + + } +} diff --git a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs new file mode 100644 index 000000000..7665fa017 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs @@ -0,0 +1,17 @@ +using System.Linq; + +namespace Tensorflow.Gradients +{ + /// + /// Represents an entry in the tape. + /// + public class OpTapeEntry + { + public string op_type { get; set; } + public TapeTensor[] output_tensor_info { get; set; } + public long[] input_tensor_id { get; set; } + public BackwardFunction backward_function { get; set; } + public override string ToString() + => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs new file mode 100644 index 000000000..08a67373c --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterGradient.cs @@ -0,0 +1,30 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Gradients +{ + public class RegisterGradient : Attribute + { + public string Name { get; set; } + + public RegisterGradient(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/RegisterGradientEager.cs b/src/TensorFlowNET.Core/Gradients/RegisterGradientEager.cs new file mode 100644 index 000000000..0c6217509 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterGradientEager.cs @@ -0,0 +1,30 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Gradients +{ + public class RegisterGradientEager : Attribute + { + public string Name { get; set; } + + public RegisterGradientEager(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs new file mode 100644 index 000000000..d573e317e --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Gradients +{ + /// + /// REGISTER_NO_GRADIENT_OP(""); + /// + public class RegisterNoGradient : Attribute + { + public string Name { get; set; } + + public RegisterNoGradient(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs b/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs new file mode 100644 index 000000000..9dc1b6662 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs @@ -0,0 +1,18 @@ +using System.Collections.Generic; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + public Tensor[] CallBackwardFunction(BackwardFunction backward_function, + List unneeded_gradients, + List output_gradients) + { + // var grads = new Tensor[output_gradients.Count]; + var result = backward_function(output_gradients.ToArray(), + unneeded_gradients.ToArray()); + + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs new file mode 100644 index 000000000..8a4a41f62 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -0,0 +1,284 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + static readonly int kMinAggregateCount = 4; + static readonly int kMinAggregateBytes = 128 * 1024 * 1024; + private static UnorderedMap> _functionsAcceptingNoneForIndicesMap; + + static Tape() + { + _functionsAcceptingNoneForIndicesMap = new(); + _functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + _functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + _functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 })); + } + + public Tensor[] ComputeGradient(long[] target_tensor_ids, + long[] source_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, + bool build_default_zeros_grads) + { + UnorderedSet sources_set = new(source_tensor_ids); + BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); + var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); + var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); + UnorderedMap gradients_size = new(); + while(op_stack.Count > 0) + { + long op = op_stack.Dequeue(); + if(!state.op_tape.TryGetValue(op, out var op_it)) + { + continue; + } + var trace = op_it; + state.op_tape.erase(op); + List out_gradients = new(); + List unneeded_gradients = new(); + for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) + { + long in_tensor_id = trace.input_tensor_id[i]; + if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) + { + unneeded_gradients.Add(i); + } + } + + bool any_gradient_nonzero = false; + List zero_indices = new(); + for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) + { + long id = trace.output_tensor_info[i].GetID(); + if(!gradients.TryGetValue(id, out var grad_it)) + { + out_gradients.Add(null); + if (build_default_zeros_grads) + { + if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || + !func_name_it.find(i)) + { + zero_indices.Add(i); + } + } + } + else + { + any_gradient_nonzero = true; + Tensor new_gradients; + if (grad_it.Count == 1) + { + new_gradients = grad_it[0]; + } + else + { + new_gradients = AggregateGradients(grad_it); + } + if (!sources_set.find(id)) + { + gradients.Remove(id); + } + else + { + grad_it.Clear(); + grad_it.Add(new_gradients); + // MarkAsResult + } + out_gradients.Add(new_gradients); + } + } + + Tensor[] in_gradients = new Tensor[0]; + if (any_gradient_nonzero) + { + foreach(var i in zero_indices) + { + out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); + } + in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); + } + else + { + out_gradients.Clear(); + } + + for(int i = 0, end = in_gradients.Length; i < end; i++) + { + long id = trace.input_tensor_id[i]; + if (in_gradients[i] is not null) + { + var unaggregated_grads = gradients.SetDefault(id, new List()); + unaggregated_grads.Add(in_gradients[i]); + if(unaggregated_grads.Count > kMinAggregateCount) + { + if(!gradients_size.TryGetValue(id, out var size)) + { + size = NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } + if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) + { + Tensor grad = AggregateGradients(unaggregated_grads); + unaggregated_grads.Clear(); + unaggregated_grads.Add(grad); + } + } + } + if(!state.tensor_usage_counts.find(id)) + { + continue; + } + state.tensor_usage_counts[id]--; + if(state.tensor_usage_counts[id] > 0) + { + continue; + } + if (!tensor_tape_.TryGetValue(id, out var tape_it)) + { + if (gradients.find(id)) + { + gradients.erase(id); + } + continue; + } + long op_id = tape_it; + if(op_id == -1) + { + continue; + } + if(state.op_missing_tensor.find(op_id)) + { + state.op_missing_tensor[op_id]--; + if(state.op_missing_tensor[op_id] == 0) + { + op_stack.Enqueue(op_id); + } + } + } + } + + if(state.op_tape.Count > 0) + { + throw new RuntimeError("Invalid tape state."); + } + Tensor[] result = new Tensor[source_tensor_ids.Length]; + for(int i = 0; i < source_tensor_ids.Length; i++) + { + long tensor_id = source_tensor_ids[i]; + if(!gradients.TryGetValue(tensor_id, out var grad_it)) + { + result[i] = null; + } + else + { + if(grad_it.Count > 1) + { + Tensor grad = AggregateGradients(grad_it); + grad_it.Clear(); + grad_it.Add(grad); + } + result[i] = grad_it[0]; + } + } + return result; + } + + UnorderedMap> FunctionsAcceptingNoneForIndicesMap() + { + return _functionsAcceptingNoneForIndicesMap; + } + + UnorderedMap> InitialGradients(long[] target_tensor_ids, + UnorderedMap sources_that_are_targets, + List output_gradients, + TensorTape tensor_tape, + OpTape op_tape) + { + var result = new UnorderedMap>(); + for(int i = 0, end = target_tensor_ids.Length; i < end; i++) + { + long id = target_tensor_ids[i]; + if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) + { + if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) + { + if(!op_tape.TryGetValue(tensor_it, out var op_it)) + { + throw new RuntimeError("Internal state of the gradient tape is invalid: " + + "failed to find operation producing a tensor."); + } + bool found = false; + for(int j = 0; j < op_it.output_tensor_info.Length; j++) + { + if (op_it.output_tensor_info[j].GetID() == id) + { + found = true; + Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); + result.SetDefault(id, new List()).Add(ones_like); + break; + } + } + if (!found) + { + throw new RuntimeError("Internal state of the gradient tape is invalid: " + + "none of operations outputs match expected tensor."); + } + } + else + { + if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) + { + Tensor ones_like = BuildOnesLike(source_tensor); + result.SetDefault(id, new List()).Add(ones_like); + } + } + } + else + { + result.SetDefault(id, new List()).Add(output_gradients[i]); + } + } + + return result; + } + + Queue InitialStack(OpTape op_tape, + UnorderedMap op_missing_tensor) + { + var result = new Queue(); + foreach (var op_entry in op_tape) + { + if (!op_missing_tensor.find(op_entry.Key)) + result.Enqueue(op_entry.Key); + } + return result; + } + + Tensor BuildOnesLike(TapeTensor t) + { + return t.OnesLike(); + } + + Tensor AggregateGradients(List gradient_tensors) + { + if(gradient_tensors.Count == 0) + { + return gradient_tensors[0]; + } + return tf.add_n(gradient_tensors.ToArray()); + } + + void DeleteGradient(Tensor gradient) + { + // Do not do anything here. Because GC will collect it when it has no reference. + } + + long NumElements(Tensor tensor) => 1; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs new file mode 100644 index 000000000..f8f356e76 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs @@ -0,0 +1,67 @@ +using System.Collections.Generic; +using Tensorflow.Util; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + public BackpropInitialState PrepareBackprop(long[] target, + TensorTape tensor_tape, + OpTape op_tape, + UnorderedSet sources_set, + bool persistent_tape) + { + Stack tensor_stack = new Stack(); + foreach(var t in target) + { + tensor_stack.Push(t); + } + BackpropInitialState result = new BackpropInitialState(); + while(tensor_stack.Count > 0) + { + long tensor_id = tensor_stack.Pop(); + if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) + { + continue; + } + if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) + || result.op_tape.find(op_id)) + { + continue; + } + result.op_tape.emplace(op_id, op_it); + foreach(var it in op_it.input_tensor_id) + { + if(result.tensor_usage_counts.find(it)) + { + result.tensor_usage_counts[it]++; + } + else + { + result.tensor_usage_counts[it] = 1; + if (tensor_tape.find(it)) + { + tensor_stack.Push(it); + } + } + } + if (!persistent_tape) + { + op_tape.erase(op_id); + } + } + foreach(var pair in result.tensor_usage_counts) + { + if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) + { + result.op_missing_tensor[it]++; + } + } + if (!persistent_tape) + { + op_tape.Clear(); + } + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs new file mode 100644 index 000000000..708b9121d --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + long next_op_id_ = 0; + UnorderedMap tensor_usage_; + + public void RecordOperation(string op_type, + TapeTensor[] output_tensors, + long[] input_tensor_id, + TF_DataType[] input_dtypes, + BackwardFunction backward_function) + { + if (!ShouldRecord(input_tensor_id, input_dtypes)) + return; + + foreach (var i in input_tensor_id) + { + tensor_usage_[i]++; + } + long op_id = next_op_id_++; + + foreach (var o in output_tensors) + { + tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; + } + + op_tape_[op_id] = new OpTapeEntry + { + op_type = op_type, + output_tensor_info = output_tensors.ToArray(), + input_tensor_id = input_tensor_id.ToArray(), + backward_function = backward_function + }; + } + + public void RecordOperation(string op_type, + Tensor[] outputs, + Tensor[] inputs, + BackwardFunction backward_function) + { + tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs new file mode 100644 index 000000000..648666bbf --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + public partial class Tape : ITape + { + int _id; + // static int tape_nesting_id_counter = 0; + bool _persistent; + public bool Persistent => _persistent; + bool _recording; + bool _created_eagerly; + TensorTape tensor_tape_; + OpTape op_tape_; + + /// + /// A deque-backed stack, whose element references are not invalidated by + /// pushes and pops at the back. + /// + // Stack call_state_; + + public Tape(bool persistent, bool watch_accessed_variables) + { + _persistent = persistent; + _created_eagerly = tf.Context.executing_eagerly(); + tensor_tape_ = new TensorTape(); + op_tape_ = new OpTape(); + tensor_usage_ = new UnorderedMap(); + if(_created_eagerly) + tf.Context.start_step(); + // nesting_id = ++tape_nesting_id_counter; + } + + /// + /// Marks this tensor to be watched by the given tape. + /// + /// + public void Watch(Tensor x) + { + tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); + tensor_tape_.emplace(x.Id, -1); + } + + public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) + { + Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); + for (int i = 0; i < tensor_ids.Length; ++i) + { + if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) + { + return true; + } + } + return false; + } + + public void VariableAccessed(IVariableV1 variable) + { + Watch(variable.Handle); + } + + public IVariableV1[] WatchedVariables() + { + return null; + } + + public bool IsDtypeTrainable(TF_DataType dtype) + { + switch (dtype) + { + case TF_DataType.TF_HALF: + case TF_DataType.TF_BFLOAT16: + case TF_DataType.TF_FLOAT: + case TF_DataType.TF_DOUBLE: + case TF_DataType.TF_COMPLEX64: + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_RESOURCE: + case TF_DataType.TF_VARIANT: + return true; + default: + return false; + } + } + + public void StartRecord() + { + if (_recording) + throw new ValueError("Tape is still recording, This can happen if you try to " + + "re-enter an already-active tape."); + _recording = true; + } + + public void StopRecord() + { + if (!_recording) + throw new ValueError("Tape is not recording."); + if (_created_eagerly) + tf.Context.end_step(); + _recording = false; + } + + public void SetTapeId(int id) + { + _id = id; + } + + public override string ToString() + => $"Tape {_id} {(_recording ? "Recording" : "Stopped")}"; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs new file mode 100644 index 000000000..3ad19768c --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs @@ -0,0 +1,65 @@ +using OneOf; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + public class TapeTensor + { + internal Tensor tensor; + internal long id; + internal TF_DataType dtype; + internal OneOf shape; + + public TapeTensor(long id, TF_DataType dtype, Shape shape) + { + this.id = id; + this.dtype = dtype; + this.shape = shape; + } + + public TapeTensor(long id, TF_DataType dtype, Tensor shape) + { + this.id = id; + this.dtype = dtype; + this.shape = shape; + } + + public TapeTensor(Tensor tensor) + { + this.id = tensor.Id; + this.dtype = tensor.dtype; + this.shape = tensor.shape; + this.tensor = tensor; + } + + public long GetID() => id; + + public Tensor ZerosLike() + { + if(dtype == dtypes.resource) + { + return null; + } + if(shape.Index == 1) + { + return tf.zeros_like(shape.AsT1); + } + return tf.zeros(shape.AsT0, dtype); + } + + public Tensor OnesLike() + { + if (shape.Index == 1) + { + return tf.ones_like(shape.AsT1); + } + return tf.ones(shape.AsT0, dtype); + } + + //public Tensor OnesLike() + // => tf.ones(shape: shape, dtype: dtype); + + public override string ToString() + => $"{id}, {shape}, {dtype.as_numpy_name()}"; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/TensorTape.cs b/src/TensorFlowNET.Core/Gradients/TensorTape.cs new file mode 100644 index 000000000..3f069082f --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/TensorTape.cs @@ -0,0 +1,14 @@ +using Tensorflow.Util; + +namespace Tensorflow.Gradients +{ + /// + /// Map from tensor to internally-defined operation-id of the operation which + /// produced this tensor. A value of -1 means that the tensor was directly + /// watched and not the result of any operation in the tape. + /// + public class TensorTape : UnorderedMap + { + + } +} diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs new file mode 100644 index 000000000..a4da60eed --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -0,0 +1,428 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Eager; +using Tensorflow.Framework; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + /// + /// tensorflow\python\ops\array_grad.py + /// + [RegisterGradient("array_grad")] + public class array_grad + { + [RegisterGradient("BroadcastTo")] + public static Tensor[] _BroadcastToGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_value = op.inputs[0]; + var broadcast_shape = op.inputs[1]; + var input_value_shape = array_ops.shape(input_value); + var reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape, input_value_shape)[1]; + var updates_grad_reshaped = math_ops.reduce_sum(grad, + axis: reduction_axes, + keepdims: true); + var updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape); + + return new Tensor[] + { + updates_grad, + null + }; + } + + [RegisterGradient("ConcatV2")] + public static Tensor[] _ConcatV2Grad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1); + } + + /// + /// Gradient for concat op. + /// + /// An operation. + /// + /// `Tensor` or `IndexedSlices` representing the gradients with respect + /// to each output of the op. + /// + /// An integer index of the first value in the op.inputs. + /// An integer index of the last value in the op.inputs. + /// An interger index of concat_dim or axis parameter in op.inputs. + /// + /// Tensors representing the partial gradients with respect to each input + /// of the op. + /// + private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_value_index, int end_value_index, int dim_index) + { + // Degenerate concatenation, just return grad. + if (len(op.inputs) == 2) + return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; + + var concat_dim = op.inputs[dim_index]; + var input_values = op.inputs._inputs.Skip(start_value_index) + .Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index) + .ToArray(); + + var out_grads = new List(); + if(concat_dim is EagerTensor) + { + var dim_int = (int)concat_dim; + var non_neg_concat_dim = dim_int < 0 + ? input_values[0].rank + dim_int + : dim_int % input_values[0].rank; + var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); + out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); + } + else if (constant_op.is_constant(concat_dim)) + { + /*If concat_dim is a constant defined in a different context, + then we duplicate it in the current context to avoid passing it + through an Enter node. + This is a small optimization in general, but it is required when + compiling with XLA, as XLA needs the concat input to be folded into a + constant.*/ + var grad_context = control_flow_util.GetOutputContext(grad.op); + var dim_context = control_flow_util.GetOutputContext(concat_dim.op); + if (dim_context != grad_context) + { + var value = tensor_util.constant_value(concat_dim); + concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype); + } + + // Using mod here for convenience since concat_dim is already verified + // in concat implementation to be within the allowed [-rank, rank) range. + var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]); + + // Get the inputs' tensor shapes + var sizes = _ExtractInputShapes(input_values); + + /* The magic number of 16 was found through benchmarking a range of sizes + on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of + cases when switching implementations at N=16, but it is possible that + there will be a small number of performance regressions.*/ + if (len(sizes) > 16) + { + // extract the size of each input along the concat dimension + var slice = array_ops.slice(array_ops.stack(sizes, axis: 1), + new Tensor[] { non_neg_concat_dim, tf.constant(0) }, + new Tensor[] { tf.constant(1), tf.constant(-1) }); + var squeeze_sizes = array_ops.squeeze(slice); + out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); + } + else + { + var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); + foreach (var (begin, size) in zip(offset, sizes)) + out_grads.Add(gen_array_ops.slice(grad, begin, size)); + } + } + + return (end_value_index <= dim_index ? + out_grads.ToArray().Concat(new Tensor[] { null }) : + new Tensor[] { null }.Concat(out_grads)).ToArray(); + } + + [RegisterGradient("ExpandDims")] + public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { _ReshapeToInput(op, grads[0]), null }; + } + + /// + /// Extract the shapes of a set of input tensors. + /// + /// + /// + private static Tensor[] _ExtractInputShapes(Tensor[] inputs) + { + var sizes = new Tensor[inputs.Length]; + bool fully_known = true; + for (int i = 0; i < inputs.Length; i++) + { + var x = inputs[i]; + + var input_shape = array_ops.shape(x); + if (!(input_shape is Tensor) || input_shape.op.type != "Const") + { + fully_known = false; + break; + } + + sizes[i] = input_shape; + } + + if (fully_known) + return sizes; + else + return gen_array_ops.shape_n(inputs); + } + + /// + /// Gradient for GatherV2 op. + /// + /// + /// + /// + [RegisterGradient("GatherV2")] + public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var @params = op.inputs[0]; + ops.colocate_with(@params); + + var params_shape = array_ops.shape(@params, out_type: tf.int64); + params_shape = math_ops.cast(params_shape, tf.int32); + + var indices = op.inputs[1]; + var indices_size = array_ops.expand_dims(array_ops.size(indices), 0); + var axis = op.inputs[2]; + var axis_static = tensor_util.constant_value(axis); + + // For axis 0 gathers, build an appropriately shaped IndexedSlices. + if ((int)axis_static == 0) + { + var params_tail_shape = params_shape.slice(new Slice(start: 1)); + var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); + var values = array_ops.reshape(grad, values_shape); + indices = array_ops.reshape(indices, indices_size); + return new Tensor[] + { + new IndexedSlices(values, indices, params_shape), + null, + null + }; + } + + return new Tensor[] { null, null }; + } + + [RegisterGradient("Reshape")] + public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; + } + + [RegisterGradient("Pack")] + public static Tensor[] _PackGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var num = op.get_attr("N"); + var axis = op.get_attr("axis"); + return array_ops.unstack(grad, num: num, axis: axis); + } + + [RegisterGradient("Unpack")] + public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads) + { + var axis = op.get_attr("axis"); + return new[] { array_ops.stack(grads, axis: axis) }; + } + + [RegisterGradient("Pad")] + public static Tensor[] _PadGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var a = op.inputs[1]; + var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) }); + var begin = constant_op.constant(new[] { 0, 0 }); + var pad_before = array_ops.slice(a, begin, size); + + // Make it a 1-D tensor. + begin = array_ops.reshape(pad_before, new[] { -1 }); + size = array_ops.shape(x); + var x_grad = array_ops.slice(grad, begin, size); + + if (len(op.inputs) == 3) + return new Tensor[] { x_grad, null, null }; + else + return new Tensor[] { x_grad, null }; + } + + [RegisterGradient("Split")] + public static Tensor[] _SplitGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { null, array_ops.concat(list(grads), op.inputs[0]) }; + } + + [RegisterGradient("Slice")] + public static Tensor[] _SliceGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_vec = op.inputs[0]; + var begin_vec = op.inputs[1]; + var input_rank = array_ops.rank(input_vec); + var slice_size = array_ops.shape(op.outputs[0]); + + var shape = array_ops.stack(new Tensor[] { input_rank, ops.convert_to_tensor(1) }); + var before_pad = array_ops.reshape(begin_vec, shape); + var after_pad = array_ops.reshape(array_ops.shape(input_vec) - slice_size - begin_vec, shape); + var paddings = array_ops.concat(new Tensor[] { before_pad, after_pad }, 1); + return new Tensor[] + { + array_ops.pad(grad, paddings), + null, + null + }; + } + + [RegisterGradient("Squeeze")] + public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { _ReshapeToInput(op, grads[0]) }; + } + + [RegisterGradient("StopGradient")] + public static Tensor[] _NoGradient(Operation op, Tensor[] grads) + { + return new Tensor[] { null }; + } + + /// + /// Gradient for StridedSlice op. + /// + /// + /// + /// + [RegisterGradient("StridedSlice")] + public static Tensor[] _StridedSliceGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var begin = op.inputs[1]; + var end = op.inputs[2]; + var strides = op.inputs[3]; + + var x = array_ops.shape(op.inputs[0], out_type: begin.dtype); + var x_static = tensor_util.constant_value(x); + var begin_static = tensor_util.constant_value(begin); + var end_static = tensor_util.constant_value(end); + var strides_static = tensor_util.constant_value(strides); + + return new Tensor[] + { + array_ops.strided_slice_grad( + x_static, + begin_static, + end_static, + strides_static, + grad, + begin_mask: op.get_attr("begin_mask"), + end_mask: op.get_attr("end_mask"), + ellipsis_mask: op.get_attr("ellipsis_mask"), + new_axis_mask: op.get_attr("new_axis_mask"), + shrink_axis_mask: op.get_attr("shrink_axis_mask")), + null, + null, + null + }; + } + + [RegisterGradient("StridedSliceGrad")] + public static Tensor[] _StridedSliceGradGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var begin = op.inputs[1]; + var end = op.inputs[2]; + var strides = op.inputs[3]; + + return new Tensor[] + { + null, + null, + null, + array_ops.strided_slice( + grad, + begin, + end, + strides, + begin_mask: (int)op.get_attr("begin_mask"), + end_mask: (int)op.get_attr("end_mask"), + ellipsis_mask: (int)op.get_attr("ellipsis_mask"), + new_axis_mask: (int)op.get_attr("new_axis_mask"), + shrink_axis_mask: (int)op.get_attr("shrink_axis_mask")) + }; + } + + private static Tensor _ReshapeToInput(Operation op, Tensor grad) + { + return array_ops.reshape(grad, array_ops.shape(op.inputs[0])); + } + + [RegisterGradient("Transpose")] + public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads) + { + var p = op.inputs[1]; + return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; + } + + [RegisterGradient("ReverseV2")] + public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var axis = op.inputs[1]; + return new Tensor[] { array_ops.reverse(grad, axis), null }; + } + + [RegisterGradient("Tile")] + public static Tensor[] _TileGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype); + var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1)); + var axes = math_ops.range(0, array_ops.size(split_shape), 2); + + //# Sum reduces grad along the first dimension for IndexedSlices + //if isinstance(grad, indexed_slices_lib.IndexedSlices): + //input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) + //grad = math_ops.unsorted_segment_sum( + // grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) + //split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0) + + var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes); + if (!tf.Context.executing_eagerly()) + { + input_grad.set_shape(op.inputs[0].GetShape()); + } + return new Tensor[] { input_grad, null }; + } + + [RegisterGradient("GatherNd")] + public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads) + { + var @ref = op.inputs[0]; + var indices = op.inputs[1]; + var grad = grads[0]; + var ref_shape = array_ops.shape(@ref, out_type: indices.dtype); + Tensor ref_grad = null; + if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1) + { + ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape); + } + else + { + ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape); + } + return new Tensor[] { ref_grad, null }; + } + + } +} diff --git a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs new file mode 100644 index 000000000..901a33ca8 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + /// + /// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, + /// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... + /// This is a variant of TF_AddGradients that allows to caller to pass a custom + /// name prefix to the operations added to a graph to compute the gradients. + /// + /// TF_Graph* + /// const char* + /// TF_Output* + /// int + /// TF_Output* + /// int + /// TF_Output* + /// TF_Status* + /// TF_Output* + [DllImport(TensorFlowLibName)] + public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny, + TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); + } +} diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs new file mode 100644 index 000000000..eba821d2c --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -0,0 +1,271 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + /// + /// Gradients for operators defined in control_flow_ops.py.cs + /// + [RegisterGradient("control_flow_grad")] + public class control_flow_grad + { + /// + /// Gradients for a Switch op is calculated using a Merge op. + /// + /// If the switch is a loop switch, it will be visited twice. We create + /// the merge on the first visit, and update the other input of the merge + /// on the second visit. A next_iteration is also added on second visit. + /// + /// + [RegisterGradient("Switch")] + public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + { + var merge_grad = grad_ctxt.grad_state.switch_map.get(op); + if (merge_grad != null) + { + if (grads[1] != null) + control_flow_ops._AddNextAndBackEdge(merge_grad, grads[1], + enforce_shape_invariant: false); + return new Tensor[] { null, null }; + } + else if (grads[0] != null) + { + merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0]; + grad_ctxt.grad_state.switch_map[op] = merge_grad; + return new Tensor[] { merge_grad, null }; + } + else + return new Tensor[] { null, null }; + } + case CondContext ccond: + { + var zero_grad = grads[1 - op_ctxt.branch]; + // At this point, we have created zero_grad guarded by the right switch. + // Unfortunately, we may still get None here for not trainable data types. + if (zero_grad == null) + { + throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); + } + + return new Tensor[] + { + merge(grads, name: "cond_grad")[0], + null + }; + } + default: + throw new NotImplementedException("_SwitchGrad WhileContext"); + } + throw new NotImplementedException("_SwitchGrad"); + } + + /// + /// Returns the value of an available element of `inputs`. + /// + /// + /// + /// + internal static MergeOutput merge(Tensor[] inputs, string name = null) + { + return tf_with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) + return gen_control_flow_ops.ref_merge(inputs, name: name); + else + return gen_control_flow_ops.merge(inputs, name: name); + }); + } + + /// + /// Gradients for a Merge op are calculated using a Switch op. + /// + [RegisterGradient("Merge")] + public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_op = op.inputs[0].op; + var graph = ops.get_default_graph(); + var op_ctxt = control_flow_util.GetOutputContext(input_op); + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + { + return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); + } + case CondContext ccond: + { + var pred = ccond.pred; + if (grad_ctxt != null && grad_ctxt.grad_state != null) + { + //# This Merge node is part of a cond within a loop. + //# The backprop needs to have the value of this predicate for every + //# iteration. So we must have its values accumulated in the forward, and + //# use the accumulated values as the predicate for this backprop switch. + var grad_state = grad_ctxt.grad_state; + var real_pred = grad_state.history_map[pred.name] as Tensor; + if (real_pred == null) + { + //# Remember the value of pred for every iteration. + grad_ctxt = grad_state.grad_context; + grad_ctxt.Exit(); + var history_pred = grad_state.AddForwardAccumulator(pred); + grad_ctxt.Enter(); + + //# Add the stack pop op. If pred.op is in a (outer) CondContext, + //# the stack pop will be guarded with a switch. + real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); + grad_state.history_map[pred.name] = real_pred; + } + pred = real_pred; + } + var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); + return results; + } + default: + { + var num_inputs = op.inputs.Length; + var cond = new Tensor[num_inputs]; + for (int i = 0; i < num_inputs; i++) + cond[i] = math_ops.equal(op.outputs[1], i); + var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); + return result; + } + } + + } + + [RegisterGradient("RefMerge")] + public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) + { + return _MergeGrad(op, grads); + } + + /// + /// Gradients for an exit op are calculated using an Enter op. + /// + [RegisterGradient("Exit")] + public static Tensor[] _ExitGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context() as WhileContext; + // The flag `back_prop` is set by users to suppress gradient + // computation for this loop. If the attribute `back_prop` is false, + // no gradient computation. + if (!grad_ctxt.back_prop) + return null; + + if (op_ctxt.grad_state != null) + throw new TypeError("Second-order gradient for while loops not supported."); + + grad_ctxt.AddName(grad.name); + + grad_ctxt.Enter(); + var result = control_flow_ops._Enter( + grad, grad_ctxt.Name, is_constant: false, + parallel_iterations: grad_ctxt.parallel_iterations, + name: "b_exit"); + + grad_ctxt.loop_enters.append(result); + grad_ctxt.Exit(); + return new[] { result }; + } + + /// + /// A forward next_iteration is translated into a backprop identity. + /// + /// Note that the backprop next_iteration is added in switch grad. + /// + [RegisterGradient("NextIteration")] + public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads) + { + return grads; + } + + [RegisterGradient("RefNextIteration")] + public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads) + { + return grads; + } + + /// + /// Gradients for an Enter are calculated using an Exit op. + /// + /// For loop variables, grad is the gradient so just add an exit. + /// For loop invariants, we need to add an accumulator loop. + /// + [RegisterGradient("Enter")] + public static Tensor[] _EnterGrad(Operation op, Tensor[] grads) + { + Tensor result = null; + var grad = grads[0]; + var graph = ops.get_default_graph(); + var grad_ctxt = graph._get_control_flow_context() as WhileContext; + if (!grad_ctxt.back_prop) + // Skip gradient computation, if the attribute `back_prop` is false. + return grads; + if (grad_ctxt.grad_state == null) + // Pass the gradient through if we are not in a gradient while context. + return grads; + if (op.get_attr("is_constant")) + { + // Add a gradient accumulator for each loop invariant. + result = grad_ctxt.AddBackpropAccumulator(op, grad); + } + else + { + result = control_flow_ops.exit(grad); + grad_ctxt.loop_exits.append(result); + grad_ctxt.ExitResult(new[] { result }); + } + + return new Tensor[] { result }; + } + + + [RegisterGradient("RefEnter")] + public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad) + { + return _EnterGrad(op, grad); + } + + /// + /// Stop backprop for the predicate of a while loop. + /// + [RegisterGradient("LoopCond")] + public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad) + { + return null; + } + + } +} diff --git a/src/TensorFlowNET.Core/Gradients/custom_gradient.cs b/src/TensorFlowNET.Core/Gradients/custom_gradient.cs new file mode 100644 index 000000000..0a248086b --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/custom_gradient.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + public class custom_gradient + { + public static string generate_name() + { + return $"CustomGradient-{ops.uid()}"; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/default_gradient.cs b/src/TensorFlowNET.Core/Gradients/default_gradient.cs new file mode 100644 index 000000000..e6c22e369 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/default_gradient.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + internal static class default_gradient + { + public static (Shape, TF_DataType) shape_and_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); + } + return (t.shape, t.dtype); + } + + public static Tensor zeros_like(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var (shape, dtype) = shape_and_dtype(t); + return array_ops.zeros(shape, dtype); + } + else + { + return array_ops.zeros_like(t); + } + } + + public static TF_DataType get_zeros_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); + } + return t.dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs new file mode 100644 index 000000000..53d09eea6 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs @@ -0,0 +1,30 @@ +namespace Tensorflow.Gradients +{ + public class gradient_exclustions + { + public static int[] OpGradientUnusedInputIndices(string op_name) + => op_name switch + { + "FusedBatchNorm" => new[] { 2 }, + "FusedBatchNormGradV3" => new[] { 5 }, + "FusedBatchNormV2" => new[] { 2 }, + "FusedBatchNormV3" => new[] { 2 }, + "ReadVariableOp" => new int[0], + _ => null + }; + + public static int[] OpGradientUnusedOutputIndices(string op_name) + => op_name switch + { + "FusedBatchNormV3" => new[] { 0, 1, 2 }, + "ReadVariableOp" => new int[0], + "SoftmaxCrossEntropyWithLogits" => new[] { 0 }, + "TensorArrayConcat" => new[] { 0 }, + "TensorArrayConcatV2" => new[] { 0 }, + "TensorArrayConcatV3" => new[] { 0 }, + "Mul" => new int[0], + "Sum" => new int[0], + _ => null + }; + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs new file mode 100644 index 000000000..e91bafe88 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -0,0 +1,51 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow +{ + public class gradients_impl + { + public static Tensor[] gradients(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int? aggregation_method = null) + { + return gradients_util._GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); + } + + private static List _AsList(object ys) + { + List ret = null; + + switch (ys) + { + case Tensor value: + ret = new List { value }; + break; + case List value: + ret = value; + break; + } + + return ret; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs new file mode 100644 index 000000000..1fb327788 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -0,0 +1,767 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Gradients; +using Tensorflow.Graphs; +using Tensorflow.Operations; +using Tensorflow.Operations.ControlFlows; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gradients_util + { + // Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are + // unfortunately too slow to use here. + public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; + public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; + public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; + public static Tensor[] _GradientsHelper(Tensor[] ys, + Tensor[] xs, + Tensor[] grad_ys = null, + string name = "gradients", + bool colocate_gradients_with_ops = false, + bool gate_gradients = false, + int aggregation_method = 0, + Tensor[] stop_gradients = null, + Graph src_graph = null) + { + if (src_graph == null) + src_graph = ops.get_default_graph(); + + // If src_graph is a _FuncGraph (i.e. a function body), gather it and all + // ancestor graphs. This is necessary for correctly handling captured values. + var func_graphs = new List(); + var curr_graph = src_graph; + if (src_graph is FuncGraph func_graph) + { + func_graphs.append(func_graph); + curr_graph = func_graph.OuterGraph; + } + + + if (stop_gradients == null) + stop_gradients = new Tensor[0]; + if (grad_ys == null) + grad_ys = new Tensor[ys.Length]; + + // Iterate over the collected ops. + /* + * grads: op => list of gradients received on each output endpoint of the + * op. The gradients for each endpoint are initially collected as a list. + * When it is time to call the op's gradient function, for each endpoint we + * aggregate the list of received gradients into a Add() Operation if there + * is more than one. + */ + var grads = new Dictionary>>(); + Operation[] reachable_to_ops = null; + ControlFlowState loop_state = null; + Dictionary pending_count = null; + + tf_with(ops.name_scope(name, "gradients", + values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => + { + string grad_scope = scope; + // Get a uid for this call to gradients that can be used to help + // cluster ops for compilation. + var gradient_uid = curr_graph.unique_name("uid"); + ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); + xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); + grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); + + /* + * The approach we take here is as follows: Create a list of all ops in the + * subgraph between the ys and xs. Visit these ops in reverse order of ids + * to ensure that when we visit an op the gradients w.r.t its outputs have + * been collected. Then aggregate these gradients if needed, call the op's + * gradient function, and add the generated gradients to the gradients for + * its input. + */ + + // Initialize the pending count for ops in the connected subgraph from ys + // to the xs. + var to_ops = ys.Select(x => x.op).ToList(); + var from_ops = xs.Select(x => x.op).ToList(); + var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); + (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); + + // Add the initial gradients for the ys. + foreach (var (y, grad_y) in zip(ys, grad_ys)) + _SetGrad(grads, y, grad_y); + + // Initialize queue with to_ops. + var queue = new Queue(); + // Add the ops in 'to_ops' into the queue. + var to_ops_set = new List(); + foreach (var op in to_ops) + { + // 'ready' handles the case where one output gradient relies on + // another output's gradient. + if (!pending_count.ContainsKey(op.name)) + pending_count[op.name] = 0; + bool ready = pending_count[op.name] == 0; + if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) + { + to_ops_set.Add(op); + queue.Enqueue(op); + } + } + + if (loop_state != null) + { + var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); + foreach (var y in loop_exits) + { + //if(IsTrainable(y)) + throw new NotImplementedException(""); + } + } + + var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); + while (queue.Count > 0) + { + // generate gradient subgraph for op. + var op = queue.Dequeue(); + + _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); + { + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: true); + var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: true); + + Tensor[] in_grads = null; + Func grad_fn = null; + var is_partitioned_call = _IsPartitionedCall(op); + var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; + var has_out_grads = out_grads.Exists(x => x != null); + if (has_out_grads && !stop_ops.Contains(op)) + { + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + try + { + grad_fn = ops.get_gradient_function(op); + } + catch (LookupError) + { + if (is_func_call) + { + EagerDefinedFunction func_call = null; + if (is_partitioned_call) + { + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + var func_name = ((NameAttrList)func_attr).Name; + func_call = src_graph._get_function(func_name); + if(func_call is null && src_graph.OuterGraph is not null) + { + var graph = src_graph.OuterGraph; + while(graph is not null) + { + func_call = graph._get_function(func_name); + if(func_call is not null) + { + break; + } + if(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + else + { + break; + } + } + } + } + else + { + func_call = src_graph._get_function(op.type); + } + // skip the following codes: + // `func_call = getattr(op, "__defun", func_call)` + grad_fn = func_call.csharp_grad_func; + } + else + { + throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); + } + } + } + + if (loop_state != null) + loop_state.EnterGradWhileContext(op, before: false); + + if ((is_func_call || grad_fn != null) && has_out_grads) + { + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) + { + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. + if (loop_state != null) + out_grads[i] = new List { loop_state.ZerosLike(op, i) }; + else + out_grads[i] = new List { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; + } + } + + tf_with(ops.name_scope(op.name + "_grad"), scope1 => + { + if (grad_fn != null) + { + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, (x, y) => _SymGrad(x, y)); + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); + if (gate_gradients && in_grads.Count(x => x != null) > 1) + { + ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); + in_grads = control_flow_ops.tuple(in_grads); + } + }); + } + else + { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } + + var inputs = _NonEagerInputs(op, xs).ToList(); + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) + { + if (in_grad != null) + { + if (!(in_grad is null) && + in_grad.Tag == null && // maybe a IndexedSlice + t_in.dtype != TF_DataType.TF_RESOURCE) + { + in_grad.shape = t_in.shape; + } + + _SetGrad(grads, t_in, in_grad); + } + } + + if (loop_state != null) + loop_state.ExitGradWhileContext(op, before: false); + } + + // Update pending count for the inputs of op and enqueue ready ops. + _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); + } + }); + + if (loop_state != null) + loop_state.PostProcessing(); + return xs.Select(x => _GetGrad(grads, x)).ToArray(); + } + + /// + /// Fill in default values for grad_ys. + /// + /// List of gradients, can contain None. + /// List of tensors. + /// + /// + private static Tensor[] _DefaultGradYs(Tensor[] grad_ys, Tensor[] ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + { + var new_grad_ys = new List(); + + foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) + { + _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); + + if (grad_y == null) + { + if (y.dtype.is_complex()) + throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); + var shape = array_ops.shape(y); + var constant = constant_op.constant(1, y.dtype, name: $"grad_ys_{i}"); + var fill = gen_array_ops.fill(shape, constant); + new_grad_ys.append(fill); + continue; + } + + if (y.dtype.is_floating() || y.dtype.is_integer()) + { + + } + + // Create a grad_y tensor in the name scope of the gradient. + new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); + } + + return new_grad_ys.ToArray(); + } + + private static void _maybe_colocate_with(Operation op, string gradient_uid, bool colocate_gradients_with_ops) + { + + } + + /// + /// Initialize the pending count for ops between two lists of Operations. + /// 'pending_count[op]' indicates the number of backprop inputs + /// to this operation. + /// + /// + /// + /// + /// + /// + private static (Operation[], Dictionary, ControlFlowState) _PendingCount(List to_ops, + List from_ops, + bool colocate_gradients_with_ops, + List func_graphs, + Tensor[] xs) + { + // Mark reachable ops from from_ops. + var reached_ops = new List(); + _MarkReachedOps(from_ops, reached_ops, func_graphs); + // X in reached_ops iff X is reachable from from_ops by a path of zero or more + // backpropagatable tensors. + + var reachable_to_ops = to_ops.Where(x => reached_ops.Contains(x)).Select(x => x).ToArray(); + + var between_ops = new List(); + var between_op_list = new List(); + + Queue queue = new Queue(to_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + + if (reached_ops.Contains(op)) + { + between_ops.Add(op); + between_op_list.Insert(between_op_list.Count, op); + // Clear the boolean so we won't add the inputs again. + reached_ops.Remove(op); + foreach (var inp in _NonEagerInputs(op, xs)) + queue.Enqueue(inp.op); + } + } + // X in between_ops iff X is on a path of zero or more backpropagatable tensors + // between from_ops and to_ops + + // 'loop_state' is None if there are no while loops. + var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops); + + // Initialize pending count for between ops. + var pending_count = new Dictionary(); + foreach (var op in between_op_list) + { + foreach (Tensor x in _NonEagerInputs(op, xs)) + { + if (between_ops.Contains(x.op)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] += 1; + } + } + } + + return (reachable_to_ops.ToArray(), pending_count, loop_state); + } + + /// + /// Sets gradient "grad" in "grads" for tensor "t". + /// + /// + /// + /// + private static void _SetGrad(Dictionary>> grads, Tensor t, Tensor grad) + { + var op = t.op; + var op_grads = grads.ContainsKey(op.name) ? grads[op.name] : null; + if (op_grads == null) + { + op_grads = op.outputs.Select(x => new List()).ToList(); + grads[op.name] = op_grads; + } + var t_grads = op_grads[t.value_index]; + if (t_grads.Count > 0 && + control_flow_util.IsLoopSwitch(op)) + op_grads[t.value_index][0] = grad; + else + t_grads.Add(grad); + } + + private static IEnumerable _NonEagerInputs(Operation op, Tensor[] xs) + { + for (int i = 0; i < op.inputs.Length; i++) + yield return op.inputs[i]; + } + + private static List> _AggregatedGrads(Dictionary>> grads, Operation op, string gradient_uid, + ControlFlowState loop_state, int aggregation_method = 0) + { + var out_grads = _GetGrads(grads, op); + + foreach (var (i, out_grad) in enumerate(out_grads)) + { + if (loop_state != null) + { + if (out_grads.Count > 1 && + out_grads[1].Count > 0 && + control_flow_util.IsLoopSwitch(op)) + continue; + } + + // Aggregate multiple gradients, and convert [] to None. + if (out_grad.Count > 0) + { +#pragma warning disable CS0219 // Variable is assigned but its value is never used + string used = ""; +#pragma warning restore CS0219 // Variable is assigned but its value is never used + if (out_grad.Count < 2) + { + used = "nop"; + if (out_grad.Count == 0) + { + throw new ValueError("_AggregatedGrads out_grad.Length == 0"); + } + + out_grads[i] = new List { out_grad[0] }; + } + else + { + used = "add_n"; + out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; + } + } + else + { + out_grads[i] = null; + } + } + + return out_grads; + } + + /// + /// Adds tensors from potentially multiple devices. + /// + /// + /// + /// + private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid) + { + // Basic function structure comes from control_flow_ops.group(). + // Sort tensors according to their devices. + var tensors_on_device = new Dictionary>(); + + foreach (var tensor in tensor_list) + { + if (!tensors_on_device.ContainsKey(tensor.Device)) + tensors_on_device[tensor.Device] = new List(); + + tensors_on_device[tensor.Device].Add(tensor); + } + + // For each device, add the tensors on that device first. + var summands = new List(); + foreach (var dev in tensors_on_device.Keys) + { + var tensors = tensors_on_device[dev]; + ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); + summands.Add(math_ops.add_n(tensors.ToArray())); + } + + return math_ops.add_n(summands.ToArray()); + } + + /// + /// The set of ops that terminate the gradient computation. + /// + /// list of Operations. + /// list of Operations never to backprop through. + /// mapping from operation to number of backprop inputs. + /// list of Tensors. + /// The set of operations. + private static Operation[] _StopOps(List from_ops, List stop_gradient_ops, Dictionary pending_count, Tensor[] xs) + { + var stop_ops = new List(); + + foreach (var op in from_ops) + { + bool is_stop_op = true; + foreach (var inp in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(inp.op.name)) + pending_count[inp.op.name] = 0; + + if (pending_count[inp.op.name] > 0) + { + is_stop_op = false; + break; + } + } + if (is_stop_op) + stop_ops.Insert(0, op); + } + stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); + return stop_ops.ToArray(); + } + + private static Tensor _GetGrad(Dictionary>> grads, Tensor t) + { + var op = t.op; + if (!grads.ContainsKey(op.name)) + return null; + var op_grads = grads[op.name]; + var t_grad = op_grads[t.value_index]; + return t_grad[0]; + } + + private static List> _GetGrads(Dictionary>> grads, Operation op) + { + if (grads.ContainsKey(op.name)) + return grads[op.name]; + else + return op.outputs.Select(x => new List()).ToList(); + } + + /// + /// Mark all ops reached from "from_ops" + /// + /// + /// + /// + private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) + { + Queue queue = new Queue(from_ops); + while (queue.Count > 0) + { + var op = queue.Dequeue(); + + if (!reached_ops.Contains(op)) + { + reached_ops.Add(op); + foreach (var output in op.outputs) + { + if (_IsBackpropagatable(output)) + { + var c = output.consumers().ToList(); + c.ForEach(x => queue.Enqueue(x)); + } + } + } + } + } + + private static bool _IsBackpropagatable(Tensor tensor) + { + if (_IsTrainable(tensor)) + { + return true; + } + else + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype); + } + } + + private static bool _IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, + TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype); + } + + private static bool _IsPartitionedCall(Operation op) + { + return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; + } + + /// + /// Update pending count for the inputs of op and enqueue ready ops. + /// + /// + /// + /// + /// + /// + /// + private static void _UpdatePendingAndEnqueueReady(Dictionary>> grads, + Operation op, + Queue queue, + Dictionary pending_count, + ControlFlowState loop_state, + Tensor[] xs) + { + foreach (var x in _NonEagerInputs(op, xs)) + { + if (!pending_count.ContainsKey(x.op.name)) + pending_count[x.op.name] = 0; + + pending_count[x.op.name] -= 1; + + var ready = pending_count[x.op.name] == 0; + + if (loop_state != null && !ready) + { + ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op); + } + + if (ready) + { + // if x is an exit without real gradient, defer processing them. + if (control_flow_util.IsLoopExit(x.op)) + { + var grad_state = loop_state.GetGradState(x.op, before: false); + grad_state.deferred_exits.append(x); + grad_state.pending_exits_count -= 1; + // We now have all the exits so process them. + if (grad_state.pending_exits_count == 0) + { + var has_not_none_grad = false; + foreach (var y in grad_state.deferred_exits) + { + if (_HasAnyNotNoneGrads(grads, y.op)) + { + has_not_none_grad = true; + queue.Enqueue(y.op); + } + else + grad_state.unused_exits.append(y); + } + if (has_not_none_grad) + { + // For an unused exit, if it has trainable outputs, backprop + // a zero gradient. Otherwise, just ignore it. + foreach (var y in grad_state.unused_exits) + { + if (IsTrainable(y)) + _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)); + queue.Enqueue(y.op); + } + } + else + { + // All exits are "unused" so use None as gradient. + foreach (var y in grad_state.unused_exits) + queue.Enqueue(y.op); + } + } + } + else + { + queue.Enqueue(x.op); + } + } + } + } + + public static bool IsTrainable(Tensor tensor) + { + var dtype = tensor.dtype.as_base_dtype(); + return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant}.Contains(dtype); + } + + public static int PossibleTapeGradientTypes(Tensor[] tensors) + { + return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); + } + + /// + /// Return true if op has real gradient. + /// + /// + /// + /// + private static bool _HasAnyNotNoneGrads(Dictionary>> grads, Operation op) + { + var out_grads = _GetGrads(grads, op); + foreach (var out_grad in out_grads) + { + if (out_grad.Exists(g => g != null)) + return true; + } + return false; + } + + + private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) + { + // scope = scope.TrimEnd('/').Replace('/', '_'); + return grad_fn(op, out_grads); + } + + private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) + { + if (op.type == "While" || op.type == "StatelessWhile") + return; + + if (grads.Count() != op.inputs._inputs.Count()) + throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + + $"inputs {op.inputs._inputs.Count()}"); + } + + private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) + { + var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); + var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); + NameAttrList f = new(); + if (_IsPartitionedCall(op)) + { + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + f.Name = ((NameAttrList)func_attr).Name; + } + else + { + f.Name = op.type; + } + foreach(var k in op.node_def.Attr.Keys) + { + f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); + } + var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); + return in_grads; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs new file mode 100644 index 000000000..7b5fb521c --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -0,0 +1,50 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; + +namespace Tensorflow.Gradients +{ + [RegisterGradient("image_grad")] + public class image_grad + { + [RegisterGradient("ResizeNearestNeighbor")] + public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var image = op.inputs[0]; + var shape = new Shape(image.shape.dims.Skip(1).Take(2).ToArray()); + Tensor image_shape = null; + if (shape.IsFullyDefined) + image_shape = constant_op.constant(image.shape.as_int_list().Skip(1).Take(2).ToArray()); + else + image_shape = array_ops.shape(image)["1:3"]; + + grad = gen_image_ops.resize_nearest_neighbor_grad( + grad, + image_shape, + align_corners: op.get_attr("align_corners"), + half_pixel_centers: op.get_attr("half_pixel_centers")); + + return new Tensor[] + { + grad, + null + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs new file mode 100644 index 000000000..8c3f0f8bd --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -0,0 +1,1002 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + /// + /// Gradients for operators defined in math_ops.py. + /// + [RegisterGradient("math_grad")] + public class math_grad + { + [RegisterGradient("Abs")] + public static Tensor[] _AbsGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var grad = grads[0]; + + return new Tensor[] { grad * math_ops.sign(x) }; + } + + [RegisterGradient("AddV2")] + public static Tensor[] _AddV2Grad(Operation op, Tensor[] grads) + => _AddGrad(op, grads); + + [RegisterGradient("Add")] + public static Tensor[] _AddGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var y = op.inputs[1]; + var grad = grads[0]; + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) + return new Tensor[] { grad, grad }; + + var sx = array_ops.shape(x); + var sy = array_ops.shape(y); + var args = gen_array_ops.broadcast_gradient_args(sx, sy); + var (rx, ry) = (args[0], args[1]); + + var sum1 = math_ops.reduce_sum(grad, rx); + var r1 = gen_array_ops.reshape(sum1, sx); + var sum2 = math_ops.reduce_sum(grad, ry); + var r2 = gen_array_ops.reshape(sum2, sy); + + return new Tensor[] { r1, r2 }; + } + + /// + /// Copies the gradient to all inputs. + /// + /// + /// + /// + [RegisterGradient("AddN")] + public static Tensor[] _AddNGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + + return Enumerable.Range(0, len(op.inputs)) + .Select(x => grad) + .ToArray(); + } + + [RegisterGradient("Cumsum")] + public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var axis = op.inputs[1]; + var exclusive = op.get_attr("exclusive"); + var reverse = op.get_attr("reverse"); + return new Tensor[] + { + math_ops.cumsum(grad, axis, exclusive: exclusive, reverse: !reverse), + null + }; + } + + [RegisterGradient("DivNoNan")] + public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var y = op.inputs[1]; + var sx = array_ops.shape(x); + var sy = array_ops.shape(y); + var args = gen_array_ops.broadcast_gradient_args(sx, sy); + var (rx, ry) = (args[0], args[1]); + x = math_ops.conj(x); + y = math_ops.conj(y); + + var reduce_sum1 = math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx); + var reduce_sum2 = math_ops.reduce_sum(grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), ry); + + return new Tensor[] + { + array_ops.reshape(reduce_sum1, sx), + array_ops.reshape(reduce_sum2, sy) + }; + } + + public static string ellipsis = "..."; + [RegisterGradient("Einsum")] + public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) + { + // Gradient for Einsum. + string equation = (string)op.get_attr("equation"); + string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); + var input_subs = split_equation[0]; + var output_subs = split_equation[1]; + + if (op.inputs.Length == 1) + { + var input_shape = array_ops.shape(op.inputs[0]); + var reduced_label_set = new HashSet(new HashSet(input_subs).Except(new HashSet(output_subs + ellipsis))); + if (reduced_label_set.Count == 0) + return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; + return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; + } + + string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); + var x_subs = split_input_subs[0]; + var y_subs = split_input_subs[1]; + // Add ellipsis for broadcasted dimensions if any operand does not have it. + // This is because the equation "...ij,jk->ik" may be valid if the 0th input's + // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid + // because only the output subscripts contain ellipsis. + if (output_subs.Contains(ellipsis)) + { + if (!x_subs.Contains(ellipsis)) + x_subs += ellipsis; + if (!y_subs.Contains(ellipsis)) + y_subs += ellipsis; + } + // Obtain the gradients wrt the inputs x and y, without taking into account + // the unbroadcasting. + var x = op.inputs[0]; + var y = op.inputs[1]; + if (grads.GetDataType().is_complex()) + { + x = math_ops.conj(x); + y = math_ops.conj(y); + } + + var x_shape = array_ops.shape(x); + var y_shape = array_ops.shape(y); + var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); + var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); + + if (!output_subs.Contains(ellipsis)) + return new Tensor[] { grad_x, grad_y }; + var bx = _GetBcastSubshape(x_subs); + int bx_start = bx[0], bx_end = bx[1]; + var by = _GetBcastSubshape(y_subs); + int by_start = by[0], by_end = by[1]; + + var x_shape_static = x.shape; + var y_shape_static = y.shape; + if(x_shape_static.IsFullyDefined && + y_shape_static.IsFullyDefined && + x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) + return new Tensor[] { grad_x, grad_y }; + + var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], + y_shape[string.Format("{0}:{1}", by_start, by_end)]); + var rx = r[0]; + var ry = r[1]; + grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); + grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); + return new Tensor[] { grad_x, grad_y }; + } + protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, + string input_subs, string other_subs, string output_subs) + { + var reduced_label_set = new HashSet(new HashSet(input_subs).Except(new HashSet(output_subs + other_subs + "."))); + var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); + var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); + if (reduced_label_set.Count == 0) + return grad_reduced; + return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); + } + protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet reduced_label_set) + { + string reduced_subs; + Tensor reduced_dims; + List reduced_axes; + _GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); + bool has_repeated_labels = ( + new HashSet(input_subs).Count + new HashSet(output_subs).Count < + input_subs.Length + output_subs.Length); + var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); + + if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) + { + var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); + return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); + } + else + { + var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); + var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); + var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); + return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); + } + } + protected static void _GetReducedSubscripts(HashSet reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List reduced_axes) + { + reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); + reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); + reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); + } + protected static int _GetAxisFromLabel(string subscripts, char label) + { + var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); + var index = splits[0].IndexOf(label); + if (index != -1) return index; + if (splits.Length < 2) throw new OutOfRangeError(); + index = splits[1].IndexOf(label); + if (index != -1) return index; + throw new ValueError(); + } + protected static int[] _GetBcastSubshape(string subscripts) + { + int start = subscripts.IndexOf(ellipsis); + if (start == -1) return new int[] { 0, 0 }; + int remaining = subscripts.Length - (start + ellipsis.Length); + int end; + if (remaining > 0) end = remaining; + else throw new Exception(); + return new int[] { start, end }; + } + + /// + /// Returns grad * exp(x). + /// + /// + /// + /// + [RegisterGradient("Exp")] + public static Tensor[] _ExpGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var y = op.outputs[0]; // y = e^x + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => + { + y = math_ops.conj(y); + // forward_compatible(2019, 9, 14) + // return new Tensor[] { math_ops.mul_no_nan(y, grad) }; + return new Tensor[] { grad * y }; + }); + } + + [RegisterNoGradient("GreaterEqual")] + public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; + + [RegisterNoGradient("OnesLike")] + public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null; + + [RegisterNoGradient("ZerosLike")] + public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; + + [RegisterGradient("Identity")] + public static Tensor[] _IdGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { grads[0] }; + } + + [RegisterGradient("Lgamma")] + public static Tensor[] _LgammaGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => + { + x = math_ops.conj(x); + return new Tensor[] { grad * math_ops.digamma(x) }; + }); + } + + [RegisterGradient("Log")] + public static Tensor[] _LogGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => + { + x = math_ops.conj(x); + return new Tensor[] { grad * math_ops.reciprocal(x) }; + }); + } + + [RegisterGradient("Log1p")] + public static Tensor[] _Log1pGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => + { + x = math_ops.conj(x); + return new Tensor[] { grad * math_ops.reciprocal(1 + x) }; + }); + } + + [RegisterGradient("Mul")] + public static Tensor[] _MulGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var y = op.inputs[1]; + var grad = grads[0]; + + if (op is EagerOperation op_eager && + op_eager.SkipInputIndices.Contains(1) && + y.ndim == 0) + { + return new Tensor[] + { + gen_math_ops.mul(grad, math_ops.conj(y)), + null + }; + } + + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad) && + new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) + { + return new Tensor[] + { + gen_math_ops.mul(grad, y), + gen_math_ops.mul(grad, x) + }; + } + + var broads = SmartBroadcastGradientArgs(x, y, grad); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; + + x = math_ops.conj(x); + y = math_ops.conj(y); + + Tensor gx = null, gy = null; + + if (op is EagerOperation op_eager1 && + op_eager1.SkipInputIndices.Contains(0)) + gy = null; + else if (!must_reduce_x) + gx = gen_math_ops.mul(grad, y); + else + gx = array_ops.reshape( + math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx); + + if (op is EagerOperation op_eager2 && + op_eager2.SkipInputIndices.Contains(1)) + gy = null; + else if (!must_reduce_y) + gy = gen_math_ops.mul(x, grad); + else + gy = array_ops.reshape( + math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy); + + return new Tensor[] { gx, gy }; + } + + [RegisterGradient("MatMul")] + public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + Tensor grad_a = null, grad_b = null; + + var t_a = (bool)op.get_attr("transpose_a"); + var t_b = (bool)op.get_attr("transpose_b"); + var a = math_ops.conj(op.inputs[0]); + var b = math_ops.conj(op.inputs[1]); + if (!t_a && !t_b) + { + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b: true); + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a: true); + } + else if (!t_a && t_b) + { + grad_a = gen_math_ops.mat_mul(grad, b); + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true); + } + else if (t_a && !t_b) + { + grad_a = gen_math_ops.mat_mul(grad, b); + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true); + } + else if (t_a && t_b) + { + grad_a = gen_math_ops.mat_mul(b, grad, transpose_a: true, transpose_b: true); + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true, transpose_b: true); + } + + return new Tensor[] { grad_a, grad_b }; + } + + [RegisterGradient("BatchMatMul")] + public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) + { + var grad = grads[0]; + Tensor grad_a = null, grad_b = null; + + var t_a = (bool)op.get_attr("adj_x"); + var t_b = (bool)op.get_attr("adj_y"); + var a = math_ops.conj(op.inputs[0]); + var b = math_ops.conj(op.inputs[1]); + if (!t_a && !t_b) + { + grad_a = math_ops.batch_matmul(grad, b, adj_y: true); + grad_b = math_ops.batch_matmul(a, grad, adj_x: true); + } + else if (!t_a && t_b) + { + grad_a = math_ops.batch_matmul(grad, b); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true); + } + else if (t_a && !t_b) + { + grad_a = math_ops.batch_matmul(grad, b); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true); + } + else if (t_a && t_b) + { + grad_a = math_ops.batch_matmul(b, grad, adj_x: true, adj_y: true); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true, adj_y: true); + } + + return new Tensor[] { grad_a, grad_b }; + } + + [RegisterGradient("Mean")] + public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var sum_grad = _SumGrad(op, grads)[0]; + var input_shape = op.inputs[0]._shape_tuple(); + var output_shape = op.outputs[0]._shape_tuple(); + + Tensor result, factor_tensor; + if (tf.executing_eagerly() + && input_shape != null + && output_shape != null) + { + var input_size = np.prod(input_shape); + var output_size = np.prod(output_shape); + var factor = (int)input_size / Math.Max((int)output_size, 1); + factor_tensor = constant_op.constant(factor, dtype: sum_grad.dtype); + } + else + { + var input_shape_tensor = array_ops.shape(op.inputs[0]); + var output_shape_tensor = array_ops.shape(op.outputs[0]); + factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); + } + + result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); + return new Tensor[] { result, null }; + } + + /// + /// Gradient for Max. + /// + /// + /// + /// + [RegisterGradient("Max")] + public static Tensor[] _MaxGrad(Operation op, Tensor[] grads) + { + return _MinOrMaxGrad(op, grads); + } + + /// + /// Gradient for Min. + /// + /// + /// + /// + [RegisterGradient("Min")] + public static Tensor[] _MinGrad(Operation op, Tensor[] grads) + { + return _MinOrMaxGrad(op, grads); + } + + private static Tensor[] _MinOrMaxGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_shape = array_ops.shape(op.inputs[0]); + var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); + var y = op.outputs[0]; + y = array_ops.reshape(y, output_shape_kept_dims); + grad = array_ops.reshape(grad, output_shape_kept_dims); + + // Compute the number of selected (maximum or minimum) elements in each + // reduction dimension. If there are multiple minimum or maximum elements + // then the gradient will be divided between them. + var indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype); + var num_selected = array_ops.reshape(math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims); + + return new Tensor[] { math_ops.div(indicators, num_selected) * grad, null }; + } + + /// + /// Returns grad*(x > y, x <= y) with type of grad. + /// + /// + /// + /// + [RegisterGradient("Maximum")] + public static Tensor[] _MaximumGrad(Operation op, Tensor[] grads) + { + return _MaximumMinimumGrad(true, op, grads[0]); + } + + /// + /// Returns grad*(x < y, x >= y) with type of grad. + /// + /// + /// + /// + [RegisterGradient("Minimum")] + public static Tensor[] _MinimumGrad(Operation op, Tensor[] grads) + { + return _MaximumMinimumGrad(false, op, grads[0]); + } + + /// + /// Factor out the code for the gradient of Maximum or Minimum. + /// + /// + /// + /// + private static Tensor[] _MaximumMinimumGrad(bool isMaximum, Operation op, Tensor grad) + { + var x = op.inputs[0]; + var y = op.inputs[1]; + var gdtype = grad.dtype; + var sx = array_ops.shape(x); + var sy = array_ops.shape(y); + var gradshape = array_ops.shape(grad); + var zeros = array_ops.zeros(gradshape, gdtype); + var xmask = + isMaximum + ? gen_math_ops.greater_equal(x, y) + : gen_math_ops.less_equal(x, y); + var args = gen_array_ops.broadcast_gradient_args(sx, sy); + var (rx, ry) = (args[0], args[1]); + var xgrad = array_ops.where(xmask, grad, zeros); + var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx); + var ygrad = array_ops.where(xmask, zeros, grad); + var gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy); + return new Tensor[] { gx, gy }; + } + + [RegisterGradient("Neg")] + public static Tensor[] _NegGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { -grads[0] }; + } + + [RegisterGradient("Select")] + public static Tensor[] _SelectGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var c = op.inputs[0]; + var x = op.inputs[1]; + var zeros = array_ops.zeros_like(x); + return new Tensor[] + { + null, + array_ops.where(c, grad, zeros), + array_ops.where(c, zeros, grad) + }; + } + + private static Tensor _safe_shape_div(Tensor x, Tensor y) + { + return math_ops.floordiv(x, gen_math_ops.maximum(y, ops.convert_to_tensor(1))); + } + + [RegisterGradient("Sub")] + public static Tensor[] _SubGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var y = op.inputs[1]; + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) + return new Tensor[] { grad, -grad }; + + var broads = SmartBroadcastGradientArgs(x, y, grad); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; + + var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); + var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); + + return new Tensor[] { gx, gy }; + } + + public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) + { + var x_shape = x._shape_tuple(); + var y_shape = y._shape_tuple(); + var grad_shape = grad._shape_tuple(); + return x_shape != null && + y_shape != null && + Enumerable.SequenceEqual(x_shape, y_shape) && + Enumerable.SequenceEqual(y_shape, grad_shape) && + !x_shape.Contains(-1); + } + + [RegisterGradient("Sum")] + public static Tensor[] _SumGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_0_shape = op.inputs[0]._shape_tuple(); + Tensor input_shape = null; + + if (input_0_shape != null) + { + var axes = tensor_util.constant_value(op.inputs[1]); + if (!(axes is null)) + { + var rank = input_0_shape.Length; + if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.ToArray())) + { + if (tf.Context.executing_eagerly()) + { + // should add ones_rank_cache + var new_shape = constant_op.constant(range(0, rank).Select(x => 1).ToArray(), dtype: TF_DataType.TF_INT32); + grad = array_ops.reshape(grad, new_shape); + } + else + { + var new_shape = range(rank).Select(x => 1).ToArray(); + grad = array_ops.reshape(grad, new_shape); + } + + // If shape is not fully defined (but rank is), we use Shape. + if (!input_0_shape.Contains(-1)) + input_shape = constant_op.constant(input_0_shape); + else + input_shape = array_ops.shape(op.inputs[0]); + return new Tensor[] { gen_array_ops.tile(grad, input_shape), null }; + } + else if (!input_0_shape.Contains(-1) && !tf.Context.executing_eagerly()) + { + axes = axes.reshape(new Shape(-1)); + var shape_tensor = tf.constant(op.inputs[0].shape.as_int_list()); + var output_shape_kept_dims = math_ops.reduced_shape(shape_tensor, axes); + var tile_scaling = _safe_shape_div(shape_tensor, output_shape_kept_dims); + grad = array_ops.reshape(grad, output_shape_kept_dims); + return new Tensor[] { array_ops.tile(grad, tile_scaling), null }; + } + } + } + + input_shape = array_ops.shape(op.inputs[0]); + + if (tf.executing_eagerly()) + { + if (!op.get_attr("keep_dims")) + { + ops.colocate_with(input_shape); + var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); + // var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); + grad = gen_array_ops.reshape(grad, output_shape_kept_dims); + } + + return new Tensor[] { gen_array_ops.broadcast_to(grad, input_shape), null }; + } + else + { + ops.colocate_with(input_shape); + var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); + var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); + grad = gen_array_ops.reshape(grad, output_shape_kept_dims); + + return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; + } + } + + [RegisterGradient("RealDiv")] + public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var y = op.inputs[1]; + + var sx = array_ops.shape(x); + var sy = array_ops.shape(y); + var args = gen_array_ops.broadcast_gradient_args(sx, sy); + var (rx, ry) = (args[0], args[1]); + x = math_ops.conj(x); + y = math_ops.conj(y); + + var reshape1 = array_ops.reshape( + math_ops.reduce_sum( + math_ops.realdiv(grad, y), rx), + sx); + var reshape2 = array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), + sy); + + return new Tensor[] { reshape1, reshape2 }; + } + + [RegisterGradient("Sigmoid")] + public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var y = op.outputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + y = math_ops.conj(y); + return new Tensor[] { gen_math_ops.sigmoid_grad(y, grad) }; + }); + } + + [RegisterGradient("Sign")] + public static Tensor[] _SignGrad(Operation op, Tensor[] grads) + { + var x = op.inputs[0]; + var zero = constant_op.constant(0.0f, x.dtype, x.shape); + + return new Tensor[] { zero }; + } + + [RegisterGradient("Square")] + public static Tensor[] _SquareGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + var y = constant_op.constant(2.0, dtype: x.dtype); + return new Tensor[] { math_ops.multiply(grad, math_ops.multiply(x, y)) }; + }); + } + + [RegisterGradient("Sqrt")] + public static Tensor[] _SqrtGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var y = op.outputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + y = math_ops.conj(y); + var factor = constant_op.constant(0.5f, dtype: y.dtype); + return new Tensor[] { grad * (factor * math_ops.reciprocal(y)) }; + }); + } + + [RegisterGradient("Rsqrt")] + public static Tensor[] _RsqrtGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var y = op.outputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + y = math_ops.conj(y); + var factor = constant_op.constant(-0.5f, dtype: y.dtype); + return new Tensor[] { grad * (factor * math_ops.square(y) * y) }; + }); + } + + [RegisterGradient("Asin")] + public static Tensor[] _ASinGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + // the derivative of + // y = asin(x) + // is + // d/dx asin(x) = 1 / sqrt(1-x*x) + return new Tensor[] { math_ops.multiply(grad, 1 / gen_math_ops.sqrt(1 - gen_math_ops.square(x))) }; + }); + } + + [RegisterGradient("Sin")] + public static Tensor[] _SinGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, gen_math_ops.cos(x)) }; + }); + } + + [RegisterGradient("Sinh")] + public static Tensor[] _SinhGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, gen_math_ops.cosh(x)) }; + }); + } + + [RegisterGradient("Acos")] + public static Tensor[] _ACosGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + // the derivative of + // y = acos(x) + // is + // d/dx acos(x) = -1 / sqrt(1-x*x) = -d/dx asin(x) + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, -1 / gen_math_ops.sqrt(1 - gen_math_ops.square(x))) }; + }); + } + + [RegisterGradient("Cast")] + public static Tensor[] _CastGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + var src_type = x.dtype.as_base_dtype(); + var dst_type = grad.dtype.as_base_dtype(); + if (src_type.is_value_dtype() && dst_type.is_value_dtype()) + return new Tensor[] { math_ops.cast(grad, src_type) }; + else + return new Tensor[0]; + } + + [RegisterGradient("Cos")] + public static Tensor[] _CosGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, -gen_math_ops.sin(x)) }; + }); + } + + [RegisterGradient("Cosh")] + public static Tensor[] _CoshGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, gen_math_ops.sinh(x)) }; + }); + } + + [RegisterGradient("Atan")] + public static Tensor[] _ATanGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + // the derivative of + // y = atan(x) + // is + // d/dx atan(x) = 1 / (1 + x*x) + x = math_ops.conj(x); + return new Tensor[] { math_ops.multiply(grad, 1 / (1 + gen_math_ops.square(x))) }; + }); + } + + [RegisterGradient("Tanh")] + public static Tensor[] _TanhGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var y = op.outputs[0]; + + return tf_with(ops.control_dependencies(grads), delegate + { + y = math_ops.conj(y); + return new Tensor[] { gen_math_ops.tanh_grad(y, grad) }; + }); + } + + [RegisterGradient("Pow")] + public static Tensor[] _PowGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var y = op.inputs[1]; + + if (op is EagerOperation op_eager && + op_eager.SkipInputIndices.Contains(1) && + y.ndim == 0) + { + x = math_ops.conj(x); + y = math_ops.conj(y); + return new Tensor[] + { + grad * y * math_ops.pow(x, y - 1), + null + }; + } + + var z = op.outputs[0]; + + var broads = SmartBroadcastGradientArgs(x, y, grad); + var (sx, rx, must_reduce_x) = broads[0]; + var (sy, ry, must_reduce_y) = broads[1]; + + x = math_ops.conj(x); + y = math_ops.conj(y); + z = math_ops.conj(z); + var mul = grad * y * math_ops.pow(x, y - 1.0f); + var reduce_sum = math_ops.reduce_sum(mul, rx); + var gx = gen_array_ops.reshape(reduce_sum, sx); + + // Avoid false singularity at x = 0 + Tensor mask = null; + if (x.dtype.is_complex()) + throw new NotImplementedException("x.dtype.is_complex()"); + else + mask = x > 0.0f; + var ones = array_ops.ones_like(x); + var safe_x = array_ops.where(mask, x, ones); + var x1 = math_ops.log(safe_x); + var y1 = array_ops.zeros_like(x); + var log_x = array_ops.where(mask, x1, y1); + var mul1 = grad * z * log_x; + var reduce_sum1 = math_ops.reduce_sum(mul1, ry); + var gy = gen_array_ops.reshape(reduce_sum1, sy); + + return new Tensor[] { gx, gy }; + } + + /// + /// Optimized version of `broadcast_gradient_args` that caches results. + /// + /// + /// + /// + public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) + { + Tensor sx, sy; + if (x.shape.IsFullyDefined && + y.shape.IsFullyDefined) + { + sx = array_ops.shape(x); + sy = array_ops.shape(y); + } + else + { + sx = array_ops.shape_internal(x, optimize: false); + sy = array_ops.shape_internal(y, optimize: false); + } + + var args = gen_array_ops.broadcast_gradient_args(sx, sy); + var (rx, ry) = (args[0], args[1]); + return new[] + { + (sx, rx, !x.shape.Equals(grad.shape)), + (sy, ry, !y.shape.Equals(grad.shape)) + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/math_grad_eager.cs b/src/TensorFlowNET.Core/Gradients/math_grad_eager.cs new file mode 100644 index 000000000..f8b16090f --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/math_grad_eager.cs @@ -0,0 +1,71 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Eager; + +namespace Tensorflow.Gradients +{ + /// + /// Gradients for operators defined in math_ops.py. + /// + [RegisterGradientEager("math_grad")] + public class math_grad_eager + { + [RegisterGradientEager("Mul")] + public static Tensor[] _MulGrad(EagerOperation op, IntPtr[] grads) + { + var x = op.InputHandles[0]; + var y = op.InputHandles[1]; + var grad = grads[0]; + + if (op.SkipInputIndices.Contains(1) && + EagerTensor.GetRank(grad) == 0) + { + return new Tensor[] + { + null,//gen_math_ops.mul(grad, math_ops.conj(y)), + null + }; + } + + if (_ShapesFullySpecifiedAndEqual(x, y, grad)) + { + return new Tensor[] + { + math_ops.multiply(grad, y), + math_ops.multiply(grad, x) + }; + } + + throw new NotImplementedException(""); + } + + public static bool _ShapesFullySpecifiedAndEqual(IntPtr x, IntPtr y, IntPtr grad) + { + var x_shape = EagerTensor.GetDims(x); + var y_shape = EagerTensor.GetDims(y); + + var grad_shape = EagerTensor.GetDims(grad); + return x_shape != null && + y_shape != null && + Enumerable.SequenceEqual(x_shape, y_shape) && + Enumerable.SequenceEqual(y_shape, grad_shape) && + !x_shape.Contains(-1); + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs new file mode 100644 index 000000000..87646a9ea --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -0,0 +1,464 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + /// + /// + /// + [RegisterGradient("math_grad")] + public class nn_grad + { + /// + /// Return the gradients for the 2 inputs of bias_op. + /// + /// + /// + /// + [RegisterGradient("BiasAdd")] + public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + string data_format = op.get_attr("data_format")?.ToString(); + var bias_add_grad = gen_nn_ops.bias_add_grad(out_backprop: grad, data_format: data_format); + return new Tensor[] { grad, bias_add_grad }; + } + + [RegisterGradient("Relu")] + public static Tensor[] _ReluGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) }; + } + + [RegisterGradient("LeakyRelu")] + public static Tensor[] _LeakyReluGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var alpha = (float)op.get_attr("alpha"); + return new Tensor[] { gen_nn_ops.leaky_relu_grad(grad, x, alpha: alpha) }; + } + + /// + /// The derivative of the softmax nonlinearity. + /// + /// + /// + /// + [RegisterGradient("Softmax")] + public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads) + { + var grad_softmax = grads[0]; + + var softmax = op.outputs[0]; + var mul = grad_softmax * softmax; + var sum_channels = math_ops.reduce_sum(mul, axis: constant_op.constant(-1), keepdims: true); + var sub = grad_softmax - sum_channels; + return new Tensor[] { sub * softmax }; + } + + /// + /// Gradient function for SoftmaxCrossEntropyWithLogits. + /// + /// + /// + /// + [RegisterGradient("SoftmaxCrossEntropyWithLogits")] + public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) + { + var grad_loss = grads[0]; + var grad_grad = grads[1]; + var softmax_grad = op.outputs[1]; + var grad = _BroadcastMul(grad_loss, softmax_grad); + + var logits = op.inputs[0]; + if (grad_grad != null && !IsZero(grad_grad)) + { + throw new NotImplementedException("_SoftmaxCrossEntropyWithLogitsGrad"); + } + + return new Tensor[] + { + grad, + _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) + }; + } + + [RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")] + public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads) + { + var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( + op.outputs[1], + message: "Currently there is no way to take the second " + + "derivative of sparse_softmax_cross_entropy_with_logits due to the fused " + + "implementation's interaction with tf.gradients()"); + + var grad_0 = grads[0]; + + return new Tensor[] + { + _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), + null + }; + } + + [RegisterGradient("Softplus")] + public static Tensor[] _SoftplusGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + + var softplus = grad * math_ops.sigmoid(x); + return new Tensor[] { softplus }; + } + + [RegisterGradient("SquaredDifference")] + public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads) + { + Tensor x = op.inputs[0]; + Tensor y = op.inputs[1]; + var grad = grads[0]; + var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); + var x_grad = math_ops.scalar_mul(scale, grad) * (x - y); + if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad)) + { + return new Tensor[] { x_grad, -x_grad }; + } + var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad); + Debug.Assert(broadcast_info.Length == 2); + var (sx, rx, must_reduce_x) = broadcast_info[0]; + var (sy, ry, must_reduce_y) = broadcast_info[1]; + Tensor gx, gy; + if (must_reduce_x) + { + gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx); + } + else + { + gx = x_grad; + } + if (must_reduce_y) + { + gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy); + } + else + { + gy = -x_grad; + } + return new Tensor[] { gx, gy }; + } + + /// + /// The derivatives for deconvolution. + /// + /// The Deconvolution op. + /// The tensor representing the gradient w.r.t. the output + /// The gradients w.r.t. the input and the filter + [RegisterGradient("Conv2DBackpropInput")] + public static Tensor[] _Conv2DBackpropInputGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var dilations = op.get_attr_list("dilations"); + var strides = op.get_attr_list("strides"); + var padding = op.get_attr("padding"); + var explicit_paddings = op.get_attr_list("explicit_paddings"); + var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); + var data_format = op.get_attr("data_format"); + + return new Tensor[] + { + gen_nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]), op.inputs[2], + strides, padding, + use_cudnn_on_gpu: use_cudnn_on_gpu, + explicit_paddings: explicit_paddings, + dilations: dilations, + data_format: data_format), + gen_nn_ops.conv2d(grad, op.inputs[1], strides, padding, + use_cudnn_on_gpu, explicit_paddings, data_format, dilations) + }; + } + + /// + /// Gradient function for Conv2D. + /// + /// + /// + /// + [RegisterGradient("Conv2D")] + public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) + { + var dilations = op.get_attr_list("dilations"); + var strides = op.get_attr_list("strides"); + var padding = op.get_attr("padding"); + var explicit_paddings = op.get_attr_list("explicit_paddings"); + var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu"); + var data_format = op.get_attr("data_format"); + var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); + + return new Tensor[] + { + gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0], + strides, padding, use_cudnn_on_gpu, explicit_paddings, + dilations: dilations, + data_format: data_format), + gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0], + strides, padding, + dilations: dilations, + explicit_paddings: explicit_paddings, + use_cudnn_on_gpu: use_cudnn_on_gpu, + data_format: data_format) + }; + } + + /// + /// Gradient function for Conv2D. + /// + /// + /// + /// + [RegisterGradient("DepthwiseConv2dNative")] + public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads) + { + var dilations = op.get_attr_list("dilations"); + var strides = op.get_attr_list("strides"); + var padding = op.get_attr("padding"); + var explicit_paddings = op.get_attr_list("explicit_paddings"); + var data_format = op.get_attr("data_format"); + var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); + + return new Tensor[] + { + gen_nn_ops.depthwise_conv2d_native_backprop_input( + shape[0], op.inputs[1], grads[0], + strides, padding, explicit_paddings, + dilations: dilations, + data_format: data_format), + gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0], + strides, padding, + dilations: dilations, + explicit_paddings: explicit_paddings, + data_format: data_format) + }; + } + + [RegisterGradient("FusedBatchNorm")] + public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 0, grads); + + [RegisterGradient("FusedBatchNormV2")] + public static Tensor[] _FusedBatchNormV2Grad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 1, grads); + + [RegisterGradient("FusedBatchNormV3")] + public static Tensor[] _FusedBatchNormV3Grad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 2, grads); + + /// + /// Return the gradients for the 3 inputs of BatchNorm. + /// + /// + /// + /// + /// + public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) + { + var x = op.inputs[0]; + var grad_y = grads[0]; + var scale = op.inputs[1]; + var epsilon = op.get_attr("epsilon"); + var data_format = op.get_attr("data_format"); + var is_training = op.get_attr("is_training"); + Func grad_fun = (p) => + { + if(version == 2) + { + return gen_nn_ops.fused_batch_norm_grad_v3(p.YBackprop, p.X, p.Scale, + p.ReserveSpace1, p.ReserveSpace2, p.ReserveSpace3, p.Epsilon, + p.DataFormat, p.IsTraining, p.Name); + } + else if(version == 1) + { + return gen_nn_ops.fused_batch_norm_grad_v2(p.YBackprop, p.X, p.Scale, + p.ReserveSpace1, p.ReserveSpace2, p.Epsilon, p.DataFormat, + p.IsTraining, p.Name); + } + else + { + return gen_nn_ops.fused_batch_norm_grad(p.YBackprop, p.X, p.Scale, + p.ReserveSpace1, p.ReserveSpace2, p.Epsilon, p.DataFormat, + p.IsTraining, p.Name); + } + }; + + if (is_training) + { + return grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + } + else + { + var pop_mean = op.inputs[3]; + var pop_var = op.inputs[4]; + if (data_format == "NCHW") + throw new NotImplementedException(""); + + var results = grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = pop_mean, + ReserveSpace2 = pop_var, + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + + var (dx, dscale, doffset) = (results[0], results[1], results[2]); + if (data_format == "NCHW") + throw new NotImplementedException(""); + + return new Tensor[] + { + dx, + dscale, + doffset, + null, + null + }; + } + } + + [RegisterGradient("BatchNormWithGlobalNormalization")] + public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) + { + throw new NotImplementedException("BatchNormWithGlobalNormalization"); + } + + private static bool IsZero(Tensor g) + { + if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) + return true; + + throw new NotImplementedException("IsZero"); + } + + private static Tensor _BroadcastMul(Tensor vec, Tensor mat) + { + vec = array_ops.expand_dims(vec, -1); + return vec * mat; + } + + [RegisterGradient("MaxPool")] + public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + return new Tensor[] + { + gen_nn_ops.max_pool_grad( + op.inputs[0], + op.outputs[0], + grad, + op.get_attr_list("ksize"), + op.get_attr_list("strides"), + padding: op.get_attr("padding").ToString(), + data_format: op.get_attr("data_format").ToString()) + }; + } + + [RegisterGradient("AvgPool")] + public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) + { + Tensor grad = grads[0]; + + return new Tensor[] + { + gen_nn_ops.avg_pool_grad( + array_ops.shape(op.inputs[0]), + grad, + op.get_attr_list("ksize"), + op.get_attr_list("strides"), + op.get_attr("padding"), + op.get_attr("data_format")) + }; + } + + /// + /// Return the gradients for TopK. + /// + /// + /// + /// + [RegisterGradient("TopK")] + public static Tensor[] _TopKGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var _ = grads[1]; + + var in_shape = array_ops.shape(op.inputs[0]); + var ind_shape = array_ops.shape(op.outputs[1]); + + // int32 is not supported on GPU hence up-casting + var cast = math_ops.cast(ind_shape, TF_DataType.TF_INT64); + var size = array_ops.size(ind_shape) - 1; + var ind_lastdim = array_ops.gather(cast, size); + + // Flatten indices to 2D. + var stack = array_ops.stack(new object[] { -1L, ind_lastdim }); + var ind_2d = array_ops.reshape(op.outputs[1], stack); + + var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64), + array_ops.size(in_shape) - 1); + var outerdim = array_ops.shape(ind_2d).slice(0); + + // Compute linear indices(flattened to 1D). + var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64); + var range2 = math_ops.range(tf.constant(0L), cast1 * in_lastdim, in_lastdim); + var dim2 = array_ops.expand_dims(range2, -1); + var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32); + var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 }); + + // Substitute grad to appropriate locations and fill the rest with zeros, + // finally reshaping it to the original input shape. + var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1), + array_ops.reshape(grad, new int[] { -1 }), + math_ops.reduce_prod(in_shape)); + + return new Tensor[] + { + array_ops.reshape(scatter, in_shape), + array_ops.zeros(new int[0], dtype: TF_DataType.TF_INT32) + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs new file mode 100644 index 000000000..7d3ea1715 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -0,0 +1,120 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Tensorflow.Gradients; + +namespace Tensorflow +{ + public partial class ops + { + public static Dictionary> gradientFunctions = null; + + public static void RegisterFromAssembly() + { + if (gradientFunctions == null) + { + gradientFunctions = new Dictionary>(); + + var gradGroups = Assembly.GetExecutingAssembly() + .GetTypes() + .Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var g in gradGroups) + { + var methods = g.GetMethods() + .Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var m in methods) + { + RegisterGradientFunction(m.GetCustomAttribute().Name, + (oper, out_grads) => + { + // tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); + + var results = g.InvokeMember(m.Name, + BindingFlags.InvokeMethod, + null, + null, + args: new object[] { oper, out_grads }) as Tensor[]; + + // foreach (var result in results.Where(x => x != null)) + // tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); + + return results; + } + ); + } + + // REGISTER_NO_GRADIENT_OP + methods = g.GetMethods() + .Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var m in methods) + RegisterNoGradientFunction(m.GetCustomAttribute().Name); + } + } + } + + /// + /// Regiter new gradient function + /// + /// operation type + /// function delegate + public static void RegisterGradientFunction(string name, Func func) + { + RegisterFromAssembly(); + + gradientFunctions[name] = func; + } + + public static void RegisterNoGradientFunction(string name) + { + RegisterFromAssembly(); + + gradientFunctions[name] = null; + } + + public static Func get_gradient_function(Operation op) + { + if (op.inputs == null) return null; + + var gradient_function = op._gradient_function; + if(gradient_function is null) + { + RegisterFromAssembly(); + + if (!gradientFunctions.ContainsKey(op.type)) + throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); + + return gradientFunctions[op.type]; + } + + Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) + { + return gradient_function(operation, args); + } + // TODO(Rinne): check if this needs to be registered. + return wrapped_gradient_function; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/resource_variable_grad.cs b/src/TensorFlowNET.Core/Gradients/resource_variable_grad.cs new file mode 100644 index 000000000..5ab55011b --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/resource_variable_grad.cs @@ -0,0 +1,28 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Gradients +{ + [RegisterGradient("resource_variable_grad")] + public class resource_variable_grad + { + [RegisterGradient("ReadVariableOp")] + public static Tensor[] _ReadGrad(Operation op, Tensor[] grads) + { + return new Tensor[] { grads[0] }; + } + } +} diff --git a/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs new file mode 100644 index 000000000..f662b4486 --- /dev/null +++ b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs @@ -0,0 +1,40 @@ +using Google.Protobuf; + +namespace Tensorflow +{ + public class GraphTransformer + { + /// + /// Graph Transform Tool + /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md + /// + /// GraphDef object containing a model to be transformed + /// the model inputs + /// the model outputs + /// transform names and parameters + /// + public GraphDef TransformGraph(GraphDef input_graph_def, + string[] inputs, + string[] outputs, + string[] transforms) + { + var input_graph_def_string = input_graph_def.ToByteArray(); + var inputs_string = string.Join(",", inputs); + var outputs_string = string.Join(",", outputs); + var transforms_string = string.Join(" ", transforms); + var status = new Status(); + var buffer = new Buffer(); + var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, + input_graph_def_string.Length, + inputs_string, + outputs_string, + transforms_string, + buffer, + status); + + status.Check(false); + var bytes = buffer.ToArray(); + return GraphDef.Parser.ParseFrom(bytes); + } + } +} diff --git a/src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs b/src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs new file mode 100644 index 000000000..3b3508399 --- /dev/null +++ b/src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + [DllImport(TensorFlowLibName)] + public static extern int TransformGraphWithStringInputs(byte[] graph_def_string, + int graph_def_string_len, + string inputs_string, + string outputs_string, + string transforms_string, + SafeBufferHandle output_buffer, + SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs new file mode 100644 index 000000000..48d14d6bd --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -0,0 +1,83 @@ +using System; +using System.Diagnostics; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Graphs +{ + public class AutoGraph + { + public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) + { + string func_name = $"{func.Method.Name}_{ops.uid_function()}"; + + var graph = new FuncGraph(func_name); + graph.as_default(); + + var input = tf.placeholder(dtype); + var output = func(input); + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + graph.ToGraph(opers, + new[] { input }, + new[] { output }, + null); + graph.Exit(); + + + return (Tensor input) => + { + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { input }, + null, + 1); + return result[0]; + } + var s = tf.Session(input.graph); + var output = func(input); + return output; + }; + } + + public Func to_graph(Func func, params TF_DataType[] dtypes) + { + string func_name = $"{func.Method.Name}_{ops.uid_function()}"; + + var graph = new FuncGraph(func_name); + graph.as_default(); + + var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); + var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); + var output = func(input1, input2); + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + graph.ToGraph(opers, + new[] { input1, input2 }, + new[] { output }, + null); + graph.Exit(); + + return (Tensor a, Tensor b) => + { + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { a, b }, + null, + 1); + return result[0]; + } + var s = tf.Session(a.graph); + Debug.Assert(a.graph == b.graph); + var output = func(a, b); + return output; + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs new file mode 100644 index 000000000..cc283db4e --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -0,0 +1,125 @@ +using MethodBoundaryAspect.Fody.Attributes; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Tensorflow.Eager; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow.Graphs +{ + /// + /// func_graph.py func_graph_from_py_func + /// + [AllowChangingInputArguments] + public sealed class AutoGraphAttribute : OnMethodBoundaryAspect + { + ConcreteFunction function; + Tensors originalInputs; + string func_name; + static Dictionary functions = new Dictionary(); + + public override void OnEntry(MethodExecutionArgs args) + { + // TODO: func_name can be cache in FullName + Args + func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; + + if (functions.ContainsKey(func_name)) + { + function = functions[func_name]; + if (args.Arguments[0] is Tensors tensor_inputs) + args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); + else + args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); + args.FlowBehavior = FlowBehavior.Return; + return; + } + + // make function as an Operation by autograph + // need to restore mode when exits + function = new ConcreteFunction(func_name); + function.Enter(); + + // convert to Tensors + if (args.Arguments[0] is Tensors inputs) + { + originalInputs = inputs; + var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray(); + args.Arguments[0] = new Tensors(new_inputs); + } + else + { + originalInputs = new Tensors(); + // convert args to placeholder + for (var i = 0; i < args.Arguments.Length; i++) + { + if (args.Arguments[i] is EagerTensor tensor) + { + originalInputs.Add(tensor); + args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs"); + } + } + } + } + + public override void OnExit(MethodExecutionArgs args) + { + if (args.ReturnValue is Tensors outputs) + { + Tensors inputs = null; + outputs = mark_as_return(outputs); + if (args.Arguments[0] is Tensors inputs1) + inputs = inputs1; + else + inputs = args.Arguments.Select(x => x as Tensor).ToArray(); + + inputs = inputs.Where(x => x.op.OpType == "Placeholder" + && x.op.name.StartsWith("inputs")).ToArray(); + + function.ToGraph(inputs, outputs); + } + else if (args.ReturnValue is Tensor output) + { + var inputs = args.Arguments.Select(x => x as Tensor) + .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) + .ToArray(); + var outputs2 = array_ops.identity(output); + function.ToGraph(inputs, outputs2); + } + + function.Exit(); + + // cache function. + function.ReturnType = args.ReturnValue.GetType(); + function._set_infer_function(); + functions[func_name] = function; + + // run function + args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); + } + + object ConvertReturnValue(Tensors tensors) + { + if (function.ReturnType == typeof(Tensor)) + return (Tensor)tensors; + else + return tensors; + } + + /// + /// Acts like identity but marks the `Tensor` as a return value. + /// + /// + /// + public Tensors mark_as_return(Tensors tensors) + { + if (tensors == null) + return null; + var result = new Tensors(); + foreach (var tensor in tensors) + result.Add(array_ops.identity(tensor)); + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs new file mode 100644 index 000000000..622b00713 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -0,0 +1,60 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + /// + /// Serves as a stack for determining current default graph. + /// + public class DefaultGraphStack + { + Stack _stack = new Stack(); + + public Graph get_default() + { + if (_stack.Count == 0) + _stack.Push(new Graph()); + + return _stack.Peek(); + } + + public Graph get_controller(Graph g) + { + _stack.Push(g); + return g; + } + + public Graph peak_controller() + { + if (_stack.Count == 0) + return null; + return _stack.Peek(); + } + + public void pop() + { + _stack.Pop(); + } + + public void reset() + { + _stack.Clear(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs b/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs new file mode 100644 index 000000000..49815edbc --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/FreezeGraph.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public class FreezeGraph + { + public static void freeze_graph(string input_graph, + string input_saver, + bool input_binary, + string input_checkpoint, + string output_node_names, + string restore_op_name, + string filename_tensor_name, + string output_graph, + bool clear_devices, + string initializer_nodes) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs new file mode 100644 index 000000000..6f7fa9c5f --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -0,0 +1,609 @@ +using Google.Protobuf; +using System; +using System.Buffers; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Eager; +using Tensorflow.Exceptions; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Functions; +using Tensorflow.NumPy; +using Tensorflow.Operations; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Graphs; + +/// +/// Graph representing a function body. +/// +public class FuncGraph : Graph, IDisposable +{ + internal SafeFuncGraphHandle _func_graph_handle; + internal HashSet _resource_tensor_inputs; + internal HashSet> _watched_variables; + internal IEnumerable> _weak_variables; + internal object[] _structured_outputs; + internal Dictionary _output_names; + public string FuncName => _graph_key; + + public Tensors Inputs { get; set; } = new Tensors(); + public Tensors Outputs { get; set; } = new Tensors(); + public Tensors FlatStructuredOutputs + { + get + { + List res = new(); + foreach(var obj in _structured_outputs) + { + if(obj is Tensor tensor) + { + res.Add(tensor); + } + else if(obj is IEnumerable tensors) + { + res.AddRange(tensors); + } + else + { + throw new TypeError("The structured outputs member should be tensor or tensors."); + } + } + return res; + } + } + public string Name { get; set; } + public IEnumerable Variables + { + get + { + return _weak_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + throw new AssertionError("Called a function referencing variables which have been deleted. " + + "This likely means that function-local variables were created and " + + "not referenced elsewhere in the program. This is generally a " + + "mistake; consider storing variables in an object attribute on first call."); + } + }); + } + internal set + { + _weak_variables = value.Select(x => new WeakReference(x)); + } + } + public IEnumerable TrainableVariables => Variables.Where(v => v.Trainable); + public Dictionary Attrs { get; set; } + + internal Dictionary _captures + = new Dictionary(); + + public Tensor[] external_captures + => _captures.Select(x => x.Value.Item1).ToArray(); + public (Tensor, Tensor)[] captures + => _captures.Values.Select(x => x).ToArray(); + + public Tensor[] internal_captures + => _captures.Select(x => x.Value.Item2).ToArray(); + + public Tensor[] captured_inputs + => external_captures; + + /// + /// Construct a new FuncGraph. + /// + public FuncGraph(string name) : base() + { + outer_graph = ops.get_default_graph(); + while (outer_graph.building_function) + outer_graph = outer_graph.OuterGraph; + _graph_key = Name = name; + building_function = true; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); + } + + public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base() + { + outer_graph = ops.get_default_graph(); + while (outer_graph.building_function) + outer_graph = outer_graph.OuterGraph; + _graph_key = Name = name; + building_function = true; + Attrs = attrs; + // Will to test if FuncGraph has memory leak + // c_api.TF_DeleteGraph(_handle); + _handle = handle; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); + } + + public void replace_capture(Tensor tensor, Tensor placeholder) + { + _captures[tensor.Id] = (tensor, placeholder); + } + + public unsafe void ToGraph(Operation[] opers, + Tensor[] inputs, Tensor[] outputs, + string[] output_names) + { + var status = new Status(); + if (output_names is null) + { + output_names = new string[0]; + }; + + _func_graph_handle = c_api.TF_GraphToFunction(_handle, + _graph_key, + false, + opers.Length, + opers.Select(x => (IntPtr)x).ToArray(), + inputs.Length, + inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), + outputs.Length, + outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), + output_names.Length != outputs.Length ? null : output_names, + IntPtr.Zero, + null, + status); + status.Check(true); + + SetAttrs(); + + // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); + // status.Check(true); + + c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status); + status.Check(true); + + _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); + + Inputs = inputs; + // mark_as_return + Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); + } + + public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null, bool compute_device = true) + { + foreach(var (i, inp) in enumerate(inputs)) + inputs[i] = capture(inp); + + return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); + } + + const int _EAGER_CONST_THRESHOLD = 128; + public Tensor capture(Tensor tensor, string name = null, Shape shape = null) + { + if(tensor is EagerTensor or NDArray) + { + if (name == null) + name = ops.uid().ToString(); + + // Small EagerTensors are captured with Const ops + if (dtypes.is_value_dtype(tensor.dtype) + && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) + return capture_eager_tensor(tensor, name); + + // Large EagerTensors and resources are captured with Placeholder ops + return _capture_helper(tensor, name, shape: shape); + } + + if(tensor.graph != this) + { + if (name == null) + name = tensor.op.name; + var inner_graph = tensor.graph; + while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) + { + if (inner_graph == this) + throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + + " in another function or code block. Use return values," + + " explicit Python locals or TensorFlow collections to access" + + $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); + inner_graph = inner_func_graph.outer_graph; + } + return _capture_helper(tensor, name); + } + + return tensor; + } + + public void watch_variable(IVariableV1 v) + { + if (_resource_tensor_inputs.Contains(v.Handle)) + { + return; + } + _watched_variables.Add(new WeakReference(v)); + //this = this.outer_graph; + } + + Tensor capture_eager_tensor(Tensor tensor, string name) + { + Tensor graph_const = null; + if (!_captures.ContainsKey(tensor.Id)) + { + graph_const = tf_with(ops.control_dependencies(null), ctl + => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); + add_capture(tensor, graph_const); + } + else + { + graph_const = _captures[tensor.Id].Item2; + } + + BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => + { + return output_grads; + }; + + tf.Runner.RecordGradient("captured_value", + new[] { graph_const }, null, + new[] { tensor }, + getBackwardFunction: _backward_function_wrapper + /*getForwardFunction: forward_function*/); + + return graph_const; + } + + Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) + { + Tensor placeholder = null; + if (!_captures.ContainsKey(tensor.Id)) + { + placeholder = _create_substitute_placeholder(tensor, + name: name, + dtype: tensor.dtype, + shape: shape); + add_capture(tensor, placeholder); + } + else + { + placeholder = _captures[tensor.Id].Item2; + } + + BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => + { + return output_grads; + }; + + tf.Runner.RecordGradient("captured_value", + new[] { placeholder }, null, + new[] { tensor }, + getBackwardFunction: _backward_function_wrapper + /*getForwardFunction: forward_function*/); + + return placeholder; + } + + void add_capture(Tensor tensor, Tensor placeholder) + { + _captures.Add(tensor.Id, (tensor, placeholder)); + Inputs.Add(placeholder); + } + + Tensor pop_capture(Tensor tensor) + { + if(_captures.TryGetValue(tensor.Id, out var capture)) + { + _captures.Remove(tensor.Id); + return capture.Item2; + } + else + { + return null; + } + } + + Tensor _create_substitute_placeholder(Tensor value, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + Shape shape = null) + { + if (shape is null) + shape = value.shape; + if (dtype == TF_DataType.DtInvalid) + dtype = value.dtype; + + var placeholder = tf_with(ops.control_dependencies(null), ctl + => array_ops.placeholder(dtype, shape: shape, name: name)); + // custom_gradient.copy_handle_data(value, placeholder) + return placeholder; + } + + void SetAttrs() + { + if (Attrs == null) + return; + + foreach (var (_name, attr_value) in enumerate(Attrs)) + { + var serialized = attr_value.ToByteArray(); + c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); + tf.Status.Check(true); + } + } + + public override Graph as_default() + { + tf.Context.graph_mode(isFunc: true); + ops.set_default_graph(this); + return this; + } + + public override void Exit() + { + tf.Context.restore_mode(); + ops.pop_graph(); + } + + public void Dispose() + { + c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); + } + + public static FuncGraph func_graph_from_func(string name, Func func, + object[] args, Dictionary kwargs, TensorSpec[] signature = null, + FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, + bool add_control_dependencies = true, string[] arg_names = null, + Tensor op_return_value = null, bool capture_by_value = false, + bool acd_record_initial_resource_uses = false) + { + if(func_graph is null) + { + func_graph = new FuncGraph(name); + } + + // TODO(Rinne): deal with control dependencies. + + func_graph.as_default(); + var current_scope = variable_scope.get_variable_scope(); + var default_use_resource = current_scope.use_resource; + current_scope.use_resource = true; + + if(signature is not null) + { + args = signature; + kwargs = new Dictionary(); + } + var func_args = _get_defun_inputs_from_args(args, arg_names); + var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); + + if(func_kwargs is not null && func_kwargs.Count > 0) + { + throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); + } + + foreach(var arg in nest.flatten(new object[] { func_args, func_kwargs })) + { + if(arg is Tensor tensor && tensor.dtype == dtypes.resource) + { + func_graph._resource_tensor_inputs.Add(tensor); + } + else if (arg is ResourceVariable variable) + { + func_graph._resource_tensor_inputs.Add(variable.Handle); + } + } + + // skip the assignment of `func_graph.structured_input_signature`. + + var flat_func_args = nest.flatten(func_args as object); + var flat_func_kwargs = nest.flatten(func_kwargs as object); + func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) + .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); + + //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); + //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); + + Tensor convert(object x) + { + if (x is null) return null; + Tensor res = null; + if(op_return_value is not null && x is Operation) + { + tf_with(ops.control_dependencies(new object[] { x }), _ => + { + res = array_ops.identity(op_return_value); + }); + } + else if(x is not TensorArray) + { + Debug.Assert(x is Tensor); + res = ops.convert_to_tensor_or_composite(x as Tensor); + } + else + { + throw new NotImplementedException($"The `TensorArray` is not supported here currently."); + } + if (add_control_dependencies) + { + // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. + } + return res; + } + + if (autograph) + { + throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); + } + + var func_outputs = func(func_args); + func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); + func_outputs = func_outputs.Select(x => convert(x)).ToArray(); + // TODO(Rinne): `check_func_mutation`. + + current_scope.use_resource = default_use_resource; + + var graph_variables = func_graph._watched_variables.ToList(); + HashSet arg_variables = new HashSet(); + List inputs = new(); + foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) + { + if(arg is BaseResourceVariable variable) + { + var resource_placeholder = func_graph.pop_capture(variable.Handle); + if(resource_placeholder is null) + { + continue; + } + Debug.Assert(variable is IVariableV1); + arg_variables.Add(variable as IVariableV1); + inputs.Add(resource_placeholder); + } + else if(arg is Tensor tensor) + { + inputs.Add(tensor); + } + } + var variables = graph_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + return null; + } + }).Where(v => v is not null && !arg_variables.Contains(v)); + func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); + func_graph._structured_outputs = func_outputs; + func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) + .Select(x => func_graph.capture(x))); + + func_graph.Variables = variables; + + func_graph.Exit(); + + if (add_control_dependencies) + { + // TODO(Rinne): implement it. + } + return func_graph; + } + + private static object[] _get_defun_inputs_from_args(object[] args, string[] names) + { + return _get_defun_inputs(args, names, args) as object[]; + } + + private static Dictionary _get_defun_inputs_from_kwargs(Dictionary kwargs) + { + // TODO(Rinne): implement it. + Debug.Assert(kwargs is null || kwargs.Count == 0); + return kwargs; + //string[] names; + //object[] args; + //if(kwargs is not null && kwargs.Count > 0) + //{ + // var sorted_kwargs = kwargs.OrderBy(x => x.Key); + // names = sorted_kwargs.Select(x => x.Key).ToArray(); + // args = sorted_kwargs.Select(x => x.Value).ToArray(); + //} + //else + //{ + // names = new string[0]; + // args = new object[0]; + //} + //return _get_defun_inputs(args, names, kwargs) as Dictionary; + } + + private static object _get_defun_inputs(object[] args, string[] names, object structured_args) + { + List function_inputs = new(); + if(names is null) + { + names = new string[args.Length]; + } + + foreach(var (arg_value, name) in zip(args, names)) + { + foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) + { + function_inputs.Add(_get_defun_input(val, name)); + } + } + return nest.pack_sequence_as(structured_args, nest.flatten(function_inputs), true); + } + + private static object _get_defun_input(object arg, string name) + { + var func_graph = ops.get_default_graph() as FuncGraph; + Debug.Assert(func_graph is not null); + if (arg is Tensor tensor) + { + Tensor placeholder; + try + { + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); + } + catch (ValueError ex) + { + tf.Logger.Warning(ex.ToString()); + placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); + } + handle_data_util.copy_handle_data(tensor, placeholder); + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + } + return placeholder; + } + else if (arg is TensorSpec spec) + { + string requested_name; + if (!string.IsNullOrEmpty(spec.name)) + { + requested_name = spec.name; + } + else + { + requested_name = name; + } + Tensor placeholder; + try + { + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); + } + catch (ValueError) + { + // TODO(Rinne): Add warning here. + placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); + } + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(requested_name) + }); + } + return placeholder; + } + else if (arg is BaseResourceVariable variable) + { + var placeholder = func_graph.capture(variable.Handle, name); + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + return arg; + } + // TODO(Rinne): deal with `VariableSpec`. + else + { + return arg; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs new file mode 100644 index 000000000..15cf90f10 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -0,0 +1,150 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Graph + { + // Current control flow context. It could be either CondContext or WhileContext + public ControlFlowContext _control_flow_context; + + // represents the nested with(...) statements + public List<_ControlDependenciesController> _control_dependencies_stack { get; set; } = new List<_ControlDependenciesController>(); + + /// + /// For an op that takes `input_ops` as inputs, compute control inputs. + /// + /// The data input ops for an op to be created. + /// A list of control inputs for the op to be created. + public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) + { + var ret = new List(); + + foreach (var controller in _control_dependencies_stack) + { + bool dominated = false; + // If any of the input_ops already depends on the inputs from controller, + // we say that the new op is dominated (by that input), and we therefore + // do not need to add control dependencies for this controller's inputs. + foreach (var op in input_ops) + { + if (controller.op_in_group(op)) + { + dominated = true; + break; + } + } + + if (!dominated) + ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); + } + + return ret.ToArray(); + } + + /// + /// Returns a context manager that specifies control dependencies. + /// + /// Use with the `with` keyword to specify that all operations constructed + /// within the context should have control dependencies on + /// `control_inputs`. + /// + [SuppressMessage("ReSharper", "CoVariantArrayConversion")] + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) + => control_dependencies((object[])control_inputs); + + /// + /// Returns a context manager that specifies control dependencies. + /// + /// Use with the `with` keyword to specify that all operations constructed + /// within the context should have control dependencies on + /// `control_inputs`. + /// + public _ControlDependenciesController control_dependencies(object[] control_inputs) + { + if (control_inputs == null || tf.Context.executing_eagerly()) + return new _ControlDependenciesController(this, null); + + var control_ops = new List(); + foreach (var c in control_inputs) + { + switch (c) + { + // TODO: implement IndexedSlices + //case IndexedSlices islice: + // control_ops.Add(islice.op); + // break; + case Tensor t: + control_ops.Add(t.op); + break; + case Operation op: + control_ops.Add(op); + break; + default: + var t1 = _as_graph_element(c); + if (t1 == null) + throw new TypeError($"Control input must be Operation or Tensor:{c}"); + control_ops.Add(t1.op); + break; + } + } + return new _ControlDependenciesController(this, control_ops); + } + + /// + /// Returns the current control flow context. + /// + /// A context object. + public ControlFlowContext _get_control_flow_context() + { + return _control_flow_context; + } + + /// + /// Sets the current control flow context. + /// + /// a context object. + public void _set_control_flow_context(ControlFlowContext ctx) + { + _control_flow_context = ctx; + } + + public void _push_control_dependencies_controller(_ControlDependenciesController controller) + { + _control_dependencies_stack.Add(controller); + } + + public void _pop_control_dependencies_controller(_ControlDependenciesController controller) + { + _control_dependencies_stack.RemoveAt(_control_dependencies_stack.Count - 1); + } + + /// + /// Record that the given op depends on all registered control dependencies. + /// + public void _record_op_seen_by_control_dependencies(Operation op) + { + foreach (var controller in _control_dependencies_stack) + controller.add_op(op); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs new file mode 100644 index 000000000..a11d91e73 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -0,0 +1,52 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using Tensorflow.Util; + +namespace Tensorflow +{ + public partial class Graph + { + public Buffer ToGraphDef(Status s) + { + var buffer = new Buffer(); + c_api.TF_GraphToGraphDef(_handle, buffer, s); + s.Check(true); + + return buffer; + } + + private GraphDef _as_graph_def(bool add_shapes = false) + { + GraphDef def; + var status = new Status(); + var buffer = ToGraphDef(status); + status.Check(true); + // limit size to 250M, recursion to max 100 + var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); + def = GraphDef.Parser.ParseFrom(inputStream); + + // Strip the experimental library field iff it's empty. + // if(def.Library.Function.Count == 0) + + return def; + } + + public GraphDef as_graph_def(bool add_shapes = false) + => _as_graph_def(add_shapes); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs new file mode 100644 index 000000000..bed8b35ca --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs @@ -0,0 +1,17 @@ +using Tensorflow.Graphs; + +namespace Tensorflow +{ + public partial class Graph + { + public void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false) + { + + } + + internal GraphOverrideGradientContext _override_gradient_function(Dictionary> gradient_function_map) + { + return new GraphOverrideGradientContext(this, gradient_function_map); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs new file mode 100644 index 000000000..b80e26590 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -0,0 +1,69 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.IO; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class Graph + { + public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s) + { + as_default(); + var num_return_outputs = opts.NumReturnOutputs; + var return_outputs = new TF_Output[num_return_outputs]; + int size = Marshal.SizeOf(); + var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); + + c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); + + var tf_output_ptr = (TF_Output*)return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + + Marshal.FreeHGlobal(return_output_handle); + + return return_outputs; + } + + public bool Import(string file_path, string prefix = "") + { + var bytes = File.ReadAllBytes(file_path); + return Import(bytes, prefix: prefix); + } + + public bool Import(byte[] bytes, string prefix = "") + { + var opts = new ImportGraphDefOptions(); + var status = new Status(); + var graph_def = new Buffer(bytes); + + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); + c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); + status.Check(true); + return status.Code == TF_Code.TF_OK; + } + + public Graph ImportGraphDef(string file_path, string name = null) + { + as_default(); + var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path)); + importer.import_graph_def(graph_def, name: name); + return this; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs new file mode 100644 index 000000000..c788aaf01 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -0,0 +1,163 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Graph + { + public OpDef GetOpDef(string type) + => op_def_registry.GetOpDef(type); + + public OperationDescription NewOperation(string opType, string opName) + { + return c_api.TF_NewOperation(_handle, opType, opName); + } + + public Operation[] ReturnOperations(SafeImportGraphDefResultsHandle results) + { + TF_Operation return_oper_handle = new TF_Operation(); + int num_return_opers = 0; + c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); + Operation[] return_opers = new Operation[num_return_opers]; + var tf_op_size = Marshal.SizeOf(); + for (int i = 0; i < num_return_opers; i++) + { + unsafe + { + var handle = return_oper_handle.node + tf_op_size * i; + return_opers[i] = new Operation(*(IntPtr*)handle); + } + } + + return return_opers; + } + + /// + /// Get operation with given + /// + /// When is not found current graph. + /// When tf.get_default_graph() is not current graph. + /// + /// graph.GetOperationByName("CustomInputName"); + /// + public Operation OperationByName(string operName) + { + if (operName == null) + throw new ArgumentNullException(nameof(operName)); + + var handle = c_api.TF_GraphOperationByName(_handle, operName); + if (handle == IntPtr.Zero) + throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); + + /*var defaultKey = tf.get_default_graph().graph_key; + if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) + { + throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); + }*/ + + return new Operation(handle, g: this); + } + + public ITensorOrOperation[] get_operations() + { + return _nodes_by_name.Values.ToArray(); + } + + /// + /// Returns the `Operation` with the given `name`. + /// + /// This method may be called concurrently from multiple threads. + /// + /// The name of the `Operation` to return. + public Operation get_operation_by_name(string name) + => as_graph_element(name, allow_tensor: false, allow_operation: true) as Operation; + + public ITensorOrOperation _get_operation_by_name_unsafe(string name) + { + return _nodes_by_name.TryGetValue(name, out var val) ? val : null; + } + + public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) + { + var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper)); + return _get_operation_by_name_unsafe(op_name); + } + + /// + /// Creates an `Operation` in this graph from the supplied TF_Operation. + /// + /// This method is like create_op() except the new Operation is constructed + /// using `c_op`. The returned Operation will have `c_op` as its _c_op + /// field.This is used to create Operation objects around TF_Operations created + /// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile). + /// + /// This function does not call Operation._control_flow_post_processing or + /// Graph._control_dependencies_for_inputs (since the inputs may not be + /// available yet). The caller is responsible for calling these methods. + /// + /// a wrapped TF_Operation + /// (Optional.) If True, device functions will be executed + /// to compute the device property of the Operation. + /// An `Operation` object. + public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) + { + var ret = new Operation(c_op, this); + _add_op(ret); + + var name_key = ret.name.ToLower(); + if (!_names_in_use.ContainsKey(name_key)) + _names_in_use[name_key] = 1; + + _create_op_helper(ret, compute_device: compute_device); + + return ret; + } + + /// + /// Creates `Operations` in this graph for any new TF_Operations. + /// + /// This is useful for when TF_Operations are indirectly created by the C API + /// outside of the Operation constructor (e.g. by TF_ImportGraphDef, + /// TF_FinishWhile). This ensures there are corresponding Operations for all + /// TF_Operations in the underlying TF_Graph. + /// + /// + /// + public IEnumerable _add_new_tf_operations(bool compute_devices = true) + { + var new_ops = c_api_util.new_tf_operations(this) + .Select(c_op => _create_op_from_tf_operation(c_op, compute_device: compute_devices)) + .ToArray(); + + foreach (var op in new_ops) + { + var new_control_inputs = _control_dependencies_for_inputs(op.inputs) + .Select(x => x as Operation) + .ToArray(); + op._add_control_inputs(new_control_inputs); + op._control_flow_post_processing(); + } + + return new_ops; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 96131b8a1..9e879a0f0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -1,48 +1,213 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections; using System.Collections.Generic; +using System.Collections.Specialized; using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using TF_DataType = Tensorflow.DataType; +using Tensorflow.Framework; +using Tensorflow.Functions; +using Tensorflow.Common.Extensions; +using Tensorflow.Graphs; +using static Tensorflow.Binding; namespace Tensorflow { + /* + A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of + `tf.Operation` objects, + which represent units of computation; and + `tf.Tensor` objects, which represent + the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + `tf.get_default_graph`. + To add an operation to the default graph, simply call one of the functions + that defines a new `Operation`: + + ```python + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + `tf.Graph.as_default` + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + `tf.GraphKeys.GLOBAL_VARIABLES`) for + all variables that are created during the construction of a graph. The caller + may define additional collections by specifying a new name. + */ + /// - /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. - /// This leads to a low-level programming model in which you first define the dataflow graph, - /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. - /// https://www.tensorflow.org/guide/graphs + /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. + /// This leads to a low-level programming model in which you first define the dataflow graph, + /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// - public class Graph + /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
+ public partial class Graph : IEnumerable { - private IntPtr _c_graph; - public IntPtr Handle => _c_graph; - private Dictionary _nodes_by_id; - private Dictionary _nodes_by_name; + protected new SafeGraphHandle _handle; + private Dictionary _nodes_by_id; + public Dictionary _nodes_by_name; private Dictionary _names_in_use; public int _version; private int _next_id_counter; - private List _unfetchable_ops = new List(); + private List _unfetchable_ops = new List(); + private List _unfeedable_tensors = new List(); + private Dictionary _functions = new(); + internal Dictionary> _gradient_function_map = new(); + private VersionDef _graph_def_versions = new VersionDef() + { + Producer = versions.GRAPH_DEF_VERSION, + MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER + }; + + public string _name_stack = ""; + protected string _graph_key; + public string graph_key => _graph_key; + public string _last_loss_reduction; + public bool _is_loss_scaled_by_optimizer { get; set; } + + /// + /// True if the graph is considered "finalized". In that case no + /// new operations can be added. + /// + private bool _finalized = false; + + /// + /// Arbitrary collections of objects. + /// + private Dictionary _collections = new Dictionary(); + + public bool building_function; + + string _container = ""; + public string Container => _container; + + int _seed; + public int seed + { + get => _seed; + set + { + _seed = value; + } + } + + internal Graph outer_graph; + public Graph OuterGraph => outer_graph; + public Dictionary Functions => _functions; + public SafeGraphHandle c_graph => _handle; - public Graph(IntPtr graph) + public Graph() { - this._c_graph = graph; - _nodes_by_id = new Dictionary(); - _nodes_by_name = new Dictionary(); + _handle = c_api.TF_NewGraph(); + _nodes_by_id = new Dictionary(); + _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); + _graph_key = $"graph-{ops.GraphUniqueId()}/"; } - public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) + public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } - private Func _as_graph_element(object obj) + /// + /// Returns a context manager that makes this `Graph` the default graph. + /// Must call Exit() to pop graph + /// + /// + public virtual Graph as_default() + { + tf.Context.graph_mode(isFunc: false); + return ops.set_default_graph(this); + } + + public bool IsFunction(string name) + { + return _functions.ContainsKey(tf.compat.as_str(name)); + } + + internal void AddFunction(EagerDefinedFunction function) + { + _check_not_finalized(); + + var name = function.Name; + if(function._grad_func_name is not null && function.csharp_grad_func is not null) + { + throw new ValueError($"Gradient defined twice for function {name}"); + } + + var c_graph = this.c_graph; + var func = function._c_func.Get(); + Status status = new(); + if (function._grad_func is not null) + { + var gradient = function._grad_func._c_func.Get(); + c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); + status.Check(true); + } + else + { + c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); + status.Check(true); + } + + _functions[tf.compat.as_str(name)] = function; + + if(_graph_def_versions.MinConsumer < 12) + { + _graph_def_versions.MinConsumer = 12; + } + } + + private Tensor _as_graph_element(object obj) { + if (obj is RefVariable var) + return var._as_graph_element(); + else if (obj is ResourceVariable resVar) + return resVar.GraphElement; + return null; } - private T _as_graph_element_locked(T obj, bool allow_tensor = true, bool allow_operation = true) + private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; @@ -60,50 +225,150 @@ private T _as_graph_element_locked(T obj, bool allow_tensor = true, bool allo } var temp_obj = _as_graph_element(obj); + if (temp_obj != null) + obj = temp_obj; + + // If obj appears to be a name... + if (obj is string name) + { + if (name.Contains(":") && allow_tensor) + { + string op_name = name.Split(':')[0]; + int out_n = int.Parse(name.Split(':')[1]); + + if (_nodes_by_name.ContainsKey(op_name)) + return _nodes_by_name[op_name].outputs[out_n]; + else + throw new KeyError($"The name {name} refers to a Tensor which does not " + + $"exist. The operation, {op_name}, does not exist in the " + + "graph."); + } + else if (!name.Contains(":") & allow_operation) + { + if (!_nodes_by_name.ContainsKey(name)) + throw new KeyError($"The name {name} refers to an Operation not in the graph."); + return _nodes_by_name[name]; + } + else if (!name.Contains(":") & !allow_operation) + { + // Looks like an Operation name but can't be an Operation. + if (_nodes_by_name.ContainsKey(name)) + // Yep, it's an Operation name + throw new ValueError($"The name {name} refers to an Operation, not a {types_str}."); + else + throw new ValueError( + $"The name {name} looks like an (invalid) Operation name, not a {types_str}" + + " Tensor names must be of the form \":\"."); + } + } - if(obj is Tensor && allow_tensor) + if (obj is Tensor tensor && allow_tensor) { - if ((obj as Tensor).graph.Equals(this)) + if (tensor.graph.Equals(this)) { - return obj; + return tensor; } else { throw new Exception($"Tensor {obj} is not an element of this graph."); } } + else if (obj is Operation op && allow_operation) + { + if (op.graph.Equals(this)) + { + return op; + } + else + { + throw new Exception($"Operation {obj} is not an element of this graph."); + } + } - throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); + throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } - public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes, - TF_DataType[] input_types = null, string name = "", - Dictionary attrs = null, OpDef op_def = null) + public void add_to_collection(string name, T value) { - if (String.IsNullOrEmpty(name)) - { + _check_not_finalized(); + if (_collections.ContainsKey(name)) + (_collections[name] as List).Add(value); + else + _collections[name] = new List { value }; + } + + public void add_to_collections(List names, T value) + { + foreach (string name in names) + add_to_collection(name, value); + } + + private void _check_not_finalized() + { + if (_finalized) + throw new RuntimeError("Graph is finalized and cannot be modified."); + } + + public virtual Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, + TF_DataType[] input_types = null, string name = null, + Dictionary attrs = null, OpDef op_def = null, + bool compute_device = true) + { + if (inputs == null) + inputs = new Tensor[0]; + + if (string.IsNullOrEmpty(name)) name = op_type; - } - name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); - var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); + // If a names ends with a '/' it is a "name scope" and we use it as-is, + // after removing the trailing '/'. + // This was causing duplicate graph node name errors, when testing a conv2d autoencoder + // https://keras.io/guides/functional_api/#:~:text=keras.,graph%20(DAG)%20of%20layers. + // name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); + name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); + var node_def = ops._NodeDef(op_type, name, attrs: attrs); + + var input_ops = inputs.Select(x => x.op).ToArray(); + var control_inputs = _control_dependencies_for_inputs(input_ops); - var op = new Operation(node_def, + var op = new Operation(node_def, this, inputs: inputs, output_types: dtypes, - control_inputs: new object[] { }, + control_inputs: control_inputs, input_types: input_types, original_op: null, op_def: op_def); + _create_op_helper(op, compute_device); + return op; } + public ITensorFlowObject device(string device_name) + { + return new GraphDeviceContext(this, device_name); + } + + private void add_device_to_stack(string device_name, int offset = 0) + { + // TODO(Rinne): deal with device spec. + int total_offset = offset + 1; + } + + private void _create_op_helper(Operation op, bool compute_device = true) + { + // high priority + // TODO(Rinne): complete the implementation. + op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); + _record_op_seen_by_control_dependencies(op); + } + public void _add_op(Operation op) { + op._id_value = _next_id(); _nodes_by_id[op._id] = op; - //_nodes_by_name[op.name] = op; + _nodes_by_name[op.name] = op; _version = Math.Max(_version, op._id); } @@ -114,38 +379,234 @@ public int _next_id() public bool is_fetchable(T tensor_or_op) { - if (tensor_or_op is Tensor) + if (tensor_or_op is Tensor tensor) { - return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; + return !_unfetchable_ops.Contains(tensor); ; } - else if (tensor_or_op is Operation) + else if (tensor_or_op is Operation op) { - return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); + return !_unfetchable_ops.Contains(op); } return false; } - public string unique_name(string name) + public string get_name_scope() + { + return _name_stack; + } + + public string name_scope(string name) + { + string new_stack = ""; + + if (string.IsNullOrEmpty(name)) + new_stack = ""; + else if (name.EndsWith("/")) + new_stack = ops.name_from_scope_name(name); + else + new_stack = unique_name(name); + + _name_stack = new_stack; + + return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; + } + + /// + /// Return a unique operation name for `name`. + /// + /// Note: You rarely need to call `unique_name()` directly.Most of + /// the time you just need to create `with g.name_scope()` blocks to + /// generate structured names. + /// + /// `unique_name` is used to generate structured names, separated by + /// `"/"`, to help identify operations when debugging a graph. + /// Operation names are displayed in error messages reported by the + /// TensorFlow runtime, and in various visualization tools such as + /// TensorBoard. + /// + /// If `mark_as_used` is set to `True`, which is the default, a new + /// unique name is created and marked as in use.If it's set to `False`, + /// the unique name is returned without actually being marked as used. + /// This is useful when the caller simply wants to know what the name + /// to be created will be. + /// + /// The name for an operation. + /// Whether to mark this name as being used. + /// A string to be passed to `create_op()` that will be used + /// to name the operation being created. + public string unique_name(string name, bool mark_as_used = true) { + if (!String.IsNullOrEmpty(_name_stack)) + name = _name_stack + "/" + name; + // For the sake of checking for names in use, we treat names as case + // insensitive (e.g. foo = Foo). var name_key = name.ToLower(); + int i = 0; if (_names_in_use.ContainsKey(name_key)) + i = _names_in_use[name_key]; + // Increment the number for "name_key". + if (mark_as_used) + _names_in_use[name_key] = i + 1; + if (i > 0) { - _names_in_use[name_key]++; + // Make sure the composed name key is not already used. + var base_name_key = name_key; + while (_names_in_use.ContainsKey(name_key)) + { + name_key = $"{base_name_key}_{i}"; + i += 1; + } + // Mark the composed name_key as used in case someone wants + // to call unique_name("name_1"). + if (mark_as_used) + _names_in_use[name_key] = 1; + + // Return the new name with the original capitalization of the given name. + name = $"{name}_{i - 1}"; } - else + return name; + } + + public TF_Output[] ReturnOutputs(SafeImportGraphDefResultsHandle results) + { + IntPtr return_output_handle = IntPtr.Zero; + int num_return_outputs = 0; + c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); + TF_Output[] return_outputs = new TF_Output[num_return_outputs]; + unsafe { - _names_in_use[name_key] = 1; - return name; + var tf_output_ptr = (TF_Output*)return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; } - + } - return $"{name}_{_names_in_use[name_key]}"; + public string[] get_all_collection_keys() + { + return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); + } + + public object get_collection(string name, string scope = null) + { + return _collections.ContainsKey(name) ? _collections[name] : null; } - public Operation[] get_operations() + public List get_collection(string name, string scope = null) + { + List t = default; + var collection = _collections.ContainsKey(name) ? _collections[name] : new List(); + switch (collection) + { + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + default: + throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); + } + return t; + } + + public List get_collection_ref(string name) + { + if (!_collections.ContainsKey(name)) + _collections[name] = new List(); + return _collections[name] as List; + } + + public void prevent_feeding(Tensor tensor) + { + _unfeedable_tensors.Add(tensor); + } + + public void prevent_fetching(Operation op) + { + _unfetchable_ops.Add(op); + } + + public Tensor get_tensor_by_tf_output(TF_Output tf_output) + { + var op = _get_operation_by_tf_operation(tf_output.oper); + return op.outputs[tf_output.index]; + } + + /// + /// Returns the with the given . + /// This method may be called concurrently from multiple threads. + /// + /// The name of the `Tensor` to return. + /// If does not correspond to a tensor in this graph. + /// The `Tensor` with the given . + public Tensor get_tensor_by_name(string name) + { + return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); + } + + public Shape GetTensorShape(TF_Output output) + { + var status = tf.Status; + var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); + status.Check(); + + if (ndim == -1) + return Shape.Null; + + var dims = new long[ndim]; + c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); + status.Check(); + + return new Shape(dims.Select(x => (int)x).ToArray()); + } + + public virtual void Exit() + { + tf.Context.restore_mode(); + ops.pop_graph(); + } + + internal EagerDefinedFunction _get_function(string name) + { + return _functions.GetOrDefault(name, null); + } + + string debugString = string.Empty; + public override string ToString() + { + return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}"; + /*if (string.IsNullOrEmpty(debugString)) + { + int len = 0; + debugString = c_api.TF_GraphDebugString(_handle, out len); + } + + return debugString;*/ + } + + private IEnumerable GetEnumerable() + => c_api_util.tf_operations(this); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => throw new NotImplementedException(); + + public static implicit operator SafeGraphHandle(Graph graph) { - return _nodes_by_name.Values.Select(x => x).ToArray(); + return graph._handle; } } } diff --git a/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs new file mode 100644 index 000000000..2754c2b36 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Graphs +{ + public class GraphDeviceContext : ITensorFlowObject + { + private Graph _graph; + + public GraphDeviceContext(Graph graph, string device_name) + { + _graph = graph; + } + + public void __enter__() + { + + } + + public void __exit__() + { + + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs new file mode 100644 index 000000000..2befbbff6 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Tensorflow.Graphs +{ + internal class GraphOverrideGradientContext: ITensorFlowObject + { + Graph _graph; + Dictionary> _new_gradient_function_map; + public GraphOverrideGradientContext(Graph graph, + Dictionary> new_gradient_function_map) + { + _graph = graph; + _new_gradient_function_map = new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __enter__() + { + Debug.Assert(_graph._gradient_function_map.Count == 0); + _graph._gradient_function_map = _new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __exit__() + { + _graph._gradient_function_map = new Dictionary>(); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs new file mode 100644 index 000000000..a7ce6ff5f --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -0,0 +1,42 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow; + +public sealed class ImportGraphDefOptions +{ + SafeImportGraphDefOptionsHandle _handle { get; } + + public int NumReturnOutputs + => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); + + public ImportGraphDefOptions() + { + _handle = c_api.TF_NewImportGraphDefOptions(); + } + + public SafeImportGraphDefOptionsHandle Options => _handle; + + public void AddReturnOutput(string name, int index) + { + c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); + } + + public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt) + { + return opt._handle; + } +} diff --git a/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs new file mode 100644 index 000000000..f38301b64 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs @@ -0,0 +1,22 @@ +using Tensorflow.Util; + +namespace Tensorflow; + +public sealed class SafeFuncGraphHandle : SafeTensorflowHandle +{ + private SafeFuncGraphHandle() + { + } + + public SafeFuncGraphHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteFunction(handle); + SetHandle(IntPtr.Zero); + return true; + } +} diff --git a/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs new file mode 100644 index 000000000..a6da01987 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs @@ -0,0 +1,22 @@ +using Tensorflow.Util; + +namespace Tensorflow; + +public sealed class SafeGraphHandle : SafeTensorflowHandle +{ + private SafeGraphHandle() + { + } + + public SafeGraphHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteGraph(handle); + SetHandle(IntPtr.Zero); + return true; + } +} diff --git a/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs new file mode 100644 index 000000000..9fc62f1d2 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeImportGraphDefOptionsHandle : SafeTensorflowHandle + { + private SafeImportGraphDefOptionsHandle() + { + } + + public SafeImportGraphDefOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteImportGraphDefOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs new file mode 100644 index 000000000..8a84eff8c --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SafeImportGraphDefResultsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeImportGraphDefResultsHandle : SafeTensorflowHandle + { + private SafeImportGraphDefResultsHandle() + { + } + + public SafeImportGraphDefResultsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteImportGraphDefResults(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs new file mode 100644 index 000000000..7c186f94b --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs @@ -0,0 +1,177 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Graphs +{ + public class SubGraphUtility + { + /// + /// Copies the tensor and all its inputs recursively to the outer graph. + /// + /// + /// + /// + /// + /// + /// + public static Dictionary lift_to_graph(Tensors init_tensors, + FuncGraph graph, + List sources, + bool add_sources = false, + bool handle_captures = false, + Graph base_graph = null, + Dictionary op_map = null) + { + base_graph = base_graph ?? init_tensors[0].graph; + op_map = op_map ?? new Dictionary(); + var visited_ops = sources.Select(x => x.op).ToList(); + foreach (var init_tensor in init_tensors) + { + var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); + sources.AddRange(src); + } + + var ops_to_copy = new List(); + var marked_ops = new List(); + var ops_to_visit = new Stack(init_tensors.Select(x => x.op)); + var unvisited_ops = new List(ops_to_visit.ToList()); + while (unvisited_ops.Count > 0) + { + while(ops_to_visit.Count > 0) + { + var op = ops_to_visit.Pop(); + if (marked_ops.Contains(op)) + continue; + marked_ops.Add(op); + ops_to_copy.append(op); + foreach(var inp in op.inputs) + { + + } + } + // difference_update + unvisited_ops.difference_update(marked_ops); + if (unvisited_ops.Count > 0) + ops_to_visit.Push(unvisited_ops.Last()); + } + + // When lifting from one FuncGraph to another, we will need to capture the + // relevant tensors as well. + var inverse_captures = new Dictionary(); + Tensor[] internal_captures = null; + if (base_graph is FuncGraph base_func_graph) + { + var captures = base_func_graph.captures; + foreach (var (external_capture, internal_capture) in captures) + inverse_captures[internal_capture] = external_capture; + internal_captures = base_func_graph.internal_captures; + } + + graph.as_default(); + var source_ops = new List(); + // Add the sources in the same order as the original graph. + foreach (var s in internal_captures) + { + if (sources.Contains(s)) + { + sources.Remove(s); + source_ops.Add(s.op); + _copy_source(s: s, + graph: graph, + op_map: op_map, + handle_captures: handle_captures, + inverse_captures: inverse_captures, + base_graph: base_graph); + } + } + + foreach(var op in reversed(ops_to_copy)) + { + if (source_ops.Contains(op) || op_map.ContainsKey(op)) + continue; + _copy_non_source(op, graph, op_map, base_graph); + } + + graph.Exit(); + + return op_map; + } + + static void _copy_source(Tensor s, + FuncGraph graph, + Dictionary op_map, + bool handle_captures, + Dictionary inverse_captures, + Graph base_graph) + { + Tensor copied_placeholder = null; + if (handle_captures && inverse_captures.ContainsKey(s)) + copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name); + else + throw new NotImplementedException(""); + op_map[s] = copied_placeholder; + // Add an entry for the op of the source tensor so that if there are any nodes + // depending on that op via control dependencies it can work correctly. + op_map[s.op] = copied_placeholder.op; + } + + static void _copy_non_source(Operation op, FuncGraph graph, Dictionary op_map, Graph base_graph) + { + Operation copied_op = null; + var copied_inputs = new Tensors(); + tf_with(ops.control_dependencies(new object[] { op }), delegate + { + // Create a new op in the destination graph if it doesn't exist before. + var attrs = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attrs[attr_def.Key] = attr_def.Value; + + copied_op = graph.create_op(op.type, + copied_inputs, + dtypes: op.outputs.Select(x => x.dtype).ToArray(), + attrs: attrs, + name: op.name); + }); + op_map[op] = copied_op; + foreach (var (i, o) in enumerate(op.outputs)) + op_map[o] = copied_op.outputs[i]; + } + + /// + /// Walk a Graph and capture the subgraph between init_tensor and sources. + /// + /// + /// + public static List map_subgraph(Tensor init_tensor, + List sources, + List visited_ops, + bool add_sources) + { + var ops_to_visit = new Stack(); + ops_to_visit.Push(init_tensor.op); + var extra_sources = new List(); + while (ops_to_visit.Count > 0) + { + var op = ops_to_visit.Pop(); + if (visited_ops.Contains(op)) + continue; + visited_ops.Add(op); + bool should_raise = false; + if (should_raise) + throw new RuntimeError($"Unable to lift tensor {init_tensor.name}."); + if(op.type == "Placeholder") + { + extra_sources.AddRange(op.outputs); + } + foreach(var inp in op.inputs) + { + + } + } + return extra_sources; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs new file mode 100644 index 000000000..eff8be94b --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs @@ -0,0 +1,75 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public sealed class TF_ImportGraphDefResults : IDisposable + { + /*public IntPtr return_nodes; + public IntPtr missing_unused_key_names; + public IntPtr missing_unused_key_indexes; + public IntPtr missing_unused_key_names_data;*/ + + private SafeImportGraphDefResultsHandle Handle { get; } + + public TF_ImportGraphDefResults(SafeImportGraphDefResultsHandle handle) + { + Handle = handle; + } + + public TF_Output[] return_tensors + { + get + { + IntPtr return_output_handle = IntPtr.Zero; + int num_outputs = -1; + c_api.TF_ImportGraphDefResultsReturnOutputs(Handle, ref num_outputs, ref return_output_handle); + TF_Output[] return_outputs = new TF_Output[num_outputs]; + unsafe + { + var tf_output_ptr = (TF_Output*)return_output_handle; + for (int i = 0; i < num_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; + } + } + } + + public TF_Operation[] return_opers + { + get + { + return new TF_Operation[0]; + /*TF_Operation return_output_handle = new TF_Operation(); + int num_outputs = -1; + c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle); + TF_Operation[] return_outputs = new TF_Operation[num_outputs]; + unsafe + { + var tf_output_ptr = (TF_Operation*)return_output_handle; + for (int i = 0; i < num_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; + }*/ + } + } + + public void Dispose() + => Handle.Dispose(); + } +} diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs new file mode 100644 index 000000000..66f90d5c5 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -0,0 +1,127 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Operations; + +namespace Tensorflow +{ + /// + /// Context manager for `control_dependencies()` + /// + public class _ControlDependenciesController : ITensorFlowObject + { + private Graph _graph; + private List _control_inputs_val; + private List _seen_nodes; + private List<_ControlDependenciesController> _old_stack; + private bool _new_stack; + private ControlFlowContext _old_control_flow_context; + + public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); + + /// + /// Create a new `_ControlDependenciesController`. + /// + /// A `_ControlDependenciesController` is the context manager for + /// `with tf.control_dependencies()` blocks.These normally nest, + /// as described in the documentation for `control_dependencies()`. + /// + /// The `control_inputs` argument list control dependencies that must be + /// added to the current set of control dependencies.Because of + /// uniquification the set can be empty even if the caller passed a list of + /// ops.The special value `None` indicates that we want to start a new + /// empty set of control dependencies instead of extending the current set. + /// + /// In that case we also clear the current control flow context, which is an + /// additional mechanism to add control dependencies. + /// + /// The graph that this controller is managing. + /// List of ops to use as control inputs in addition + /// to the current control dependencies.None to indicate that + /// the dependencies should be cleared. + /// + public _ControlDependenciesController(Graph graph, List control_inputs) + { + _graph = graph; + if (control_inputs == null) + { + _control_inputs_val = new List(); + _new_stack = true; + } + else + { + _control_inputs_val = control_inputs; + _new_stack = false; + } + + _seen_nodes = new List(); + _old_stack = null; + _old_control_flow_context = null; + } + + public void add_op(ITensorOrOperation op) + { + _seen_nodes.Add(op); + } + + public bool op_in_group(ITensorOrOperation op) + { + return _seen_nodes.Contains(op); + } + + public void __enter__() + { + if (_new_stack) + { + // Clear the control_dependencies graph. + _old_stack = _graph._control_dependencies_stack; + _graph._control_dependencies_stack = new List<_ControlDependenciesController>(); + + // Clear the control_flow_context too. + _old_control_flow_context = _graph._get_control_flow_context(); + _graph._set_control_flow_context(null); + } + + _graph._push_control_dependencies_controller(this); + } + + public void __exit__() + { + _graph._pop_control_dependencies_controller(this); + if (_new_stack) + { + _graph._control_dependencies_stack = _old_stack; + _graph._set_control_flow_context(_old_control_flow_context); + } + } + + public void Dispose() + { + + } + + public void __init__() + { + + } + + public void __del__() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 337fdea9d..e0c58966d 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -1,16 +1,342 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { + /// + /// Destroy an options object. Graph will be deleted once no more + /// TFSession's are referencing it. + /// + /// [DllImport(TensorFlowLibName)] - public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); + public static extern void TF_DeleteGraph(IntPtr graph); + + /// + /// + /// + /// TF_ImportGraphDefOptions* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); + + /// + /// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteImportGraphDefResults(IntPtr results); + + [DllImport(TensorFlowLibName)] + public static extern string TF_GraphDebugString(IntPtr graph, out int len); + + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, SafeBufferHandle output_op_def, SafeStatusHandle status); + + /// + /// Returns the shape of the Tensor referenced by `output` in `graph` + /// into `dims`. `dims` must be an array large enough to hold `num_dims` + /// entries (e.g., the return value of TF_GraphGetTensorNumDims). + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); + + /// + /// Import the graph serialized in `graph_def` into `graph`. + /// Convenience function for when only return outputs are needed. + /// + /// `num_return_outputs` must be the number of return outputs added (i.e. the + /// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If + /// `num_return_outputs` is non-zero, `return_outputs` must be of length + /// `num_return_outputs`. Otherwise it can be null. + /// + /// TF_Graph* graph + /// const TF_Buffer* + /// const TF_ImportGraphDefOptions* + /// TF_Output* + /// int + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); + + /// + /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and + /// a bad status on error. Otherwise, returns a populated + /// TF_ImportGraphDefResults instance. The returned instance must be deleted via + /// TF_DeleteImportGraphDefResults(). + /// + /// TF_Graph* + /// const TF_Buffer* + /// const TF_ImportGraphDefOptions* + /// TF_Status* + /// TF_ImportGraphDefResults* + [DllImport(TensorFlowLibName)] + public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); + + /// + /// Import the graph serialized in `graph_def` into `graph`. + /// + /// TF_Graph* + /// TF_Buffer* + /// TF_ImportGraphDefOptions* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); + + /// + /// Iterate through the operations of a graph. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos); + + /// + /// Returns the operation in the graph with `oper_name`. Returns nullptr if + /// no operation found. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name); + + /// + /// Sets the shape of the Tensor referenced by `output` in `graph` to + /// the shape described by `dims` and `num_dims`. + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); + + /// + /// Write out a serialized representation of `graph` (as a GraphDef protocol + /// message) to `output_graph_def` (allocated by TF_NewBuffer()). + /// + /// TF_Graph* + /// TF_Buffer* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); + + /// + /// Returns the number of dimensions of the Tensor referenced by `output` + /// in `graph`. + /// + /// If the number of dimensions in the shape is unknown, returns -1. + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status); + + /// + /// Cause the imported graph to have a control dependency on `oper`. `oper` + /// should exist in the graph being imported into. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddControlDependency(SafeImportGraphDefOptionsHandle opts, IntPtr oper); + + /// + /// Set any imported nodes with input `src_name:src_index` to have that input + /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, + /// `dst` references a node already existing in the graph being imported into. + /// `src_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// int + /// TF_Output + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddInputMapping(SafeImportGraphDefOptionsHandle opts, string src_name, int src_index, TF_Output dst); + + /// + /// Add an operation in `graph_def` to be returned via the `return_opers` output + /// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no + /// lifetime requirements. + /// + /// TF_ImportGraphDefOptions* opts + /// const char* + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); + + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); + + /// + /// Add an output in `graph_def` to be returned via the `return_outputs` output + /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input + /// mapping, the corresponding existing tensor in `graph` will be returned. + /// `oper_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// int + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddReturnOutput(SafeImportGraphDefOptionsHandle opts, string oper_name, int index); + + /// + /// Returns the number of return operations added via + /// TF_ImportGraphDefOptionsAddReturnOperation(). + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_ImportGraphDefOptionsNumReturnOperations(SafeImportGraphDefOptionsHandle opts); + + /// + /// Returns the number of return outputs added via + /// TF_ImportGraphDefOptionsAddReturnOutput(). + /// + /// const TF_ImportGraphDefOptions* + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(SafeImportGraphDefOptionsHandle opts); + + /// + /// Set any imported nodes with control input `src_name` to have that input + /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, + /// `dst` references an operation already existing in the graph being imported + /// into. `src_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsRemapControlDependency(SafeImportGraphDefOptionsHandle opts, string src_name, IntPtr dst); + + /// + /// Set the prefix to be prepended to the names of nodes in `graph_def` that will + /// be imported into `graph`. `prefix` is copied and has no lifetime + /// requirements. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetPrefix(SafeImportGraphDefOptionsHandle ops, string prefix); + + /// + /// Set whether to uniquify imported operation names. If true, imported operation + /// names will be modified if their name already exists in the graph. If false, + /// conflicting names will be treated as an error. Note that this option has no + /// effect if a prefix is set, since the prefix will guarantee all names are + /// unique. Defaults to false. + /// + /// TF_ImportGraphDefOptions* + /// unsigned char + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); + + /// + /// Fetches the return operations requested via + /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched + /// operations is returned in `num_opers`. The array of return operations is + /// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. + /// + /// TF_ImportGraphDefResults* + /// int* + /// TF_Operation*** + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefResultsReturnOperations(SafeImportGraphDefResultsHandle results, ref int num_opers, ref TF_Operation opers); + + /// + /// Fetches the return outputs requested via + /// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is + /// returned in `num_outputs`. The array of return outputs is returned in + /// `outputs`. `*outputs` is owned by and has the lifetime of `results`. + /// + /// TF_ImportGraphDefResults* results + /// int* + /// TF_Output** + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefResultsReturnOutputs(SafeImportGraphDefResultsHandle results, ref int num_outputs, ref IntPtr outputs); + + /// + /// This function creates a new TF_Session (which is created on success) using + /// `session_options`, and then initializes state (restoring tensors and other + /// assets) using `run_options`. + /// + /// const TF_SessionOptions* + /// const TF_Buffer* + /// const char* + /// const char* const* + /// int + /// TF_Graph* + /// TF_Buffer* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, + string export_dir, string[] tags, int tags_len, + SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern SafeGraphHandle TF_NewGraph(); + + [DllImport(TensorFlowLibName)] + public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); + + /// + /// Set the shapes and types of the output's handle. + /// + /// TF_Graph* + /// TF_Output + /// int + /// const int64_t** + /// const int* + /// const TF_DataType* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, + int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, + SafeStatusHandle status); + + /// + /// Updates 'dst' to consume 'new_src'. + /// + /// TF_Graph* + /// + /// + /// TF_Status* + [DllImport(TensorFlowLibName)] + + public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, SafeStatusHandle status); + /// + /// Attempts to evaluate `output`. This will only be possible if `output` doesn't + /// depend on any graph inputs (this function is safe to call if this isn't the + /// case though). + /// + /// + /// + /// + /// + /// [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewGraph(); + public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Graphs/graph_io.py.cs b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs new file mode 100644 index 000000000..2d1a352e9 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/graph_io.py.cs @@ -0,0 +1,45 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System.IO; + +namespace Tensorflow +{ + public class graph_io + { + public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) + { + var graph_def = graph.as_graph_def(); + string path = Path.Combine(logdir, name); + if (as_text) + File.WriteAllText(path, graph_def.ToString()); + else + File.WriteAllBytes(path, graph_def.ToByteArray()); + return path; + } + + public static string write_graph(MetaGraphDef graph_def, string logdir, string name, bool as_text = true) + { + string path = Path.Combine(logdir, name); + if (as_text) + File.WriteAllText(path, graph_def.ToString()); + else + File.WriteAllBytes(path, graph_def.ToByteArray()); + return path; + } + } +} diff --git a/src/TensorFlowNET.Core/IO/MemmappedFileSystem.cs b/src/TensorFlowNET.Core/IO/MemmappedFileSystem.cs new file mode 100644 index 000000000..5c74c4814 --- /dev/null +++ b/src/TensorFlowNET.Core/IO/MemmappedFileSystem.cs @@ -0,0 +1,70 @@ +/***************************************************************************** + Copyright 2021 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.IO; +using System.IO.MemoryMappedFiles; +using System.Linq; +using Tensorflow; + +namespace Tensorflow.IO +{ + public class MemmappedFileSystem + { + public const string MEMMAPPED_PACKAGE_DEFAULT_NAME = "memmapped_package://."; + + private MemoryMappedFile _mmapFile; + private MemmappedFileSystemDirectory _directory; + + public MemmappedFileSystem(string path) + { + using (var stream = File.OpenRead(path)) + { + // Read the offset for the directory + var offsetData = new byte[sizeof(ulong)]; + stream.Seek(-sizeof(ulong), SeekOrigin.End); + stream.Read(offsetData, 0, sizeof(ulong)); + var offset = BitConverter.ToUInt64(offsetData, 0); + + var dirLength = stream.Length - (long) offset - sizeof(ulong); + if (dirLength < 0) + { + throw new InvalidDataException("Malformed mmapped filesystem!"); + } + + var dirData = new byte[dirLength]; + + stream.Seek((long) offset, SeekOrigin.Begin); + stream.Read(dirData, 0, (int) dirLength); + + _directory = MemmappedFileSystemDirectory.Parser.ParseFrom(dirData); + } + + _mmapFile = MemoryMappedFile.CreateFromFile(path, FileMode.Open); + } + + public Stream OpenMemmapped(string filename) + { + var entry = _directory.Element.FirstOrDefault(x => x.Name == filename); + if (entry == null) + { + throw new FileNotFoundException($"Missing memmaped file entry: {filename}"); + } + + return _mmapFile.CreateViewStream((long) entry.Offset, (long) entry.Length); + } + } +} diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs new file mode 100644 index 000000000..142b8b64e --- /dev/null +++ b/src/TensorFlowNET.Core/IO/gfile.cs @@ -0,0 +1,79 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.IO +{ + public class GFile + { + /// + /// Recursive directory tree generator for directories. + /// + /// a Directory name + /// Traverse in order if True, post order if False. + public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) + { + if (!Directory.Exists(top)) + return Enumerable.Empty<(string, string[], string[])>(); + + return walk_v2(top, in_order); + } + + private IEnumerable<(string, string[], string[])> walk_v2(string top, bool topdown) + { + var subdirs = Directory.GetDirectories(top); + var files = Directory.GetFiles(top); + + var here = (top, subdirs, files); + + if (subdirs.Length == 0) + yield return here; + else + foreach (var dir in subdirs) + foreach (var f in walk_v2(dir, topdown)) + yield return f; + } + + public string[] listdir(string data_dir) + => Directory.GetDirectories(data_dir) + .Select(x => x.Split(Path.DirectorySeparatorChar).Last()) + .ToArray(); + + public string[] glob(string data_dir) + { + var dirs = new List(); + foreach(var dir in Directory.GetDirectories(data_dir)) + dirs.AddRange(Directory.GetFiles(dir)); + return dirs.ToArray(); + } + + public string join(params string[] paths) + { + Debug.Assert(paths.Length >= 1); + if (paths[0].Substring(1).Contains("://")) + { + throw new NotImplementedException("The combination of urls has not been implemented."); + } + return Path.Combine(paths); + } + } +} diff --git a/src/TensorFlowNET.Core/Interfaces/IFlatten.cs b/src/TensorFlowNET.Core/Interfaces/IFlatten.cs new file mode 100644 index 000000000..ffabf1d0d --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/IFlatten.cs @@ -0,0 +1,7 @@ +namespace Tensorflow +{ + public interface ICanBeFlattened + { + object[] Flatten(); + } +} diff --git a/src/TensorFlowNET.Core/Interfaces/IFromMergeVars.cs b/src/TensorFlowNET.Core/Interfaces/IFromMergeVars.cs new file mode 100644 index 000000000..2dd168e1b --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/IFromMergeVars.cs @@ -0,0 +1,7 @@ +namespace Tensorflow +{ + public interface IFromMergeVars + { + T FromMergeVars(ITensorOrTensorArray[] mergeVars); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Interfaces/IPackable.cs b/src/TensorFlowNET.Core/Interfaces/IPackable.cs new file mode 100644 index 000000000..8deffea91 --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/IPackable.cs @@ -0,0 +1,7 @@ +namespace Tensorflow +{ + public interface IPackable + { + T Pack(object[] sequences); + } +} diff --git a/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs b/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs new file mode 100644 index 000000000..74d01558d --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/ITensorFlowObject.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public interface ITensorFlowObject : IDisposable + { + void __enter__(); + + void __exit__(); + } +} diff --git a/src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs b/src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs new file mode 100644 index 000000000..9fc3be9f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/ITensorOrOperation.cs @@ -0,0 +1,34 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; + +namespace Tensorflow +{ + /// + /// in order to limit function return value + /// is Tensor or Operation + /// + public interface ITensorOrOperation + { + string Device { get; } + Operation op { get; } + string name { get; } + TF_DataType dtype { get; } + Tensor[] outputs { get; } + NDArray numpy(); + } +} diff --git a/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs b/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs new file mode 100644 index 000000000..a6f30ceb8 --- /dev/null +++ b/src/TensorFlowNET.Core/Interfaces/ITensorOrTensorArray.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + /// + /// in order to limit function return value + /// is Tensor or TensorArray + /// + public interface ITensorOrTensorArray + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs new file mode 100644 index 000000000..37264104a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs @@ -0,0 +1,45 @@ +using Newtonsoft.Json; +using System.Reflection; +using System.Runtime.Versioning; +using Tensorflow.Keras.Saving.Common; + +namespace Tensorflow.Keras +{ + [JsonConverter(typeof(CustomizedActivationJsonConverter))] + public class Activation + { + public string Name { get; set; } + /// + /// The parameters are `features` and `name`. + /// + public Func ActivationFunction { get; set; } + + public Tensor Apply(Tensor input, string name = null) => ActivationFunction(input, name); + + public static implicit operator Activation(Func func) + { + return new Activation() + { + Name = func.GetMethodInfo().Name, + ActivationFunction = func + }; + } + } + + public interface IActivationsApi + { + Activation GetActivationFromName(string name); + Activation Linear { get; } + + Activation Relu { get; } + Activation Relu6 { get; } + + Activation Sigmoid { get; } + + Activation Softmax { get; } + + Activation Tanh { get; } + + Activation Mish { get; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs new file mode 100644 index 000000000..e830e5bf8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs @@ -0,0 +1,12 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition { + public class ELUArgs : AutoSerializeLayerArgs + { + [JsonProperty("alpha")] + public float Alpha { get; set; } = 0.1f; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ExponentialArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ExponentialArgs.cs new file mode 100644 index 000000000..ef024971d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ExponentialArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ExponentialArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/HardSigmoidArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/HardSigmoidArgs.cs new file mode 100644 index 000000000..788e0f36d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/HardSigmoidArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class HardSigmoidArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs new file mode 100644 index 000000000..6d9531346 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LeakyReLuArgs : AutoSerializeLayerArgs + { + /// + /// Negative slope coefficient. + /// + [JsonProperty("alpha")] + public float Alpha { get; set; } = 0.3f; + } +} diff --git a/test/TensorFlowNET.Examples/IExample.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SELUArgs.cs similarity index 51% rename from test/TensorFlowNET.Examples/IExample.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SELUArgs.cs index 8908abe00..eb0e18446 100644 --- a/test/TensorFlowNET.Examples/IExample.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SELUArgs.cs @@ -2,10 +2,10 @@ using System.Collections.Generic; using System.Text; -namespace TensorFlowNET.Examples +namespace Tensorflow.Keras.ArgsDefinition { - public interface IExample + public class SELUArgs : LayerArgs { - void Run(); + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs new file mode 100644 index 000000000..1c1d147f1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -0,0 +1,12 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition { + public class SoftmaxArgs : AutoSerializeLayerArgs + { + [JsonProperty("axis")] + public Axis axis { get; set; } = -1; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftplusArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftplusArgs.cs new file mode 100644 index 000000000..7b4f20795 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftplusArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SoftplusArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftsignArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftsignArgs.cs new file mode 100644 index 000000000..4e23d261d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftsignArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SoftsignArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SwishArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SwishArgs.cs new file mode 100644 index 000000000..3dea06a23 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SwishArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SwishArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/TanhArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/TanhArgs.cs new file mode 100644 index 000000000..5df41b71b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/TanhArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TanhArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs new file mode 100644 index 000000000..4cdfb46bd --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs @@ -0,0 +1,24 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AttentionArgs : BaseDenseAttentionArgs + { + + /// + /// If `true`, will create a scalar variable to scale the attention scores. + /// + [JsonProperty("use_scale")] + public bool use_scale { get; set; } = false; + + /// + /// Function to use to compute attention scores, one of + /// `{"dot", "concat"}`. `"dot"` refers to the dot product between the query + /// and key vectors. `"concat"` refers to the hyperbolic tangent of the + /// concatenation of the query and key vectors. + /// + [JsonProperty("score_mode")] + public string score_mode { get; set; } = "dot"; + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs new file mode 100644 index 000000000..0ef017370 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs @@ -0,0 +1,23 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BaseDenseAttentionArgs : AutoSerializeLayerArgs + { + + /// + /// Boolean. Set to `true` for decoder self-attention. Adds a mask such + /// that position `i` cannot attend to positions `j > i`. This prevents the + /// flow of information from the future towards the past. + /// + public bool causal { get; set; } = false; + + /// + /// Float between 0 and 1. Fraction of the units to drop for the + /// attention scores. + /// + [JsonProperty("dropout")] + public float dropout { get; set; } = 0f; + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs new file mode 100644 index 000000000..077dea89d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs @@ -0,0 +1,40 @@ +using Newtonsoft.Json; +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class MultiHeadAttentionArgs : AutoSerializeLayerArgs + { + [JsonProperty("num_heads")] + public int NumHeads { get; set; } + [JsonProperty("key_dim")] + public int KeyDim { get; set; } + [JsonProperty("value_dim")] + public int? ValueDim { get; set; } = null; + [JsonProperty("dropout")] + public float Dropout { get; set; } = 0f; + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("output_shape")] + public Shape OutputShape { get; set; } = null; + [JsonProperty("attention_axes")] + public Shape AttentionAxis { get; set; } = null; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("kernel_regularizer")] + public IRegularizer KernelRegularizer { get; set; } = null; + [JsonProperty("bias_regularizer")] + public IRegularizer BiasRegularizer { get; set; } = null; + [JsonProperty("kernel_constraint")] + public Action KernelConstraint { get; set; } = null; + [JsonProperty("bias_constraint")] + public Action BiasConstraint { get; set; } = null; + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: Add `key_shape`, `value_shape`, `query_shape`. + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs new file mode 100644 index 000000000..583ab9322 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -0,0 +1,26 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition +{ + /// + /// This class has nothing but the attributes different from `LayerArgs`. + /// It's used to serialize the model to `tf` format. + /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, + /// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. + /// + public class AutoSerializeLayerArgs: LayerArgs + { + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv1DArgs.cs new file mode 100644 index 000000000..c461f7d27 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv1DArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Conv1DArgs : ConvolutionalArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DArgs.cs new file mode 100644 index 000000000..767a5f80a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Conv2DArgs : ConvolutionalArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DTransposeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DTransposeArgs.cs new file mode 100644 index 000000000..3daba9465 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DTransposeArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Conv2DTransposeArgs : Conv2DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs new file mode 100644 index 000000000..f34c63d1b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs @@ -0,0 +1,46 @@ +using Newtonsoft.Json; +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ConvolutionalArgs : AutoSerializeLayerArgs + { + public int Rank { get; set; } + [JsonProperty("filters")] + public int Filters { get; set; } + public int NumSpatialDims { get; set; } = Unknown; + [JsonProperty("kernel_size")] + public Shape KernelSize { get; set; } + + /// + /// specifying the stride length of the convolution. + /// + [JsonProperty("strides")] + public Shape Strides { get; set; } + [JsonProperty("padding")] + public string Padding { get; set; } + [JsonProperty("data_format")] + public string DataFormat { get; set; } + [JsonProperty("dilation_rate")] + public Shape DilationRate { get; set; } + [JsonProperty("groups")] + public int Groups { get; set; } + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + [JsonProperty("kernel_regularizer")] + public IRegularizer KernelRegularizer { get; set; } + [JsonProperty("bias_regularizer")] + public IRegularizer BiasRegularizer { get; set; } + [JsonProperty("kernel_constraint")] + public Action KernelConstraint { get; set; } + [JsonProperty("bias_constraint")] + public Action BiasConstraint { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs new file mode 100644 index 000000000..0caa76ef5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using System; +using System.Xml.Linq; +using Tensorflow.Operations.Initializers; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO: `activity_regularizer` + public class DenseArgs : LayerArgs + { + /// + /// Positive integer, dimensionality of the output space. + /// + [JsonProperty("units")] + public int Units { get; set; } + + /// + /// Activation function to use. + /// + [JsonProperty("activation")] + public Activation Activation { get; set; } + + /// + /// Whether the layer uses a bias vector. + /// + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + + /// + /// Initializer for the `kernel` weights matrix. + /// + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + + /// + /// Initializer for the bias vector. + /// + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + + /// + /// Regularizer function applied to the `kernel` weights matrix. + /// + [JsonProperty("kernel_regularizer")] + public IRegularizer KernelRegularizer { get; set; } + + /// + /// Regularizer function applied to the bias vector. + /// + [JsonProperty("bias_regularizer")] + public IRegularizer BiasRegularizer { get; set; } + + /// + /// Constraint function applied to the `kernel` weights matrix. + /// + [JsonProperty("kernel_constraint")] + public Action KernelConstraint { get; set; } + + /// + /// Constraint function applied to the bias vector. + /// + [JsonProperty("bias_constraint")] + public Action BiasConstraint { get; set; } + + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs new file mode 100644 index 000000000..e60309720 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs @@ -0,0 +1,79 @@ +using Newtonsoft.Json; +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition.Core +{ + public class EinsumDenseArgs : AutoSerializeLayerArgs + { + /// + /// An equation describing the einsum to perform. This equation must + /// be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or + /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis + /// expression sequence. + /// + [JsonProperty("equation")] + public string Equation { get; set; } + + /// + /// The expected shape of the output tensor (excluding the batch + /// dimension and any dimensions represented by ellipses). You can specify + /// None for any dimension that is unknown or can be inferred from the input + /// shape. + /// + [JsonProperty("output_shape")] + public Shape OutputShape { get; set; } + + /// + /// A string containing the output dimension(s) to apply a bias to. + /// Each character in the `bias_axes` string should correspond to a character + /// in the output portion of the `equation` string. + /// + [JsonProperty("bias_axes")] + public string BiasAxes { get; set; } = null; + + /// + /// Activation function to use. + /// + [JsonProperty("activation")] + public Activation Activation { get; set; } + + /// + /// Initializer for the `kernel` weights matrix. + /// + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + + /// + /// Initializer for the bias vector. + /// + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + + /// + /// Regularizer function applied to the `kernel` weights matrix. + /// + [JsonProperty("kernel_regularizer")] + public IRegularizer KernelRegularizer { get; set; } + + /// + /// Regularizer function applied to the bias vector. + /// + [JsonProperty("bias_regularizer")] + public IRegularizer BiasRegularizer { get; set; } + + /// + /// Constraint function applied to the `kernel` weights matrix. + /// + [JsonProperty("kernel_constraint")] + public Action KernelConstraint { get; set; } + + /// + /// Constraint function applied to the bias vector. + /// + [JsonProperty("bias_constraint")] + public Action BiasConstraint { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs new file mode 100644 index 000000000..c462961b3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs @@ -0,0 +1,22 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class EmbeddingArgs : AutoSerializeLayerArgs + { + [JsonProperty("input_dim")] + public int InputDim { get; set; } + [JsonProperty("output_dim")] + public int OutputDim { get; set; } + [JsonProperty("mask_zero")] + public bool MaskZero { get; set; } + [JsonProperty("input_length")] + public int InputLength { get; set; } = -1; + [JsonProperty("embeddings_initializer")] + public IInitializer EmbeddingsInitializer { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: `embeddings_regularizer`, `embeddings_constraint`. + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs new file mode 100644 index 000000000..e036e1912 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -0,0 +1,22 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class InputLayerArgs : LayerArgs + { + [JsonIgnore] + public Tensor InputTensor { get; set; } + [JsonProperty("sparse")] + public virtual bool Sparse { get; set; } + [JsonProperty("ragged")] + public bool Ragged { get; set; } + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs new file mode 100644 index 000000000..ba0332836 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -0,0 +1,23 @@ +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class DataAdapterArgs: IKerasConfig + { + public Tensors X { get; set; } + public Tensors Y { get; set; } + public IDatasetV2 Dataset { get; set; } + public int BatchSize { get; set; } = 32; + public int Steps { get; set; } + public int Epochs { get; set; } + public bool Shuffle { get; set; } + public int MaxQueueSize { get; set; } + public int Worker { get; set; } + public bool UseMultiprocessing { get; set; } + public IModel Model { get; set; } + public Dictionary ClassWeight = null; + public NDArray SampleWeight = null; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs new file mode 100644 index 000000000..72d0bb811 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -0,0 +1,25 @@ +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class DataHandlerArgs: IKerasConfig + { + public Tensors X { get; set; } + public Tensors Y { get; set; } + public IDatasetV2 Dataset { get; set; } + public int BatchSize { get; set; } = 32; + public int StepsPerEpoch { get; set; } = -1; + public int InitialEpoch { get; set; } = 0; + public int Epochs { get; set; } = 1; + public bool Shuffle { get; set; } = false; + public int MaxQueueSize { get; set; } = 10; + public int Workers { get; set; } = 1; + public bool UseMultiprocessing { get; set; } = false; + public IModel Model { get; set; } + public IVariableV1 StepsPerExecution { get; set; } + public Dictionary ClassWeight = null; + public NDArray SampleWeight = null; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs new file mode 100644 index 000000000..11b8ba39a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -0,0 +1,54 @@ +using Newtonsoft.Json; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition +{ + [JsonObject(MemberSerialization.OptIn)] + public class LayerArgs: IKerasConfig + { + /// + /// Indicates whether the layer's weights are updated during training + /// and whether the layer's updates are run during training. + /// + public virtual bool Trainable { get; set; } = true; + public virtual string Name { get; set; } + + /// + /// Only applicable to input layers. + /// + public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; + + /// + /// Whether the `call` method can be used to build a TF graph without issues. + /// This attribute has no effect if the model is created using the Functional + /// API. Instead, `model.dynamic` is determined based on the internal layers. + /// + public virtual bool Dynamic { get; set; } = false; + + /// + /// Only applicable to input layers. + /// + public virtual Shape InputShape { get; set; } + + /// + /// Only applicable to input layers. + /// + public virtual KerasShapesWrapper BatchInputShape { get; set; } + + public virtual int BatchSize { get; set; } = -1; + + /// + /// Initial weight values. + /// + public virtual float[] Weights { get; set; } + + /// + /// Regularizer function applied to the output of the layer(its "activation"). + /// + public virtual IRegularizer ActivityRegularizer { get; set; } + + public virtual bool Autocast { get; set; } + + public virtual bool IsFromConfig { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/AddArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/AddArgs.cs new file mode 100644 index 000000000..016d58203 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/AddArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AddArgs : MergeArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/ConcatenateArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/ConcatenateArgs.cs new file mode 100644 index 000000000..4a81d139d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/ConcatenateArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ConcatenateArgs : MergeArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs new file mode 100644 index 000000000..9bcf1908e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs @@ -0,0 +1,15 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO: complete the implementation + public class MergeArgs : AutoSerializeLayerArgs + { + public Tensors Inputs { get; set; } + [JsonProperty("axis")] + public int Axis { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/SubtractArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/SubtractArgs.cs new file mode 100644 index 000000000..1e3621cb6 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/SubtractArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SubtractArgs : MergeArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs new file mode 100644 index 000000000..57b8bb695 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ModelArgs : LayerArgs + { + public Tensors Inputs { get; set; } + public Tensors Outputs { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs new file mode 100644 index 000000000..ad55ff612 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -0,0 +1,13 @@ +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class NodeArgs: IKerasConfig + { + public ILayer[] InboundLayers { get; set; } + public int[] NodeIndices { get; set; } + public int[] TensorIndices { get; set; } + public Tensors InputTensors { get; set; } + public Tensors Outputs { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs new file mode 100644 index 000000000..6ee91e80b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs @@ -0,0 +1,37 @@ +using Newtonsoft.Json; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BatchNormalizationArgs : AutoSerializeLayerArgs + { + [JsonProperty("axis")] + public Shape Axis { get; set; } = -1; + [JsonProperty("momentum")] + public float Momentum { get; set; } = 0.99f; + [JsonProperty("epsilon")] + public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] + public bool Center { get; set; } = true; + [JsonProperty("scale")] + public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] + public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] + public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("moving_mean_initializer")] + public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("moving_variance_initializer")] + public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] + public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] + public IRegularizer GammaRegularizer { get; set; } + // TODO: `beta_constraint` and `gamma_constraint`. + [JsonProperty("renorm")] + public bool Renorm { get; set; } + // TODO: `renorm_clipping` and `virtual_batch_size`. + [JsonProperty("renorm_momentum")] + public float RenormMomentum { get; set; } = 0.99f; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs new file mode 100644 index 000000000..1ac661b37 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs @@ -0,0 +1,27 @@ +using Newtonsoft.Json; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LayerNormalizationArgs : AutoSerializeLayerArgs + { + [JsonProperty("axis")] + public Axis Axis { get; set; } = -1; + [JsonProperty("epsilon")] + public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] + public bool Center { get; set; } = true; + [JsonProperty("scale")] + public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] + public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] + public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] + public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] + public IRegularizer GammaRegularizer { get; set; } + + // TODO: `beta_constraint` and `gamma_constraint`. + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs new file mode 100644 index 000000000..30c901453 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/NormalizationArgs.cs @@ -0,0 +1,15 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition; + +public class NormalizationArgs : PreprocessingLayerArgs +{ + [JsonProperty("axis")] + public Axis? Axis { get; set; } + [JsonProperty("mean")] + public float? Mean { get; set; } + [JsonProperty("variance")] + public float? Variance { get; set; } + + public bool Invert { get; set; } = false; +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs new file mode 100644 index 000000000..6256fd329 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs @@ -0,0 +1,13 @@ +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class OptimizerV2Args: IKerasConfig + { + public string Name { get; set; } + public float LearningRate { get; set; } = 0.001f; + public float InitialDecay { get; set; } + public float ClipNorm { get; set; } + public float ClipValue { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs new file mode 100644 index 000000000..06903e370 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AveragePooling2DArgs : Pooling2DArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling1DArgs.cs new file mode 100644 index 000000000..e73aff766 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling1DArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GlobalAveragePooling1DArgs : Pooling1DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling2DArgs.cs new file mode 100644 index 000000000..d143cf471 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalAveragePooling2DArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GlobalAveragePooling2DArgs : Pooling2DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling1DArgs.cs new file mode 100644 index 000000000..e03227feb --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling1DArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GlobalMaxPooling1DArgs : Pooling1DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling2DArgs.cs new file mode 100644 index 000000000..a95cac836 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/GlobalMaxPooling2DArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GlobalMaxPooling2DArgs : Pooling2DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling1DArgs.cs new file mode 100644 index 000000000..4cfff2c15 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling1DArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class MaxPooling1DArgs : Pooling1DArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2DArgs.cs new file mode 100644 index 000000000..c2eb9d3cb --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2DArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class MaxPooling2DArgs : Pooling2DArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs new file mode 100644 index 000000000..c5fdca675 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs @@ -0,0 +1,40 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Pooling1DArgs : AutoSerializeLayerArgs + { + /// + /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. + /// + public IPoolFunction PoolFunction { get; set; } + + /// + /// specifying the size of the pooling window. + /// + [JsonProperty("pool_size")] + public int PoolSize { get; set; } + + /// + /// specifying the strides of the pooling operation. + /// + [JsonProperty("strides")] + public int Strides { + get { return _strides.HasValue ? _strides.Value : PoolSize; } + set { _strides = value; } + } + private int? _strides = null; + + /// + /// The padding method, either 'valid' or 'same'. + /// + [JsonProperty("padding")] + public string Padding { get; set; } = "valid"; + + /// + /// one of `channels_last` (default) or `channels_first`. + /// + [JsonProperty("data_format")] + public string DataFormat { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs new file mode 100644 index 000000000..91a372ef3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs @@ -0,0 +1,36 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Pooling2DArgs : AutoSerializeLayerArgs + { + /// + /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. + /// + public IPoolFunction PoolFunction { get; set; } + + /// + /// specifying the size of the pooling window. + /// + [JsonProperty("pool_size")] + public Shape PoolSize { get; set; } + + /// + /// specifying the strides of the pooling operation. + /// + [JsonProperty("strides")] + public Shape Strides { get; set; } + + /// + /// The padding method, either 'valid' or 'same'. + /// + [JsonProperty("padding")] + public string Padding { get; set; } = "valid"; + + /// + /// one of `channels_last` (default) or `channels_first`. + /// + [JsonProperty("data_format")] + public string DataFormat { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/CategoryEncodingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/CategoryEncodingArgs.cs new file mode 100644 index 000000000..c282afd89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/CategoryEncodingArgs.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class CategoryEncodingArgs : AutoSerializeLayerArgs + { + [JsonProperty("num_tokens")] + public int NumTokens { get; set; } + [JsonProperty("output_mode")] + public string OutputMode { get; set; } + [JsonProperty("sparse")] + public bool Sparse { get; set; } + public NDArray CountWeights { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs new file mode 100644 index 000000000..97cb364d9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class PreprocessingLayerArgs : AutoSerializeLayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs new file mode 100644 index 000000000..154bd8c89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs @@ -0,0 +1,12 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RescalingArgs : AutoSerializeLayerArgs + { + [JsonProperty("scale")] + public float Scale { get; set; } + [JsonProperty("offset")] + public float Offset { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs new file mode 100644 index 000000000..39fa52211 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs @@ -0,0 +1,10 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO: no corresponding class found in keras python, maybe obselete? + public class ResizingArgs : PreprocessingLayerArgs + { + public int Height { get; set; } + public int Width { get; set; } + public string Interpolation { get; set; } = "bilinear"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs new file mode 100644 index 000000000..1a7149f5a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs @@ -0,0 +1,25 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TextVectorizationArgs : PreprocessingLayerArgs + { + [JsonProperty("standardize")] + public Func Standardize { get; set; } + [JsonProperty("split")] + public string Split { get; set; } = "standardize"; + [JsonProperty("max_tokens")] + public int MaxTokens { get; set; } = -1; + [JsonProperty("output_mode")] + public string OutputMode { get; set; } = "int"; + [JsonProperty("output_sequence_length")] + public int OutputSequenceLength { get; set; } = -1; + [JsonProperty("vocabulary")] + public string[] Vocabulary { get; set; } + + // TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding` + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs new file mode 100644 index 000000000..ac9a3d116 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs @@ -0,0 +1,10 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RMSpropArgs : OptimizerV2Args + { + public float RHO { get; set; } = 0.9f; + public float Momentum { get; set; } = 0.0f; + public float Epsilon { get; set; } = 1e-7f; + public bool Centered { get; set; } = false; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs new file mode 100644 index 000000000..1c85d4936 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs @@ -0,0 +1,28 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class DropoutArgs : AutoSerializeLayerArgs + { + /// + /// Float between 0 and 1. Fraction of the input units to drop. + /// + [JsonProperty("rate")] + public float Rate { get; set; } + + /// + /// 1D integer tensor representing the shape of the + /// binary dropout mask that will be multiplied with the input. + /// + [JsonProperty("noise_shape")] + public Shape NoiseShape { get; set; } + + /// + /// random seed. + /// + [JsonProperty("seed")] + public int? Seed { get; set; } + + public bool SupportsMasking { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs new file mode 100644 index 000000000..8c2626390 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping2DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][2], int[1][1], int[2][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs new file mode 100644 index 000000000..2d98e55db --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping3DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][3], int[1][1], int[3][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs new file mode 100644 index 000000000..21b85966b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs @@ -0,0 +1,12 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping1DArgs : LayerArgs + { + /// + /// Accept length 1 or 2 + /// + public NDArray cropping { get; set; } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs new file mode 100644 index 000000000..91ffc2058 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs @@ -0,0 +1,10 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class FlattenArgs : AutoSerializeLayerArgs + { + [JsonProperty("data_format")] + public string DataFormat { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs new file mode 100644 index 000000000..92be10ab1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs @@ -0,0 +1,9 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { + public class PermuteArgs : AutoSerializeLayerArgs + { + [JsonProperty("dims")] + public int[] dims { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs new file mode 100644 index 000000000..4d1123c8a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs @@ -0,0 +1,11 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ReshapeArgs : AutoSerializeLayerArgs + { + [JsonProperty("target_shape")] + public Shape TargetShape { get; set; } + public object[] TargetShapeObjects { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs new file mode 100644 index 000000000..504b3d46d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs @@ -0,0 +1,17 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class UpSampling2DArgs : AutoSerializeLayerArgs + { + [JsonProperty("size")] + public Shape Size { get; set; } + [JsonProperty("data_format")] + public string DataFormat { get; set; } = "channels_last"; + /// + /// 'nearest', 'bilinear' + /// + [JsonProperty("interpolation")] + public string Interpolation { get; set; } = "nearest"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Upsampling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Upsampling1DArgs.cs new file mode 100644 index 000000000..4e3dbf17a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Upsampling1DArgs.cs @@ -0,0 +1,10 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class UpSampling1DArgs : AutoSerializeLayerArgs + { + [JsonProperty("size")] + public int Size { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs new file mode 100644 index 000000000..4831e435b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs @@ -0,0 +1,10 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO: complete the implementation + public class ZeroPadding2DArgs : LayerArgs + { + public NDArray Padding { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs new file mode 100644 index 000000000..d658a82e9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs @@ -0,0 +1,20 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BidirectionalArgs : AutoSerializeLayerArgs + { + [JsonProperty("layer")] + public ILayer Layer { get; set; } + [JsonProperty("merge_mode")] + public string? MergeMode { get; set; } + [JsonProperty("backward_layer")] + public ILayer BackwardLayer { get; set; } + public NDArray Weights { get; set; } + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs new file mode 100644 index 000000000..cdc3097e9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GRUArgs : AutoSerializeLayerArgs + { + public int Units { get; set; } + public Activation Activation { get; set; } + public Activation RecurrentActivation { get; set; } + public bool UseBias { get; set; } = true; + public float Dropout { get; set; } = .0f; + public float RecurrentDropout { get; set; } = .0f; + public IInitializer KernelInitializer { get; set; } + public IInitializer RecurrentInitializer { get; set; } + public IInitializer BiasInitializer { get; set; } + public bool ReturnSequences { get;set; } + public bool ReturnState { get;set; } + public bool GoBackwards { get;set; } + public bool Stateful { get;set; } + public bool Unroll { get;set; } + public bool TimeMajor { get;set; } + public bool ResetAfter { get;set; } + public int Implementation { get; set; } = 2; + + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs new file mode 100644 index 000000000..624756afe --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs @@ -0,0 +1,39 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GRUCellArgs : AutoSerializeLayerArgs + { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] + public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + [JsonProperty("reset_after")] + public bool ResetAfter { get;set; } + [JsonProperty("implementation")] + public int Implementation { get; set; } = 2; + + + + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs new file mode 100644 index 000000000..1d215576f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GRUOptionalArgs : RnnOptionalArgs + { + public string Identifier => "GRU"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs new file mode 100644 index 000000000..a6beb77e8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class LSTMArgs : RNNArgs + { + // TODO: maybe change the `RNNArgs` and implement this class. + public bool UnitForgetBias { get; set; } + public int Implementation { get; set; } + + public LSTMArgs Clone() + { + return (LSTMArgs)MemberwiseClone(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs new file mode 100644 index 000000000..f45032312 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs @@ -0,0 +1,35 @@ +using Newtonsoft.Json; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO: complete the implementation + public class LSTMCellArgs : AutoSerializeLayerArgs + { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] + public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + [JsonProperty("unit_forget_bias")] + public bool UnitForgetBias { get; set; } = true; + [JsonProperty("implementation")] + public int Implementation { get; set; } = 2; + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs new file mode 100644 index 000000000..2829927c3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + public class LSTMOptionalArgs : RnnOptionalArgs + { + public string Identifier => "LSTM"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs new file mode 100644 index 000000000..d0b73ba44 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -0,0 +1,49 @@ +using Newtonsoft.Json; +using System.Collections.Generic; +using Tensorflow.Keras.Layers; + +namespace Tensorflow.Keras.ArgsDefinition +{ + // TODO(Rinne): add regularizers. + public class RNNArgs : AutoSerializeLayerArgs + { + [JsonProperty("return_sequences")] + public bool ReturnSequences { get; set; } = false; + [JsonProperty("return_state")] + public bool ReturnState { get; set; } = false; + [JsonProperty("go_backwards")] + public bool GoBackwards { get; set; } = false; + [JsonProperty("stateful")] + public bool Stateful { get; set; } = false; + [JsonProperty("unroll")] + public bool Unroll { get; set; } = false; + [JsonProperty("time_major")] + public bool TimeMajor { get; set; } = false; + + public int? InputDim { get; set; } + public int? InputLength { get; set; } + // TODO: Add `num_constants` and `zero_output_for_mask`. + [JsonProperty("units")] + public int Units { get; set; } + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] + public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + public IInitializer KernelInitializer { get; set; } + public IInitializer RecurrentInitializer { get; set; } + public IInitializer BiasInitializer { get; set; } + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("zero_output_for_mask")] + public bool ZeroOutputForMask { get; set; } = false; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + + public RNNArgs Clone() + { + return (RNNArgs)MemberwiseClone(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs new file mode 100644 index 000000000..a6520589d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RnnOptionalArgs: IOptionalArgs + { + public string Identifier => "Rnn"; + public Tensor Mask { get; set; } = null; + public Tensors Constants { get; set; } = null; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs new file mode 100644 index 000000000..e45ef79d0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SimpleRNNArgs : RNNArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs new file mode 100644 index 000000000..b84ea21b3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs @@ -0,0 +1,27 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SimpleRNNCellArgs: AutoSerializeLayerArgs + { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs new file mode 100644 index 000000000..a8b8caf06 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + public class SimpleRNNOptionalArgs : RnnOptionalArgs + { + public string Identifier => "SimpleRNN"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs new file mode 100644 index 000000000..2600f14ee --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs @@ -0,0 +1,10 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Layers; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class StackedRNNCellsArgs : LayerArgs + { + public bool ReverseStateOrder = false; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs new file mode 100644 index 000000000..ec8e16d59 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs @@ -0,0 +1,24 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class WrapperArgs : AutoSerializeLayerArgs + { + [JsonProperty("layer")] + public ILayer Layer { get; set; } + + public WrapperArgs(ILayer layer) + { + Layer = layer; + } + + public static implicit operator WrapperArgs(BidirectionalArgs args) + => new WrapperArgs(args.Layer); + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/SequentialArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SequentialArgs.cs new file mode 100644 index 000000000..407a9ed5f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SequentialArgs.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SequentialArgs : ModelArgs + { + public List Layers { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs new file mode 100644 index 000000000..c2981fccc --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs @@ -0,0 +1,11 @@ +using Tensorflow.NumPy; +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TensorFlowOpLayerArgs : LayerArgs + { + public NodeDef NodeDef { get; set; } + public Dictionary Constants { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs new file mode 100644 index 000000000..e114ca97f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs @@ -0,0 +1,22 @@ +namespace Tensorflow.Keras.Engine; + +public interface ICallback +{ + Dictionary> history { get; set; } + void on_train_begin(); + void on_train_end(); + void on_epoch_begin(int epoch); + void on_train_batch_begin(long step); + void on_train_batch_end(long end_step, Dictionary logs); + void on_epoch_end(int epoch, Dictionary epoch_logs); + void on_predict_begin(); + void on_predict_batch_begin(long step); + void on_predict_batch_end(long end_step, Dictionary logs); + void on_predict_end(); + void on_test_begin(); + void on_test_end(Dictionary logs); + void on_test_batch_begin(long step); + void on_test_batch_end(long end_step, Dictionary logs); + + +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs new file mode 100644 index 000000000..889c76d91 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -0,0 +1,116 @@ +using Tensorflow.Functions; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; +using Tensorflow.Util; + +namespace Tensorflow.Keras.Engine; + +public interface IModel : ILayer +{ + void compile(IOptimizer optimizer, ILossFunc loss); + + void compile(IOptimizer optimizer, ILossFunc loss, string[] metrics); + + void compile(string optimizer, string loss, string[] metrics); + + void compile(IOptimizer optimizer, ILossFunc loss, IMetricFunc[] metrics); + + ICallback fit(NDArray x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + float validation_split = 0f, + ValidationDataPack validation_data = null, + int validation_step = 10, + bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + + ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + float validation_split = 0f, + ValidationDataPack validation_data = null, + bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + + public ICallback fit(IDatasetV2 dataset, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + IDatasetV2 validation_data = null, + int validation_step = 10, // 间隔多少次会进行一次验证 + bool shuffle = true, + Dictionary class_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + + void save(string filepath, + bool overwrite = true, + bool include_optimizer = true, + string save_format = "tf", + SaveOptions? options = null, + ConcreteFunction? signatures = null, + bool save_traces = true); + + void save_weights(string filepath, + bool overwrite = true, + string save_format = null, + object options = null); + + void load_weights(string filepath, + bool by_name = false, + bool skip_mismatch = false, + object options = null); + + Dictionary evaluate(NDArray x, NDArray y, + int batch_size = -1, + int verbose = 1, + NDArray sample_weight = null, + + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false, + bool return_dict = false, + bool is_val = false); + + Tensors predict(Tensors x, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + + public Tensors predict(IDatasetV2 dataset, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + + void summary(int line_length = -1, float[] positions = null); + + IKerasConfig get_config(); + + bool Stop_training { get;set; } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/INode.cs b/src/TensorFlowNET.Core/Keras/Engine/INode.cs new file mode 100644 index 000000000..bd778f6c4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/INode.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Engine +{ + public interface INode + { + Tensors input_tensors { get; } + Tensors Outputs { get; } + ILayer Layer { get; } + List KerasInputs { get; set; } + INode[] ParentNodes { get; } + ILayer[] InboundLayers { get; } + IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); + bool is_input { get; } + List serialize(Func make_node_key, Dictionary node_conversion_map); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs new file mode 100644 index 000000000..1f989391b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs @@ -0,0 +1,22 @@ +namespace Tensorflow.Keras.Engine; + +public interface IOptimizer +{ + Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); + Tensor[] clip_gradients(Tensor[] grads); + void apply_gradients((Tensor, IVariableV1) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + + void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + + IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs new file mode 100644 index 000000000..6743935c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -0,0 +1,84 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Engine +{ + /// + /// Specifies the ndim, dtype and shape of every input to a layer. + /// + public class InputSpec: IKerasConfigable + { + public int? ndim; + public int? max_ndim; + public int? min_ndim; + Dictionary axes; + Shape shape; + TF_DataType dtype; + public int[] AllAxisDim; + + public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, + int? ndim = null, + int? min_ndim = null, + int? max_ndim = null, + Dictionary axes = null, + Shape shape = null) + { + this.ndim = ndim; + if (axes == null) + axes = new Dictionary(); + this.axes = axes; + this.min_ndim = min_ndim; + this.max_ndim = max_ndim; + this.shape = shape; + this.dtype = dtype; + if (ndim == null && shape != null) + this.ndim = shape.ndim; + + if (axes != null) + AllAxisDim = axes.Select(x => x.Value).ToArray(); + } + + public IKerasConfig get_config() + { + return new Config() + { + DType = dtype == TF_DataType.DtInvalid ? null : dtype, + Shape = shape, + Ndim = ndim, + MinNdim = min_ndim, + MaxNdim = max_ndim, + Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) + }; + } + + public override string ToString() + => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; + + public class Config: IKerasConfig + { + public TF_DataType? DType { get; set; } + public Shape Shape { get; set; } + public int? Ndim { get; set; } + public int? MinNdim { get;set; } + public int? MaxNdim { get;set; } + public IDictionary Axes { get; set; } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs new file mode 100644 index 000000000..f1e4ba0c9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -0,0 +1,32 @@ +namespace Tensorflow.Keras.Engine +{ + /// + /// Tracks the Layer call that created a Tensor, for Keras Graph Networks. + /// + public class KerasHistory + { + ILayer layer; + public ILayer Layer => layer; + int node_index; + public int NodeIndex => node_index; + int tensor_index; + public int TensorIndex => tensor_index; + + public KerasHistory(ILayer layer, int node_index, int tensor_index) + { + this.layer = layer; + this.node_index = node_index; + this.tensor_index = tensor_index; + } + + public void Deconstruct(out ILayer layer, out int node_index, out int tensor_index) + { + layer = this.layer; + node_index = this.node_index; + tensor_index = this.tensor_index; + } + + public override string ToString() + => $"{layer.GetType().Name} {layer.Name}"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs new file mode 100644 index 000000000..5a264b631 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs @@ -0,0 +1,75 @@ +namespace Tensorflow.Keras.Engine; + +/// +/// A representation of a Keras in/output during Functional API construction. +/// +public class KerasTensor +{ + private Tensors _original_tensors; + public Tensors original_tensors + { + get => _original_tensors; + set => _original_tensors = value; + } + + private Shape _inferred_value; + public Shape inferred_value => _inferred_value; + + private string _name; + private TensorSpec _type_spec; + public Shape shape => _type_spec.shape; + public TF_DataType dtype => _type_spec.dtype; + + public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) + { + _type_spec = type_spec; + _inferred_value = inferred_value; + _name = name; + } + + public static KerasTensor from_tensor(Tensor tensor) + { + var type_spec = tensor.ToTensorSpec(); + Shape? inferred_value = default; + if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) + { + inferred_value = tf.ones(tensor).shape; + } + var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); + kt.original_tensors = tensor; + return kt; + } + + public KerasTensor this[int idx] + => _original_tensors.First()[idx]; + + public KerasTensor this[params Slice[] slices] + => _original_tensors.First()[slices]; + + public override string ToString() + => _original_tensors.Length switch + { + > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", + 1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", + _ => _original_tensors.ToString(), + }; + + private string GetInferredValueString() + => _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; + + public static implicit operator Tensors(KerasTensor kt) + => kt._original_tensors; + + public static implicit operator Tensor(KerasTensor kt) + { + Tensor tensor = kt._original_tensors; + tensor.IsFromKerasTensor = true; + return tensor; + } + + public static implicit operator KerasTensor(Tensor tensor) + => from_tensor(tensor); + + public static implicit operator KerasTensor(Tensors tensors) + => from_tensor(tensors.First()); +} diff --git a/src/TensorFlowNET.Core/Keras/IInitializersApi.cs b/src/TensorFlowNET.Core/Keras/IInitializersApi.cs new file mode 100644 index 000000000..3ad5e87b8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/IInitializersApi.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public interface IInitializersApi + { + IInitializer Orthogonal(float gain = 1.0f, int? seed = null); + + IInitializer HeNormal(int? seed = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs new file mode 100644 index 000000000..db8deb24b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs @@ -0,0 +1,61 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Models; + +namespace Tensorflow.Keras +{ + public interface IKerasApi + { + IInitializersApi initializers { get; } + ILayersApi layers { get; } + ILossesApi losses { get; } + IActivationsApi activations { get; } + IOptimizerApi optimizers { get; } + IMetricsApi metrics { get; } + IModelsApi models { get; } + + /// + /// `Model` groups layers into an object with training and inference features. + /// + /// + /// + /// + IModel Model(Tensors inputs, Tensors outputs, string name = null); + + /// + /// Instantiate a Keras tensor. + /// + /// + /// + /// + /// + /// + /// A boolean specifying whether the placeholder to be created is sparse. + /// + /// + /// A boolean specifying whether the placeholder to be created is ragged. + /// + /// + /// Optional existing tensor to wrap into the `Input` layer. + /// If set, the layer will not create a placeholder tensor. + /// + /// + Tensors Input(Shape shape = null, + int batch_size = -1, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs new file mode 100644 index 000000000..6c15fd469 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras +{ + public interface IOptimizerApi + { + /// + /// Adam optimization is a stochastic gradient descent method that is based on + /// adaptive estimation of first-order and second-order moments. + /// + /// + /// + /// + /// + /// + /// + /// + IOptimizer Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam"); + + /// + /// Adam enables L2 weight decay on gradients. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + IOptimizer AdamW(float learning_rate = 0.001f, + float weight_decay = 0.004f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + List no_decay_params = null, + string name = "AdamW"); + + /// + /// Construct a new RMSprop optimizer. + /// + /// + /// + /// + /// + /// + /// + /// + IOptimizer RMSprop(float learning_rate = 0.001f, + float rho = 0.9f, + float momentum = 0.0f, + float epsilon = 1e-7f, + bool centered = false, + string name = "RMSprop"); + + IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f); + } +} diff --git a/src/TensorFlowNET.Core/Keras/IPreprocessing.cs b/src/TensorFlowNET.Core/Keras/IPreprocessing.cs new file mode 100644 index 000000000..28eea0f56 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/IPreprocessing.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public interface IPreprocessing + { + public ILayer Resizing(int height, int width, string interpolation = "bilinear"); + public ILayer TextVectorization(Func standardize = null, + string split = "whitespace", + int max_tokens = -1, + string output_mode = "int", + int output_sequence_length = -1); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs new file mode 100644 index 000000000..2f92c4e57 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -0,0 +1,32 @@ +using Tensorflow.Common.Types; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; +using Tensorflow.Training; + +namespace Tensorflow.Keras +{ + public interface ILayer: IWithTrackable, IKerasConfigable + { + string Name { get; } + bool Trainable { get; } + bool Built { get; } + void build(KerasShapesWrapper input_shape); + List Layers { get; } + List InboundNodes { get; } + List OutboundNodes { get; } + Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); + List TrainableVariables { get; } + List TrainableWeights { get; } + List NonTrainableWeights { get; } + List Weights { get; set; } + void set_weights(IEnumerable weights); + List get_weights(); + Shape OutputShape { get; } + KerasShapesWrapper BatchInputShape { get; } + KerasShapesWrapper BuildInputShape { get; } + TF_DataType DType { get; } + int count_params(); + void adapt(Tensor data, int? batch_size = null, int? steps = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs new file mode 100644 index 000000000..524798690 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs @@ -0,0 +1,21 @@ +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.NumPy; +using Tensorflow.Operations.Activation; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public ILayer ELU(float alpha = 0.1f); + public ILayer SELU(); + public ILayer Softmax(int axis = -1); + public ILayer Softmax(Axis axis); + public ILayer Softplus(); + public ILayer HardSigmoid(); + public ILayer Softsign(); + public ILayer Swish(); + public ILayer Tanh(); + public ILayer Exponential(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs new file mode 100644 index 000000000..22fb50d3d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs @@ -0,0 +1,28 @@ +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public ILayer Attention(bool use_scale = false, + string score_mode = "dot", + bool causal = false, + float dropout = 0f); + public ILayer MultiHeadAttention(int num_heads, + int key_dim, + int? value_dim = null, + float dropout = 0f, + bool use_bias = true, + Shape output_shape = null, + Shape attention_axes = null, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null, + Action kernel_constraint = null, + Action bias_constraint = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs new file mode 100644 index 000000000..3578652ee --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs @@ -0,0 +1,13 @@ +using System; +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public ILayer Cropping1D(NDArray cropping); + public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last); + public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs new file mode 100644 index 000000000..d0a7f09fd --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs @@ -0,0 +1,10 @@ +using System; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public ILayer Concatenate(int axis = -1); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs new file mode 100644 index 000000000..ae34c514f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs @@ -0,0 +1,22 @@ +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public ILayer Reshape(Shape target_shape); + public ILayer Reshape(object[] target_shape); + + public ILayer UpSampling1D( + int size + ); + + public ILayer UpSampling2D(Shape size = null, + string data_format = null, + string interpolation = "nearest"); + + public ILayer ZeroPadding2D(NDArray padding); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs new file mode 100644 index 000000000..57273eb08 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -0,0 +1,317 @@ +using System; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.NumPy; +using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; + +namespace Tensorflow.Keras.Layers +{ + public partial interface ILayersApi + { + public IPreprocessing preprocessing { get; } + + public ILayer Add(); + + public ILayer AveragePooling2D(Shape pool_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null); + + public ILayer BatchNormalization(int axis = -1, + float momentum = 0.99f, + float epsilon = 0.001f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null, + IInitializer moving_mean_initializer = null, + IInitializer moving_variance_initializer = null, + bool trainable = true, + string name = null, + bool renorm = false, + float renorm_momentum = 0.99f); + + /// + /// A preprocessing layer which encodes integer features. + /// + /// The total number of tokens the layer should support. + /// Specification for the output of the layer. + /// + public ILayer CategoryEncoding(int num_tokens, + string output_mode = "one_hot", + bool sparse = false, + NDArray count_weights = null); + + public ILayer Conv1D(int filters, + Shape kernel_size, + int strides = 1, + string padding = "valid", + string data_format = "channels_last", + int dilation_rate = 1, + int groups = 1, + string activation = null, + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros"); + + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid" + ); + + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + Activation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null); + + public ILayer Conv2DTranspose(int filters, + Shape kernel_size = null, + Shape strides = null, + string output_padding = "valid", + string data_format = null, + Shape dilation_rate = null, + string activation = null, + bool use_bias = true, + string kernel_initializer = null, + string bias_initializer = null, + string kernel_regularizer = null, + string bias_regularizer = null, + string activity_regularizer = null); + + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + string activation = null, + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros"); + public ILayer DepthwiseConv2D(Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + int depth_multiplier = 1, + string activation = null, + bool use_bias = false, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros", + string depthwise_initializer = "glorot_uniform" + ); + + public ILayer Dense(int units); + public ILayer Dense(int units, + string activation = null, + Shape input_shape = null); + public ILayer Dense(int units, + Activation activation = null, + IInitializer kernel_initializer = null, + bool use_bias = true, + IInitializer bias_initializer = null, + Shape input_shape = null); + + public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null); + + public ILayer Embedding(int input_dim, + int output_dim, + IInitializer embeddings_initializer = null, + bool mask_zero = false, + Shape input_shape = null, + int input_length = -1); + + public ILayer EinsumDense(string equation, + Shape output_shape, + string bias_axes, + Activation activation = null, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null, + Action kernel_constraint = null, + Action bias_constraint = null); + + public ILayer Flatten(string data_format = null); + + public ILayer GlobalAveragePooling1D(string data_format = "channels_last"); + public ILayer GlobalAveragePooling2D(); + public ILayer GlobalAveragePooling2D(string data_format = "channels_last"); + public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); + public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); + + public KerasTensor Input(Shape shape = null, + int batch_size = -1, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null); + public ILayer InputLayer(Shape input_shape, + string name = null, + bool sparse = false, + bool ragged = false); + + public ILayer LayerNormalization(Axis? axis, + float epsilon = 1e-3f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null); + + public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); + public ILayer LeakyReLU(float alpha = 0.3f); + + public ILayer ReLU6(); + + + public IRnnCell LSTMCell(int uints, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2); + + public ILayer LSTM(int units, + Activation activation = null, + Activation recurrent_activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer recurrent_initializer = null, + IInitializer bias_initializer = null, + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool time_major = false, + bool unroll = false); + + public ILayer MaxPooling1D(int? pool_size = null, + int? strides = null, + string padding = "valid", + string data_format = null); + public ILayer MaxPooling2D(Shape pool_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null); + + public ILayer Permute(int[] dims); + + public ILayer Rescaling(float scale, + float offset = 0, + Shape input_shape = null); + + public IRnnCell SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f); + + public IRnnCell StackedRNNCells( + IEnumerable cells); + + public ILayer SimpleRNN(int units, + string activation = "tanh", + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + bool return_sequences = false, + bool return_state = false); + + public ILayer RNN( + IRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false + ); + + public ILayer RNN( + IEnumerable cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false + ); + + public IRnnCell GRUCell( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool reset_after = true); + + public ILayer GRU( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false, + bool reset_after = true + ); + + /// + /// Bidirectional wrapper for RNNs. + /// + /// `keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU` + /// automatically. + /// + public ILayer Bidirectional( + ILayer layer, + string merge_mode = "concat", + NDArray weights = null, + ILayer backward_layer = null); + + public ILayer Subtract(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs new file mode 100644 index 000000000..a1a5350cc --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/IPoolFunction.cs @@ -0,0 +1,12 @@ +namespace Tensorflow +{ + public interface IPoolFunction + { + Tensor Apply(Tensor value, + int[] ksize, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs new file mode 100644 index 000000000..43df75b17 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public interface IRnnCell: ILayer + { + /// + /// If the derived class tends to not implement it, please return null. + /// + INestStructure? StateSize { get; } + /// + /// If the derived class tends to not implement it, please return null. + /// + INestStructure? OutputSize { get; } + /// + /// Whether the optional RNN args are supported when appying the layer. + /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. + /// + bool SupportOptionalArgs { get; } + Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs new file mode 100644 index 000000000..8cf6150d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Layers +{ + public interface IStackedRnnCells : IRnnCell + { + int Count { get; } + IRnnCell this[int idx] { get; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs new file mode 100644 index 000000000..408c7ca18 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras.Losses; + +public interface ILossFunc +{ + public string Reduction { get; } + public string Name { get; } + Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); +} diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs new file mode 100644 index 000000000..4c92512d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs @@ -0,0 +1,56 @@ +namespace Tensorflow.Keras.Losses; + +public interface ILossesApi +{ + ILossFunc BinaryCrossentropy(bool from_logits = false, + float label_smoothing = 0f, + int axis = -1, + string reduction = "auto", + string name = "binary_crossentropy"); + + ILossFunc SparseCategoricalCrossentropy(string reduction = null, + string name = null, + bool from_logits = false); + + ILossFunc CategoricalCrossentropy(string reduction = null, + string name = null, + bool from_logits = false); + + ILossFunc MeanSquaredError(string reduction = null, + string name = null); + + ILossFunc MeanSquaredLogarithmicError(string reduction = null, + string name = null); + + ILossFunc MeanAbsolutePercentageError(string reduction = null, + string name = null); + + ILossFunc MeanAbsoluteError(string reduction = null, + string name = null); + + ILossFunc CosineSimilarity(string reduction = null, + int axis = -1, + string name = null); + + ILossFunc Huber(string reduction = null, + string name = null, + Tensor delta = null); + + ILossFunc LogCosh(string reduction = null, + string name = null); + + /// + /// Implements the focal loss function. + /// + /// + /// + /// + /// + /// + /// + ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25f, + float gamma = 2.0f, + string reduction = "none", + string name = "sigmoid_focal_crossentropy"); +} diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs new file mode 100644 index 000000000..930afa0b0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs @@ -0,0 +1,18 @@ +namespace Tensorflow.Keras.Metrics; + +public interface IMetricFunc +{ + string Name { get; } + /// + /// Accumulates metric statistics. + /// + /// + /// + /// + /// + Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); + + Tensor result(); + + void reset_states(); +} diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs new file mode 100644 index 000000000..dbe4ac3fd --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -0,0 +1,186 @@ +namespace Tensorflow.Keras.Metrics; + +public interface IMetricsApi +{ + Tensor binary_accuracy(Tensor y_true, Tensor y_pred); + + Tensor categorical_accuracy(Tensor y_true, Tensor y_pred); + Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null); + + Tensor mean_absolute_error(Tensor y_true, Tensor y_pred); + + Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred); + + /// + /// Calculates how often predictions matches integer labels. + /// + /// Integer ground truth values. + /// The prediction values. + /// Sparse categorical accuracy values. + Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred); + + /// + /// Computes the sparse categorical crossentropy loss. + /// + /// + /// + /// + /// + /// + /// + Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null); + + /// + /// Computes how often targets are in the top `K` predictions. + /// + /// + /// + /// + /// + Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5); + + /// + /// Calculates how often predictions equal labels. + /// + /// + IMetricFunc Accuracy(string name = "accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Calculates how often predictions match binary labels. + /// + /// + IMetricFunc BinaryAccuracy(string name = "binary_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + float threshold = 05f); + + /// + /// Calculates how often predictions match one-hot labels. + /// + /// + IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null); + + /// + /// Computes the crossentropy metric between the labels and predictions. + /// + /// + IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null); + + /// + /// Computes the crossentropy metric between the labels and predictions. + /// + /// + IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Calculates how often predictions match integer labels. + /// + /// + IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes the cosine similarity between the labels and predictions. + /// + /// + IMetricFunc CosineSimilarity(string name = "cosine_similarity", + TF_DataType dtype = TF_DataType.TF_FLOAT, + Axis? axis = null); + + /// + /// Computes F-1 Score. + /// + /// + IMetricFunc F1Score(int num_classes, + string? average = null, + float? threshold = null, + string name = "f1_score", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes F-Beta score. + /// + /// + IMetricFunc FBetaScore(int num_classes, + string? average = null, + float beta = 0.1f, + float? threshold = null, + string name = "fbeta_score", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes hamming loss. + /// + /// multiclass or multilabel + /// + /// + /// + /// + IMetricFunc HammingLoss(string mode, + float? threshold = null, + string name = "hamming_loss", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes how often targets are in the top K predictions. + /// + /// + /// + IMetricFunc TopKCategoricalAccuracy(int k = 5, + string name = "top_k_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes how often integer targets are in the top K predictions. + /// + /// + /// + IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, + string name = "sparse_top_k_categorical_accuracy", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes the precision of the predictions with respect to the labels. + /// + /// + /// + /// + /// + /// + /// + IMetricFunc Precision(float thresholds = 0.5f, + int top_k = 0, + int class_id = 0, + string name = "recall", + TF_DataType dtype = TF_DataType.TF_FLOAT); + + /// + /// Computes the recall of the predictions with respect to the labels. + /// + /// + /// + /// + /// + /// + /// + IMetricFunc Recall(float thresholds = 0.5f, + int top_k = 0, + int class_id = 0, + string name = "recall", + TF_DataType dtype = TF_DataType.TF_FLOAT); +} diff --git a/src/TensorFlowNET.Core/Keras/Models/IModelsApi.cs b/src/TensorFlowNET.Core/Keras/Models/IModelsApi.cs new file mode 100644 index 000000000..007c82a17 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Models/IModelsApi.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Models +{ + public interface IModelsApi + { + public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs new file mode 100644 index 000000000..06dbb7c8c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs @@ -0,0 +1,25 @@ +using Newtonsoft.Json; +using System.Collections.Generic; +using Tensorflow.Keras.Saving.Common; + +namespace Tensorflow.Keras +{ + [JsonConverter(typeof(CustomizedRegularizerJsonConverter))] + public interface IRegularizer + { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } + Tensor Apply(RegularizerArgs args); + } + + public interface IRegularizerApi + { + IRegularizer GetRegularizerFromName(string name); + IRegularizer L1 { get; } + IRegularizer L2 { get; } + IRegularizer L1L2 { get; } + } + +} diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs new file mode 100644 index 000000000..8e7e89b1d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs @@ -0,0 +1,13 @@ +namespace Tensorflow.Keras +{ + public class RegularizerArgs + { + public Tensor X { get; set; } + + + public RegularizerArgs(Tensor x) + { + X = x; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs new file mode 100644 index 000000000..1217e1e52 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public interface IKerasConfig + { + } + + public interface IKerasConfigable + { + IKerasConfig get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs new file mode 100644 index 000000000..b348780cf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs @@ -0,0 +1,50 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Saving.Common +{ + public class CustomizedActivationJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Activation); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(""); + token.WriteTo(writer); + } + else if (value is not Activation) + { + throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject(((Activation)value).Name); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var activationName = serializer.Deserialize(reader); + if (tf.keras is null) + { + throw new RuntimeError("Tensorflow.Keras is not loaded, please install it first."); + } + return tf.keras.activations.GetActivationFromName(string.IsNullOrEmpty(activationName) ? "linear" : activationName); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs new file mode 100644 index 000000000..aea4af6d6 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs @@ -0,0 +1,57 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.Common +{ + public class CustomizedAxisJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Axis); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(new int[] { }); + token.WriteTo(writer); + } + else if (value is not Axis) + { + throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Axis)!.axis); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + int[]? axis; + if (reader.ValueType == typeof(long)) + { + axis = new int[1]; + axis[0] = (int)serializer.Deserialize(reader, typeof(int)); + } + else + { + axis = serializer.Deserialize(reader, typeof(int[])) as int[]; + } + if (axis is null) + { + throw new ValueError("Cannot deserialize 'null' to `Axis`."); + } + return new Axis(axis!); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs new file mode 100644 index 000000000..29b3b094c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs @@ -0,0 +1,36 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; + +namespace Tensorflow.Keras.Saving.Common +{ + public class CustomizedDTypeJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(TF_DataType); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + var token = JToken.FromObject(((TF_DataType)value).as_numpy_name()); + token.WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + if (reader.ValueType == typeof(string)) + { + var str = (string)serializer.Deserialize(reader, typeof(string)); + return dtypes.tf_dtype_from_name(str); + } + else + { + return (TF_DataType)serializer.Deserialize(reader, typeof(int)); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs new file mode 100644 index 000000000..a7bae56d0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs @@ -0,0 +1,69 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; + +using Tensorflow.Operations.Initializers; + +namespace Tensorflow.Keras.Saving.Common +{ + class InitializerInfo + { + public string class_name { get; set; } + public JObject config { get; set; } + } + public class CustomizedIinitializerJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(IInitializer); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + var initializer = value as IInitializer; + if (initializer is null) + { + JToken.FromObject(null).WriteTo(writer); + return; + } + JToken.FromObject(new InitializerInfo() + { + class_name = initializer.ClassName, + config = JObject.FromObject(initializer.Config) + }, serializer).WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var info = serializer.Deserialize(reader); + if (info is null) + { + return null; + } + return info.class_name switch + { + "Constant" => new Constant(info.config["value"].ToObject()), + "GlorotUniform" => new GlorotUniform(seed: info.config["seed"].ToObject()), + "Ones" => new Ones(), + "Orthogonal" => new Orthogonal(info.config["gain"].ToObject(), info.config["seed"].ToObject()), + "RandomNormal" => new RandomNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(), + info.config["seed"].ToObject()), + "RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject(), + maxval: info.config["maxval"].ToObject(), seed: info.config["seed"].ToObject()), + "TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(), + info.config["seed"].ToObject()), + "VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject(), info.config["mode"].ToObject(), + info.config["distribution"].ToObject(), info.config["seed"].ToObject()), + "Zeros" => new Zeros(), + _ => throw new ValueError($"The specified initializer {info.class_name} cannot be recognized.") + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs new file mode 100644 index 000000000..3a21db9d2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs @@ -0,0 +1,76 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Saving.Json +{ + public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(KerasShapesWrapper); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + JToken.FromObject(null).WriteTo(writer); + return; + } + if (value is not KerasShapesWrapper wrapper) + { + throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}"); + } + if (wrapper.Shapes.Length == 0) + { + JToken.FromObject(null).WriteTo(writer); + } + else if (wrapper.Shapes.Length == 1) + { + JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer); + } + else + { + JToken.FromObject(wrapper.Shapes).WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + if (reader.TokenType == JsonToken.StartArray) + { + TensorShapeConfig[] shapes = serializer.Deserialize(reader); + if (shapes is null) + { + return null; + } + return new KerasShapesWrapper(shapes); + } + else if (reader.TokenType == JsonToken.StartObject) + { + var shape = serializer.Deserialize(reader); + if (shape is null) + { + return null; + } + return new KerasShapesWrapper(shape); + } + else if (reader.TokenType == JsonToken.Null) + { + return null; + } + else + { + throw new ValueError($"Cannot deserialize the token type {reader.TokenType}"); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs new file mode 100644 index 000000000..51194a610 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs @@ -0,0 +1,100 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Saving.Common +{ + public class CustomizedNodeConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(NodeConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not NodeConfig) + { + throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var config = value as NodeConfig; + var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var values = serializer.Deserialize(reader, typeof(object[])) as object[]; + if (values is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + if (values.Length == 1) + { + var array = values[0] as JArray; + if (array is null) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + values = array.ToObject(); + } + if (values.Length < 3) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + if (values[0] is not string) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); + } + int nodeIndex; + int tensorIndex; + if (values[1] is long) + { + nodeIndex = (int)(long)values[1]; + } + else if (values[1] is int) + { + nodeIndex = (int)values[1]; + } + else + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); + } + if (values[2] is long) + { + tensorIndex = (int)(long)values[2]; + } + else if (values[1] is int) + { + tensorIndex = (int)values[2]; + } + else + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); + } + return new NodeConfig() + { + Name = values[0] as string, + NodeIndex = nodeIndex, + TensorIndex = tensorIndex + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedRegularizerJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedRegularizerJsonConverter.cs new file mode 100644 index 000000000..4b1790aca --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedRegularizerJsonConverter.cs @@ -0,0 +1,57 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Regularizers; + +namespace Tensorflow.Keras.Saving.Common +{ + class RegularizerInfo + { + public string class_name { get; set; } + public JObject config { get; set; } + } + + public class CustomizedRegularizerJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(IRegularizer); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + var regularizer = value as IRegularizer; + if (regularizer is null) + { + JToken.FromObject(null).WriteTo(writer); + return; + } + JToken.FromObject(new RegularizerInfo() + { + class_name = regularizer.ClassName, + config = JObject.FromObject(regularizer.Config) + }, serializer).WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var info = serializer.Deserialize(reader); + if (info is null) + { + return null; + } + return info.class_name switch + { + "L1L2" => new L1L2 (info.config["l1"].ToObject(), info.config["l2"].ToObject()), + "L1" => new L1(info.config["l1"].ToObject()), + "L2" => new L2(info.config["l2"].ToObject()), + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs new file mode 100644 index 000000000..39799e929 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs @@ -0,0 +1,93 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.Common +{ + class ShapeInfoFromPython + { + public string class_name { get; set; } + public long?[] items { get; set; } + } + public class CustomizedShapeJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Shape); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not Shape) + { + throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var shape = (value as Shape)!; + long?[] dims = new long?[shape.ndim]; + for (int i = 0; i < dims.Length; i++) + { + if (shape.dims[i] == -1) + { + dims[i] = null; + } + else + { + dims[i] = shape.dims[i]; + } + } + var token = JToken.FromObject(new ShapeInfoFromPython() + { + class_name = "__tuple__", + items = dims + }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + long?[] dims; + if (reader.TokenType == JsonToken.StartObject) + { + var shape_info_from_python = serializer.Deserialize(reader); + if (shape_info_from_python is null) + { + return null; + } + dims = shape_info_from_python.items; + } + else if (reader.TokenType == JsonToken.StartArray) + { + dims = serializer.Deserialize(reader); + } + else if (reader.TokenType == JsonToken.Null) + { + return null; + } + else + { + throw new ValueError($"Cannot deserialize the token {reader} as Shape."); + } + long[] convertedDims = new long[dims.Length]; + for (int i = 0; i < dims.Length; i++) + { + convertedDims[i] = dims[i] ?? -1; + } + return new Shape(convertedDims); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs new file mode 100644 index 000000000..ea6fe976f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs @@ -0,0 +1,61 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using System.Diagnostics; +using OneOf.Types; +using Tensorflow.Keras.Saving.Json; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Saving +{ + [JsonConverter(typeof(CustomizedKerasShapesWrapperJsonConverter))] + public class KerasShapesWrapper + { + public TensorShapeConfig[] Shapes { get; set; } + + public KerasShapesWrapper(Shape shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public KerasShapesWrapper(TensorShapeConfig shape) + { + Shapes = new TensorShapeConfig[] { shape }; + } + + public KerasShapesWrapper(TensorShapeConfig[] shapes) + { + Shapes = shapes; + } + + public KerasShapesWrapper(IEnumerable shape) + { + Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); + } + + public Shape ToSingleShape() + { + Debug.Assert(Shapes.Length == 1); + var shape_config = Shapes[0]; + Debug.Assert(shape_config is not null); + return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); + } + + public Shape[] ToShapeArray() + { + return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); + } + + public static implicit operator KerasShapesWrapper(Shape shape) + { + return new KerasShapesWrapper(shape); + } + public static implicit operator KerasShapesWrapper(TensorShapeConfig shape) + { + return new KerasShapesWrapper(shape); + } + + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs new file mode 100644 index 000000000..4ce290c83 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving +{ + public class LayerConfig: IKerasConfig + { + [JsonProperty("name")] + public string Name { get; set; } + [JsonProperty("class_name")] + public string ClassName { get; set; } + [JsonProperty("config")] + public LayerArgs Config { get; set; } + [JsonProperty("inbound_nodes")] + public List InboundNodes { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs new file mode 100644 index 000000000..8ddcd1f04 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -0,0 +1,26 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; + +namespace Tensorflow.Keras.Saving +{ + public class FunctionalConfig : IKerasConfig + { + [JsonProperty("name")] + public string Name { get; set; } + [JsonProperty("layers")] + public List Layers { get; set; } + [JsonProperty("input_layers")] + public List InputLayers { get; set; } + [JsonProperty("output_layers")] + public List OutputLayers { get; set; } + + public override string ToString() + => $"{Name}, {Layers.Count} Layers, {InputLayers.Count} Input Layers, {OutputLayers.Count} Output Layers"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs new file mode 100644 index 000000000..8337ae018 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -0,0 +1,19 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Saving.Common; + +namespace Tensorflow.Keras.Saving +{ + [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] + public class NodeConfig : IKerasConfig + { + public string Name { get; set; } + public int NodeIndex { get; set; } + public int TensorIndex { get; set; } + + public override string ToString() + => $"{Name}, {NodeIndex}, {TensorIndex}"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs new file mode 100644 index 000000000..ae8a1ab13 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public interface ISerializedAttributes + { + IDictionary Functions { get; } + + IDictionary CheckpointableObjects { get; } + + /// + /// Returns functions to attach to the root object during serialization. + /// + IDictionary FunctionsToSerialize { get; } + + /// + /// Returns objects to attach to the root object during serialization. + /// + IDictionary ObjectsToSerialize{get; } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + IDictionary set_and_validate_functions(IDictionary function_dict); + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + IDictionary set_and_validate_objects(IDictionary object_dict); + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs new file mode 100644 index 000000000..feb65711b --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteInterpreterHandle : SafeTensorflowHandle + { + protected SafeTfLiteInterpreterHandle() + { + } + + public SafeTfLiteInterpreterHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteInterpreterDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs new file mode 100644 index 000000000..728936468 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteInterpreterOptionsHandle : SafeTensorflowHandle + { + protected SafeTfLiteInterpreterOptionsHandle() + { + } + + public SafeTfLiteInterpreterOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteInterpreterOptionsDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs new file mode 100644 index 000000000..bdae15431 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteModelHandle : SafeTensorflowHandle + { + protected SafeTfLiteModelHandle() + { + } + + public SafeTfLiteModelHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteModelDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs b/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs new file mode 100644 index 000000000..7b3aa1023 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Lite +{ + public enum TfLiteDataType + { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, + kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, + kTfLiteUInt64 = 13, + kTfLiteResource = 14, + kTfLiteVariant = 15, + kTfLiteUInt32 = 16, + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs b/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs new file mode 100644 index 000000000..e564392c5 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Lite +{ + public struct TfLiteQuantizationParams + { + public float scale; + public int zero_point; + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs b/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs new file mode 100644 index 000000000..066121251 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Lite +{ + public enum TfLiteStatus + { + kTfLiteOk = 0, + + // Generally referring to an error in the runtime (i.e. interpreter) + kTfLiteError = 1, + + // Generally referring to an error from a TfLiteDelegate itself. + kTfLiteDelegateError = 2, + + // Generally referring to an error in applying a delegate due to + // incompatibility between runtime and delegate, e.g., this error is returned + // when trying to apply a TfLite delegate onto a model graph that's already + // immutable. + kTfLiteApplicationError = 3, + + // Generally referring to serialized delegate data not being found. + // See tflite::delegates::Serialization. + kTfLiteDelegateDataNotFound = 4, + + // Generally referring to data-writing issues in delegate serialization. + // See tflite::delegates::Serialization. + kTfLiteDelegateDataWriteError = 5, + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs b/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs new file mode 100644 index 000000000..5a43f58fc --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs @@ -0,0 +1,21 @@ +using System; + +namespace Tensorflow.Lite +{ + public struct TfLiteTensor + { + IntPtr _handle; + + public TfLiteTensor(IntPtr handle) + => _handle = handle; + + public static implicit operator TfLiteTensor(IntPtr handle) + => new TfLiteTensor(handle); + + public static implicit operator IntPtr(TfLiteTensor tensor) + => tensor._handle; + + public override string ToString() + => $"TfLiteTensor 0x{_handle.ToString("x16")}"; + } +} diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs new file mode 100644 index 000000000..9ff381299 --- /dev/null +++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow.ModelSaving +{ + public class ModelSaver + { + public void save(Trackable obj, string export_dir, SaveOptions options = null) + { + var saved_model = new SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + _build_meta_graph(obj, export_dir, options, meta_graph_def); + } + + void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options, + MetaGraphDef meta_graph_def = null) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs b/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs new file mode 100644 index 000000000..94828922c --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs @@ -0,0 +1,36 @@ +using MethodBoundaryAspect.Fody.Attributes; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + [DebuggerStepThrough] + public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect + { + bool _changedMode = false; + bool _locked = false; + static object locker = new Object(); + public override void OnEntry(MethodExecutionArgs args) + { + Monitor.Enter(locker, ref _locked); + + if (!tf.executing_eagerly()) + { + tf.Context.eager_mode(); + _changedMode = true; + } + } + + public override void OnExit(MethodExecutionArgs args) + { + if (_changedMode) + tf.Context.restore_mode(); + + if (_locked) + Monitor.Exit(locker); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs new file mode 100644 index 000000000..7a3ecbf10 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -0,0 +1,76 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving.Common; + +namespace Tensorflow +{ + [JsonConverter(typeof(CustomizedAxisJsonConverter))] + public class Axis + { + public int[] axis { get; set; } + public int size => axis == null ? -1 : axis.Length; + public bool IsScalar { get; init; } + + public int this[int index] => axis[index]; + + public Axis(params int[] axis) + { + this.axis = axis; + } + + public static implicit operator int[]?(Axis axis) + => axis?.axis; + + public static implicit operator int(Axis axis) + => axis.axis[0]; + + public static implicit operator Axis(int axis) + => new Axis(axis) { IsScalar = true }; + + public static implicit operator Axis((int, int) axis) + => new Axis(axis.Item1, axis.Item2); + + public static implicit operator Axis((int, int, int) axis) + => new Axis(axis.Item1, axis.Item2, axis.Item3); + + public static implicit operator Axis(int[] axis) + => new Axis(axis); + + public static implicit operator Axis(long[] axis) + => new Axis(axis.Select(x => (int)x).ToArray()); + + public static implicit operator Axis(Shape axis) + => new Axis(axis.dims.Select(x => (int)x).ToArray()); + + public static implicit operator Tensor(Axis axis) + => constant_op.constant(axis); + + public static bool operator ==(Axis left, int right) + => left.IsScalar && left[0] == right; + + public static bool operator !=(Axis left, int right) + => !(left == right); + + public override string ToString() + => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs new file mode 100644 index 000000000..7d287552c --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public class LinearAlgebraImpl + { + [AutoNumPy] + public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn") + => new NDArray(tf.linalg.lstsq(a, b)); + + [AutoNumPy] + public NDArray norm(NDArray a, Axis axis = null) + { + if (a.dtype.is_integer()) + { + var float_a = math_ops.cast(a, dtype: tf.float32); + return new NDArray(tf.linalg.norm(float_a, axis: axis)); + } + + return new NDArray(tf.linalg.norm(a, axis: axis)); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs new file mode 100644 index 000000000..c0f9e695d --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs @@ -0,0 +1,116 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow.Util; +using Razorvine.Pickle; +using Tensorflow.NumPy.Pickle; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NumPyImpl + { + public NDArray eye(int N, int? M = null, int k = 0, TF_DataType dtype = TF_DataType.TF_DOUBLE) + { + if (!M.HasValue) + M = N; + + var diag_len = min(N, M.Value); + if (k > 0) + { + if (N >= M) + diag_len -= k; + else if (N + k > M) + diag_len = M.Value - k; + } + else + { + if (M >= N) + diag_len += k; + else if (M - k > N) + diag_len = N + k; + } + + var diagonal_ = array_ops.ones(new Shape(diag_len), dtype: dtype); + var tensor = array_ops.matrix_diag(diagonal: diagonal_, num_rows: N, num_cols: M.Value, k: k); + return new NDArray(tensor); + } + + public NDArray frombuffer(byte[] bytes, string dtype) + { + if (dtype == ">u4") + { + var size = bytes.Length / sizeof(uint); + var ints = new int[size]; + for (var index = 0; index < size; index++) + ints[index] = bytes[0] * 256 + bytes[1] + bytes[2] * 256 + bytes[3]; + + return new NDArray(ints, shape: new Shape(size)); + } + + throw new NotImplementedException(""); + } + + public NDArray frombuffer(byte[] bytes, Shape shape, TF_DataType dtype) + { + return new NDArray(bytes, shape, dtype); + } + + public NDArray linspace(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, + TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) + { + var start_tensor = array_ops.constant(start, dtype: dtype); + var stop_tensor = array_ops.constant(stop, dtype: dtype); + + // var step_tensor = array_ops.constant(np.nan); + Tensor result = null; + + if (endpoint) + { + result = math_ops.linspace(start_tensor, stop_tensor, num, axis: axis); + } + else + { + if (num > 1) + { + var step = (stop_tensor - start_tensor) / num; + var new_stop = math_ops.cast(stop_tensor, step.dtype) - step; + start_tensor = math_ops.cast(start_tensor, new_stop.dtype); + result = math_ops.linspace(start_tensor, new_stop, num, axis: axis); + } + else + result = math_ops.linspace(start_tensor, stop_tensor, num, axis: axis); + } + + return new NDArray(result); + } + + Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape) + { + int total = 1; + for (int i = 0; i < shape.Length; i++) + total *= shape[i]; + + var buffer = reader.ReadBytes(bytes * total); + System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length); + + return matrix; + } + + Array ReadObjectMatrix(BinaryReader reader, Array matrix, int[] shape) + { + Stream deflateStream = reader.BaseStream; + BufferedStream bufferedStream = new BufferedStream(deflateStream); + var unpickler = new Unpickler(); + return (MultiArrayPickleWarpper)unpickler.load(bufferedStream); + } + + public (NDArray, NDArray) meshgrid(T[] array, bool copy = true, bool sparse = false) + { + var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse); + return (new NDArray(tensors[0]), new NDArray(tensors[1])); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs new file mode 100644 index 000000000..bc6047eb1 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; + +namespace Tensorflow.NumPy +{ + public partial class NumPyImpl + { + public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) + { + var dtype = NumPyUtils.GetResultType(a.dtype, np.float64); + if(weights is null) + { + var tensorA = math_ops.cast(a, dtype); + var nd = math_ops.reduce_mean(tensorA, axis); + return new NDArray(nd); + } + else + { + var tensorW = math_ops.cast(weights, dtype); + if(a.rank != weights.rank) + { + var weights_sum = math_ops.reduce_sum(tensorW); + var axes = np.array(new[,] { { axis }, { 0 } }); + var avg = math_ops.tensordot(a, weights, axes) / weights_sum; + } + + throw new NotImplementedException(""); + } + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.cs new file mode 100644 index 000000000..5c6dee2df --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy +{ + public partial class NumPyImpl + { + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs new file mode 100644 index 000000000..199e5ced3 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs @@ -0,0 +1,167 @@ +using System.IO; + +namespace Tensorflow.NumPy +{ + public partial class NumPyImpl + { + public NDArray load(string file) + { + using var stream = new FileStream(file, FileMode.Open); + using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true); + if (!ParseReader(reader, out var bytes, out var type, out var shape)) + throw new FormatException(); + + Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim)); + + var result = new NDArray(ReadValueMatrix(reader, array, bytes, type, shape)); + return result.reshape(shape); + } + + public Array LoadMatrix(Stream stream) + { + using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true)) + { + if (!ParseReader(reader, out var bytes, out var type, out var shape)) + throw new FormatException(); + + Array matrix = Array.CreateInstance(type, shape); + + //if (type == typeof(String)) + //return ReadStringMatrix(reader, matrix, bytes, type, shape); + + if (type == typeof(Object)) + return ReadObjectMatrix(reader, matrix, shape); + else + { + return ReadValueMatrix(reader, matrix, bytes, type, shape); + } + } + } + + public T Load(Stream stream) + where T : class, + ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + { + // if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string))) + // return LoadJagged(stream) as T; + return LoadMatrix(stream) as T; + } + + bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape) + { + bytes = 0; + t = null; + shape = null; + + // The first 6 bytes are a magic string: exactly "x93NUMPY" + if (reader.ReadChar() != 63) return false; + if (reader.ReadChar() != 'N') return false; + if (reader.ReadChar() != 'U') return false; + if (reader.ReadChar() != 'M') return false; + if (reader.ReadChar() != 'P') return false; + if (reader.ReadChar() != 'Y') return false; + + byte major = reader.ReadByte(); // 1 + byte minor = reader.ReadByte(); // 0 + + if (major != 1 || minor != 0) + throw new NotSupportedException(); + + ushort len = reader.ReadUInt16(); + + string header = new String(reader.ReadChars(len)); + string mark = "'descr': '"; + int s = header.IndexOf(mark) + mark.Length; + int e = header.IndexOf("'", s + 1); + string type = header.Substring(s, e - s); + bool? isLittleEndian; + t = GetType(type, out bytes, out isLittleEndian); + + if (isLittleEndian.HasValue && isLittleEndian.Value == false) + throw new Exception(); + + mark = "'fortran_order': "; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(",", s + 1); + bool fortran = bool.Parse(header.Substring(s, e - s)); + + if (fortran) + throw new Exception(); + + mark = "'shape': ("; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(")", s + 1); + shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray(); + + return true; + } + + Type GetType(string dtype, out int bytes, out bool? isLittleEndian) + { + isLittleEndian = IsLittleEndian(dtype); + bytes = dtype.Length > 2 ? Int32.Parse(dtype.Substring(2)) : 0; + + string typeCode = dtype.Substring(1); + + if (typeCode == "b1") + return typeof(bool); + if (typeCode == "i1") + return typeof(Byte); + if (typeCode == "i2") + return typeof(Int16); + if (typeCode == "i4") + return typeof(Int32); + if (typeCode == "i8") + return typeof(Int64); + if (typeCode == "u1") + return typeof(Byte); + if (typeCode == "u2") + return typeof(UInt16); + if (typeCode == "u4") + return typeof(UInt32); + if (typeCode == "u8") + return typeof(UInt64); + if (typeCode == "f4") + return typeof(Single); + if (typeCode == "f8") + return typeof(Double); + if (typeCode.StartsWith("S")) + return typeof(String); + if (typeCode.StartsWith("O")) + return typeof(Object); + + throw new NotSupportedException(); + } + + bool? IsLittleEndian(string type) + { + bool? littleEndian = null; + + switch (type[0]) + { + case '<': + littleEndian = true; + break; + case '>': + littleEndian = false; + break; + case '|': + littleEndian = null; + break; + default: + throw new Exception(); + } + + return littleEndian; + } + + Array Create(Type type, int length) + { + // ReSharper disable once PossibleNullReferenceException + while (type.IsArray) + type = type.GetElementType(); + + return Array.CreateInstance(type, length); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs new file mode 100644 index 000000000..a707e8aae --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class RandomizedImpl + { + [AutoNumPy] + public NDArray permutation(int x) => new NDArray(random_ops.random_shuffle(math_ops.range(0, x))); + + [AutoNumPy] + public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x)); + + [AutoNumPy] + public void shuffle(NDArray x, int? seed = null) + { + var y = random_ops.random_shuffle(x, seed); + Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); + } + + public NDArray random(Shape size) + => uniform(low: 0, high: 1, size: size); + + [AutoNumPy] + public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32) + { + if(high == null) + { + high = low; + low = 0; + } + size = size ?? Shape.Scalar; + var tensor = random_ops.random_uniform_int(shape: size, minval: low, maxval: (int)high); + return new NDArray(tensor); + } + + [AutoNumPy] + public NDArray randn(params int[] shape) + => new NDArray(random_ops.random_normal(shape ?? Shape.Scalar)); + + [AutoNumPy] + public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null) + => new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); + + [AutoNumPy] + public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null) + => new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high)); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs new file mode 100644 index 000000000..2aa327b5b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NDArray + { + public override bool Equals(object obj) + { + return obj switch + { + int val => GetAtIndex(0) == val, + long val => GetAtIndex(0) == val, + float val => GetAtIndex(0) == val, + double val => GetAtIndex(0) == val, + string val => StringData(0) == val, + int[] val => ToArray().SequenceEqual(val), + long[] val => ToArray().SequenceEqual(val), + float[] val => ToArray().SequenceEqual(val), + double[] val => ToArray().SequenceEqual(val), + NDArray val => Equals(this, val), + _ => base.Equals(obj) + }; + } + + bool Equals(NDArray x, NDArray y) + { + if (x.ndim != y.ndim) + return false; + else if (x.size != y.size) + return false; + else if (x.dtype != y.dtype) + return false; + + return Enumerable.SequenceEqual(x.ToByteArray(), y.ToByteArray()); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs new file mode 100644 index 000000000..45b236c7b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy +{ + public partial class NDArray + { + public void Deconstruct(out byte blue, out byte green, out byte red) + { + var data = ToArray(); + blue = data[0]; + green = data[1]; + red = data[2]; + } + + public static implicit operator NDArray(int[] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[] array) + => new NDArray(array); + + public static implicit operator NDArray(float[] array) + => new NDArray(array); + + public static implicit operator NDArray(double[] array) + => new NDArray(array); + + public static implicit operator NDArray(long[] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,,] array) + => new NDArray(array); + + public unsafe static implicit operator bool(NDArray nd) + => nd.dtype == TF_DataType.TF_BOOL ? *(bool*)nd.data : NDArrayConverter.Scalar(nd); + + public unsafe static implicit operator byte(NDArray nd) + => nd.dtype == TF_DataType.TF_UINT8 ? *(byte*)nd.data : NDArrayConverter.Scalar(nd); + + public unsafe static implicit operator int(NDArray nd) + => nd.dtype == TF_DataType.TF_INT32 ? *(int*)nd.data : NDArrayConverter.Scalar(nd); + + public unsafe static implicit operator long(NDArray nd) + => nd.dtype == TF_DataType.TF_INT64 ? *(long*)nd.data : NDArrayConverter.Scalar(nd); + + public unsafe static implicit operator float(NDArray nd) + => nd.dtype == TF_DataType.TF_FLOAT ? *(float*)nd.data : NDArrayConverter.Scalar(nd); + + public unsafe static implicit operator double(NDArray nd) + => nd.dtype == TF_DataType.TF_DOUBLE ? *(double*)nd.data : NDArrayConverter.Scalar(nd); + + public static implicit operator NDArray(bool value) + => new NDArray(value); + + public static implicit operator NDArray(byte value) + => new NDArray(value); + + public static implicit operator NDArray(int value) + => new NDArray(value); + + public static implicit operator NDArray(long value) + => new NDArray(value); + + public static implicit operator NDArray(float value) + => new NDArray(value); + + public static implicit operator NDArray(double value) + => new NDArray(value); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs new file mode 100644 index 000000000..9c0d728f8 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -0,0 +1,287 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NDArray + { + public NDArray this[params int[] indices] + { + get => GetData(indices.Select(x => new Slice + { + Start = x, + Stop = x + 1, + IsIndex = true + }).ToArray()); + + set => SetData(indices.Select(x => + { + if(x < 0) + x = (int)dims[0] + x; + + var slice = new Slice + { + Start = x, + Stop = x + 1, + IsIndex = true + }; + + return slice; + }), value); + } + + public NDArray this[params Slice[] slices] + { + get => GetData(slices); + set => SetData(slices, value); + } + + public NDArray this[NDArray mask] + { + get + { + if (mask.dtype == TF_DataType.TF_BOOL) + return GetData(enumerate(mask.ToArray()).Where(x => x.Item2).Select(x => x.Item1).ToArray()); + else if (mask.dtype == TF_DataType.TF_INT32) + return GetData(mask.ToArray()); + else if (mask.dtype == TF_DataType.TF_INT64) + return GetData(mask.ToArray().Select(x => Convert.ToInt32(x)).ToArray()); + else if (mask.dtype == TF_DataType.TF_FLOAT) + return GetData(mask.ToArray().Select(x => Convert.ToInt32(x)).ToArray()); + + throw new NotImplementedException(""); + } + + set + { + if (mask.dtype == TF_DataType.TF_BOOL) + MaskData(mask, value); + else + throw new NotImplementedException(""); + } + } + + [AutoNumPy] + unsafe NDArray GetData(Slice[] slices) + { + if (shape.IsScalar) + return GetScalar(); + + if (SliceHelper.AreAllIndex(slices, out var indices1)) + { + var newshape = ShapeHelper.GetShape(shape, slices); + if (newshape.IsScalar) + { + var offset = ShapeHelper.GetOffset(shape, indices1); + return GetScalar((ulong)offset); + } + else + { + return GetArrayData(newshape, indices1); + } + } + else if (slices.Count() == 1) + { + var slice = slices[0]; + if (slice.Step == 1) + { + var newshape = ShapeHelper.GetShape(shape, slice); + var array = new NDArray(newshape, dtype: dtype); + + var new_dims = new int[shape.ndim]; + new_dims[0] = slice.Start ?? 0; + //for (int i = 1; i < shape.ndim; i++) + //new_dims[i] = (int)shape.dims[i]; + + var offset = ShapeHelper.GetOffset(shape, new_dims); + var src = (byte*)data + (ulong)offset * dtypesize; + var dst = (byte*)array.data; + var len = (ulong)newshape.size * dtypesize; + + System.Buffer.MemoryCopy(src, dst, len, len); + + return array; + } + } + + // default, performance is bad + var tensor = base[slices.ToArray()]; + if (tensor.Handle == null) + { + if (tf.executing_eagerly()) + tensor = tf.get_default_session().eval(tensor); + } + + return new NDArray(tensor, tf.executing_eagerly()); + } + + unsafe T GetAtIndex(params int[] indices) where T : unmanaged + { + var offset = (ulong)ShapeHelper.GetOffset(shape, indices); + return *((T*)data + offset); + } + + unsafe NDArray GetScalar(ulong offset = 0) + { + var array = new NDArray(Shape.Scalar, dtype: dtype); + var src = (byte*)data + offset * dtypesize; + System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize); + return array; + } + + unsafe NDArray GetArrayData(Shape newshape, int[] indices) + { + var offset = ShapeHelper.GetOffset(shape, indices); + var len = (ulong)newshape.size * dtypesize; + var array = new NDArray(newshape, dtype: dtype); + + var src = (byte*)data + (ulong)offset * dtypesize; + System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len); + + return array; + } + + unsafe NDArray GetData(int[] indices, int axis = 0) + { + if (shape.IsScalar) + return GetScalar(); + + if(axis == 0) + { + var dims = shape.as_int_list(); + dims[0] = indices.Length; + + var array = np.ndarray(dims, dtype: dtype); + + dims[0] = 1; + var len = new Shape(dims).size * dtype.get_datatype_size(); + + int dst_index = 0; + foreach (var pos in indices) + { + var src_offset = (ulong)ShapeHelper.GetOffset(shape, pos); + var dst_offset = (ulong)ShapeHelper.GetOffset(array.shape, dst_index++); + + var src = (byte*)data + src_offset * dtypesize; + var dst = (byte*)array.data + dst_offset * dtypesize; + System.Buffer.MemoryCopy(src, dst, len, len); + } + + return array; + } + else + throw new NotImplementedException(""); + } + + void SetData(IEnumerable slices, NDArray array) + => SetData(array, slices.ToArray(), new int[shape.ndim].ToArray(), -1); + + unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim) + { + if (dtype != src.dtype) + // src = src.astype(dtype); + throw new ArrayTypeMismatchException($"Required dtype {dtype} but {src.dtype} is assigned."); + + if (!slices.Any()) + return; + + if (shape.Equals(src.shape)) + { + System.Buffer.MemoryCopy(src.data.ToPointer(), data.ToPointer(), src.bytesize, src.bytesize); + return; + } + + // first iteration + if(currentNDim == -1) + { + slices = SliceHelper.AlignWithShape(shape, slices); + } + + // last dimension + if (currentNDim == ndim - 1) + { + var offset = (int)ShapeHelper.GetOffset(shape, indices); + var dst = data + offset * (int)dtypesize; + System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); + return; + } + + currentNDim++; + var slice = slices[currentNDim]; + + var start = slice.Start ?? 0; + var stop = slice.Stop ?? (int)dims[currentNDim]; + var step = slice.Step; + + if(step != 1) + { + for (var i = start; i < stop; i += step) + { + if (i >= dims[currentNDim]) + throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}"); + + indices[currentNDim] = i; + if (currentNDim < ndim - src.ndim) + { + SetData(src, slices, indices, currentNDim); + } + else + { + var srcIndex = (i - start) / step; + SetData(src[srcIndex], slices, indices, currentNDim); + } + } + } + else + { + for (var i = start; i < stop; i++) + { + if (i >= dims[currentNDim]) + throw new OutOfRangeError($"Index should be in [0, {dims[currentNDim]}] but got {i}"); + + indices[currentNDim] = i; + if (currentNDim < ndim - src.ndim) + { + SetData(src, slices, indices, currentNDim); + } + // last dimension + else if(currentNDim == ndim - 1) + { + SetData(src, slices, indices, currentNDim); + break; + } + else if(SliceHelper.IsContinuousBlock(slices, currentNDim)) + { + var offset = (int)ShapeHelper.GetOffset(shape, indices); + var dst = data + offset * (int)dtypesize; + System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); + return; + } + else + { + var srcIndex = i - start; + SetData(src[srcIndex], slices, indices, currentNDim); + } + } + } + + // reset indices + indices[currentNDim] = 0; + } + + unsafe void MaskData(NDArray mask, NDArray value) + { + var masks = mask.ToArray(); + var s1 = new Shape(dims.Skip(mask.rank).ToArray()); + var val = tf.fill(s1, value).numpy(); + for (int i = 0; i < masks.Length; i++) + { + if (masks[i]) + this[i] = val; + } + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs new file mode 100644 index 000000000..dd4577096 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs @@ -0,0 +1,61 @@ +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NDArray + { + [AutoNumPy] + public static NDArray operator +(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("add", lhs, rhs)); + [AutoNumPy] + public static NDArray operator -(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("sub", lhs, rhs)); + [AutoNumPy] + public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs)); + [AutoNumPy] + public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs)); + [AutoNumPy] + public static NDArray operator %(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mod", lhs, rhs)); + [AutoNumPy] + public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs)); + [AutoNumPy] + public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs)); + [AutoNumPy] + public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs)); + [AutoNumPy] + public static NDArray operator ==(NDArray lhs, NDArray rhs) + { + if (ReferenceEquals(lhs, rhs)) + return Scalar(true); + if (lhs is null) + return Scalar(false); + if (rhs is null) + return Scalar(false); + // TODO(Rinne): use np.allclose instead. + if (lhs.dtype.is_floating() || rhs.dtype.is_floating()) + { + var diff = tf.abs(lhs - rhs); + return new NDArray(gen_math_ops.less(diff, new NDArray(1e-5).astype(diff.dtype))); + } + else + { + return new NDArray(math_ops.equal(lhs, rhs)); + } + } + [AutoNumPy] + public static NDArray operator !=(NDArray lhs, NDArray rhs) + { + if (ReferenceEquals(lhs, rhs)) + return Scalar(false); + if (lhs is null || rhs is null) + return Scalar(true); + if (lhs.dtype.is_floating() || rhs.dtype.is_floating()) + { + var diff = tf.abs(lhs - rhs); + return new NDArray(gen_math_ops.greater_equal(diff, new NDArray(1e-5).astype(diff.dtype))); + } + else + { + return new NDArray(math_ops.not_equal(lhs, rhs)); + } + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs new file mode 100644 index 000000000..4c64eba74 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class NDArrayConverter + { + public unsafe static T Scalar(NDArray nd) where T : unmanaged + => nd.dtype switch + { + TF_DataType.TF_BOOL => Scalar(*(bool*)nd.data), + TF_DataType.TF_UINT8 => Scalar(*(byte*)nd.data), + TF_DataType.TF_FLOAT => Scalar(*(float*)nd.data), + TF_DataType.TF_INT32 => Scalar(*(int*)nd.data), + TF_DataType.TF_INT64 => Scalar(*(long*)nd.data), + TF_DataType.TF_DOUBLE => Scalar(*(double*)nd.data), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + static T Scalar(byte input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), + TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + static T Scalar(float input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), + TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + static T Scalar(int input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), + TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + static T Scalar(long input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), + TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + public static unsafe Array ToMultiDimArray(NDArray nd) where T : unmanaged + { + var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list()); + + var addr = ret switch + { + T[] array => Addr(array), + T[,] array => Addr(array), + T[,,] array => Addr(array), + T[,,,] array => Addr(array), + T[,,,,] array => Addr(array), + T[,,,,,] array => Addr(array), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) + }; + + System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize); + return ret; + } + + #region multiple array + static unsafe T* Addr(T[] array) where T : unmanaged + { + fixed (T* a = &array[0]) + return a; + } + + static unsafe T* Addr(T[,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0]) + return a; + } + + static unsafe T* Addr(T[,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0, 0]) + return a; + } + + static unsafe T* Addr(T[,,,,,] array) where T : unmanaged + { + fixed (T* a = &array[0, 0, 0, 0, 0, 0]) + return a; + } + #endregion + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs b/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs new file mode 100644 index 000000000..230797b8b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; + +namespace Tensorflow.NumPy +{ + public class NDArrayRender + { + public static string ToString(NDArray array, int maxLength = 10) + { + Shape shape = array.shape; + if (shape.IsScalar) + return Render(array); + + var s = new StringBuilder(); + s.Append("array("); + Build(s, array, maxLength); + s.Append(")"); + return s.ToString(); + } + + static void Build(StringBuilder s, NDArray array, int maxLength) + { + var shape = array.shape; + + if (shape.Length == 1) + { + s.Append("["); + s.Append(Render(array)); + s.Append("]"); + return; + } + + var len = shape[0]; + s.Append("["); + + if (len <= maxLength) + { + for (int i = 0; i < len; i++) + { + Build(s, array[i], maxLength); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + } + else + { + for (int i = 0; i < maxLength / 2; i++) + { + Build(s, array[i], maxLength); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + + s.Append(" ... "); + s.AppendLine(); + + for (int i = (int)len - maxLength / 2; i < len; i++) + { + Build(s, array[i], maxLength); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + } + + s.Append("]"); + } + + static string Render(NDArray array) + { + if (array.buffer == IntPtr.Zero) + return ""; + + var dtype = array.dtype; + var shape = array.shape; + + if (dtype == TF_DataType.TF_STRING) + { + if (array.rank == 0) + return "'" + string.Join(string.Empty, array.StringBytes()[0] + .Take(256) + .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; + else + return $"'{string.Join("', '", array.StringData().Take(25))}'"; + } + else if (dtype == TF_DataType.TF_VARIANT) + { + return ""; + } + else if (dtype == TF_DataType.TF_RESOURCE) + { + return ""; + } + else + { + return dtype switch + { + TF_DataType.TF_BOOL => Render(array.ToArray(), array.shape), + TF_DataType.TF_INT8 => Render(array.ToArray(), array.shape), + TF_DataType.TF_INT32 => Render(array.ToArray(), array.shape), + TF_DataType.TF_INT64 => Render(array.ToArray(), array.shape), + TF_DataType.TF_UINT64 => Render(array.ToArray(), array.shape), + TF_DataType.TF_FLOAT => Render(array.ToArray(), array.shape), + TF_DataType.TF_DOUBLE => Render(array.ToArray(), array.shape), + _ => Render(array.ToArray(), array.shape) + }; + } + } + + static string Render(T[] array, Shape shape) + { + if (array == null) + return ""; + + if (array.Length == 0) + return ""; + + if (shape.IsScalar) + return array[0].ToString(); + + var display = ""; + if (array.Length <= 10) + display += string.Join(", ", array); + else + display += string.Join(", ", array.Take(5)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 5)); + return display; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs new file mode 100644 index 000000000..b4add5086 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Numerics; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray any(NDArray a, Axis axis = null) => new NDArray(a.ToArray().Any(x => x)); + [AutoNumPy] + public static NDArray logical_or(NDArray x1, NDArray x2) => new NDArray(tf.logical_or(x1, x2)); + + [AutoNumPy] + public static NDArray logical_and(NDArray x1, NDArray x2) => new NDArray(tf.logical_and(x1, x2)); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs new file mode 100644 index 000000000..4cad36e0b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Globalization; +using System.Numerics; +using System.Text; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray argmax(NDArray a, Axis? axis = null) + => new NDArray(math_ops.argmax(a, axis ?? 0)); + + [AutoNumPy] + public static NDArray argmin(NDArray a, Axis? axis = null) + => new NDArray(math_ops.argmin(a, axis ?? 0)); + + [AutoNumPy] + public static NDArray argsort(NDArray a, Axis? axis = null) + => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); + + [AutoNumPy] + public static (NDArray, NDArray) unique(NDArray a) + { + var(u, indice) = array_ops.unique(a); + return (new NDArray(u), new NDArray(indice)); + } + + [AutoNumPy] + public static void shuffle(NDArray x) => np.random.shuffle(x); + + /// + /// Sorts a ndarray + /// + /// + /// + /// The axis along which to sort. The default is -1, which sorts the last axis. + /// + /// + /// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`) + /// + /// + /// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`. + /// + [AutoNumPy] + public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING") + => new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction)); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs new file mode 100644 index 000000000..bce16ec9f --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Numerics; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); + + [AutoNumPy] + public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); + + [AutoNumPy] + public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) + => tf.numpy.average(a, axis: axis, weights: weights, returned: returned); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs b/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs new file mode 100644 index 000000000..35356603b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs @@ -0,0 +1,19 @@ +using System; +using System.Text; + +namespace Tensorflow.NumPy +{ + internal class NumPyUtils + { + public static TF_DataType GetResultType(params TF_DataType[] dtypes) + { + var resultDType = dtypes[0]; + for(int i = 1; i < dtypes.Length; i++) + { + if (dtypes[i].get_datatype_size() > resultDType.get_datatype_size()) + resultDType = dtypes[i]; + } + return resultDType; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs new file mode 100644 index 000000000..5e2574170 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Numerics; +using System.Text; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray concatenate((NDArray, NDArray) tuple, int axis = 0) + => new NDArray(array_ops.concat(new[] { tuple.Item1, tuple.Item2 }, axis)); + + [AutoNumPy] + public static NDArray concatenate(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.concat(arrays, axis)); + + [AutoNumPy] + public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); + + [AutoNumPy] + public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1)); + + [AutoNumPy] + public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); + + [AutoNumPy] + public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis)); + + [AutoNumPy] + public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); + + [AutoNumPy] + public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis)); + + [AutoNumPy] + public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis)); + + [AutoNumPy] + public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis)); + + [AutoNumPy] + public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs new file mode 100644 index 000000000..2559638b3 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Numerics; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray cos(NDArray x) => new NDArray(math_ops.cos(x)); + + [AutoNumPy] + public static NDArray exp(NDArray x) => new NDArray(tf.exp(x)); + + [AutoNumPy] + public static NDArray floor(NDArray x) => new NDArray(math_ops.floor(x)); + + [AutoNumPy] + public static NDArray log(NDArray x) => new NDArray(tf.log(x)); + + [AutoNumPy] + public static NDArray mean(NDArray x) => new NDArray(math_ops.reduce_mean(x)); + + [AutoNumPy] + public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2)); + + [AutoNumPy] + //public static NDArray maximum(NDArray x1, NDArray x2) => new NDArray(tf.maximum(x1, x2)); + public static NDArray maximum(NDArray x1, NDArray x2, int? axis = null) + { + var maxValues = tf.maximum(x1, x2); + if (axis.HasValue) + { + maxValues = tf.reduce_max(maxValues, axis: axis.Value); + } + return new NDArray(maxValues); + } + + [AutoNumPy] + public static NDArray minimum(NDArray x1, NDArray x2) => new NDArray(tf.minimum(x1, x2)); + + [AutoNumPy] + public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false) + => new NDArray(tf.reduce_prod(array, axis: axis)); + + [AutoNumPy] + public static NDArray prod(params T[] array) where T : unmanaged + => new NDArray(tf.reduce_prod(new NDArray(array))); + [AutoNumPy] + public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) + { + //if axes mentioned + if (axes != null) + { + return new NDArray(tf.dot_prod(x1, x2, axes, name)); + } + if (x1.shape.ndim > 1) + { + x1 = GetFlattenArray(x1); + } + if (x2.shape.ndim > 1) + { + x2 = GetFlattenArray(x2); + } + //if axes not mentioned, default 0,0 + return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); + + } + [AutoNumPy] + public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); + [AutoNumPy] + public static NDArray square(NDArray x) => new NDArray(tf.square(x)); + + [AutoNumPy] + public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); + + [AutoNumPy] + public static NDArray sqrt(NDArray x) => new NDArray(tf.sqrt(x)); + + [AutoNumPy] + public static NDArray sum(NDArray x1, Axis? axis = null) => new NDArray(tf.math.sum(x1, axis)); + + [AutoNumPy] + public static NDArray add(NDArray x, NDArray y) => new NDArray(math_ops.add(x, y)); + + [AutoNumPy] + public static NDArray greater(NDArray x, NDArray y) => new NDArray(tf.greater(x, y)); + + [AutoNumPy] + public static NDArray less(NDArray x, NDArray y) => new NDArray(tf.less(x, y)); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs new file mode 100644 index 000000000..b349f5229 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Persistence.cs @@ -0,0 +1,60 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.IO; +using System.IO.Compression; + +namespace Tensorflow.NumPy; + +public partial class np +{ + [AutoNumPy] + public static NpzDictionary loadz(string file) + { + using var stream = new FileStream(file, FileMode.Open); + return new NpzDictionary(stream); + } + + public static void save(string file, NDArray nd) + { + using var stream = new FileStream(file, FileMode.Create); + NpyFormat.Save(nd, stream); + } + + public static void savez(string file, params NDArray[] nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream); + } + + public static void savez(string file, object nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream); + } + + public static void savez_compressed(string file, params NDArray[] nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream, CompressionLevel.Fastest); + } + + public static void savez_compressed(string file, object nds) + { + using var stream = new FileStream(file, FileMode.Create); + NpzFormat.Save(nds, stream, CompressionLevel.Fastest); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs new file mode 100644 index 000000000..10de0e7d2 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs @@ -0,0 +1,95 @@ +using System.IO; +using System.Runtime.InteropServices; + +namespace Tensorflow.NumPy; + +public class NpyFormat +{ + public static void Save(NDArray array, Stream stream, bool leaveOpen = true) + { + using var writer = new BinaryWriter(stream, Encoding.ASCII, leaveOpen: leaveOpen); + + string dtype = GetDtypeName(array, out var type, out var maxLength); + int[] shape = array.shape.as_int_list(); + var bytesWritten = (ulong)writeHeader(writer, dtype, shape); + stream.Write(array.ToByteArray(), 0, (int)array.bytesize); + } + + private static int writeHeader(BinaryWriter writer, string dtype, int[] shape) + { + // The first 6 bytes are a magic string: exactly "x93NUMPY" + + char[] magic = { 'N', 'U', 'M', 'P', 'Y' }; + writer.Write((byte)147); + writer.Write(magic); + writer.Write((byte)1); // major + writer.Write((byte)0); // minor; + + string tuple = shape.Length == 1 ? $"{shape[0]}," : String.Join(", ", shape.Select(i => i.ToString()).ToArray()); + string header = "{{'descr': '{0}', 'fortran_order': False, 'shape': ({1}), }}"; + header = string.Format(header, dtype, tuple); + int preamble = 10; // magic string (6) + 4 + + int len = header.Length + 1; // the 1 is to account for the missing \n at the end + int headerSize = len + preamble; + + int pad = 16 - (headerSize % 16); + header = header.PadRight(header.Length + pad); + header += "\n"; + headerSize = header.Length + preamble; + + if (headerSize % 16 != 0) + throw new Exception(""); + + writer.Write((ushort)header.Length); + for (int i = 0; i < header.Length; i++) + writer.Write((byte)header[i]); + + return headerSize; + } + + private static string GetDtypeName(NDArray array, out Type type, out int bytes) + { + type = array.dtype.as_system_dtype(); + + bytes = 1; + + if (type == typeof(string)) + { + throw new NotSupportedException(""); + } + else if (type == typeof(bool)) + { + bytes = 1; + } + else + { + bytes = Marshal.SizeOf(type); + } + + if (type == typeof(bool)) + return "|b1"; + else if (type == typeof(byte)) + return "|u1"; + else if (type == typeof(short)) + return " : IDisposable, IReadOnlyDictionary, ICollection + where T : class, + ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable +{ + Stream stream; + ZipArchive archive; + + bool disposedValue = false; + + Dictionary entries; + Dictionary arrays; + + + public NpzDictionary(Stream stream) + { + this.stream = stream; + this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true); + + this.entries = new Dictionary(); + foreach (var entry in archive.Entries) + this.entries[entry.FullName] = entry; + + this.arrays = new Dictionary(); + } + + + public IEnumerable Keys + { + get { return entries.Keys; } + } + + + public IEnumerable Values + { + get { return entries.Values.Select(OpenEntry); } + } + + public int Count + { + get { return entries.Count; } + } + + + public object SyncRoot + { + get { return ((ICollection)entries).SyncRoot; } + } + + + public bool IsSynchronized + { + get { return ((ICollection)entries).IsSynchronized; } + } + + public bool IsReadOnly + { + get { return true; } + } + + public T this[string key] + { + get { return OpenEntry(entries[key]); } + } + + private T OpenEntry(ZipArchiveEntry entry) + { + T array; + if (arrays.TryGetValue(entry.FullName, out array)) + return array; + + using (Stream s = entry.Open()) + { + array = Load_Npz(s); + arrays[entry.FullName] = array; + return array; + } + } + + protected virtual T Load_Npz(Stream s) + { + return np.Load(s); + } + + public bool ContainsKey(string key) + { + return entries.ContainsKey(key); + } + + public bool TryGetValue(string key, out T value) + { + value = default(T); + ZipArchiveEntry entry; + if (!entries.TryGetValue(key, out entry)) + return false; + value = OpenEntry(entry); + return true; + } + + public IEnumerator> GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); + } + + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return new KeyValuePair(entry.FullName, OpenEntry(entry)); + } + + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var entry in archive.Entries) + yield return OpenEntry(entry); + } + + public void CopyTo(Array array, int arrayIndex) + { + foreach (var v in this) + array.SetValue(v, arrayIndex++); + } + + public void CopyTo(T[] array, int arrayIndex) + { + foreach (var v in this) + array.SetValue(v, arrayIndex++); + } + + public void Add(T item) + { + throw new ReadOnlyException(); + } + + public void Clear() + { + throw new ReadOnlyException(); + } + + public bool Contains(T item) + { + foreach (var v in this) + if (Object.Equals(v.Value, item)) + return true; + return false; + } + + public bool Remove(T item) + { + throw new ReadOnlyException(); + } + + protected virtual void Dispose(bool disposing) + { + if (!disposedValue) + { + if (disposing) + { + archive.Dispose(); + stream.Dispose(); + } + + archive = null; + stream = null; + entries = null; + arrays = null; + + disposedValue = true; + } + } + + public void Dispose() + { + Dispose(true); + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs new file mode 100644 index 000000000..ba7868faa --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs @@ -0,0 +1,138 @@ +using System.IO; +using System.IO.Compression; + +namespace Tensorflow.NumPy; + +public class NpzDictionary +{ + Dictionary arrays = new Dictionary(); + + public NDArray this[string key] => arrays[key]; + + public NpzDictionary(Stream stream) + { + using var archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: false); + + foreach (var entry in archive.Entries) + { + arrays[entry.FullName] = OpenEntry(entry); + } + } + + private NDArray OpenEntry(ZipArchiveEntry entry) + { + if (arrays.TryGetValue(entry.FullName, out var array)) + return array; + + using var s = entry.Open(); + return (NDArray)LoadMatrix(s); + } + + public Array LoadMatrix(Stream stream) + { + using var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: false); + + if (!ParseReader(reader, out var bytes, out var type, out var shape)) + throw new FormatException(); + + Array matrix = Array.CreateInstance(type, shape); + + return ReadMatrix(reader, matrix, bytes, type, shape); + } + + bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape) + { + bytes = 0; + t = null; + shape = null; + + // The first 6 bytes are a magic string: exactly "x93NUMPY" + if (reader.ReadChar() != 63) return false; + if (reader.ReadChar() != 'N') return false; + if (reader.ReadChar() != 'U') return false; + if (reader.ReadChar() != 'M') return false; + if (reader.ReadChar() != 'P') return false; + if (reader.ReadChar() != 'Y') return false; + + byte major = reader.ReadByte(); // 1 + byte minor = reader.ReadByte(); // 0 + + if (major != 1 || minor != 0) + throw new NotSupportedException(); + + ushort len = reader.ReadUInt16(); + + string header = new string(reader.ReadChars(len)); + string mark = "'descr': '"; + int s = header.IndexOf(mark) + mark.Length; + int e = header.IndexOf("'", s + 1); + string type = header.Substring(s, e - s); + bool? isLittleEndian; + t = GetType(type, out bytes, out isLittleEndian); + + if (isLittleEndian.HasValue && isLittleEndian.Value == false) + throw new Exception(); + + mark = "'fortran_order': "; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(",", s + 1); + bool fortran = bool.Parse(header.Substring(s, e - s)); + + if (fortran) + throw new Exception(); + + mark = "'shape': ("; + s = header.IndexOf(mark) + mark.Length; + e = header.IndexOf(")", s + 1); + shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray(); + + return true; + } + + Type GetType(string dtype, out int bytes, out bool? isLittleEndian) + { + isLittleEndian = IsLittleEndian(dtype); + bytes = int.Parse(dtype.Substring(2)); + + string typeCode = dtype.Substring(1); + return typeCode switch + { + "b1" => typeof(bool), + "i1" => typeof(byte), + "i2" => typeof(short), + "i4" => typeof(int), + "i8" => typeof(long), + "u1" => typeof(byte), + "u2" => typeof(ushort), + "u4" => typeof(uint), + "u8" => typeof(ulong), + "f4" => typeof(float), + "f8" => typeof(double), + // typeCode.StartsWith("S") => typeof(string), + _ => throw new NotSupportedException() + }; + } + + bool? IsLittleEndian(string type) + { + return type[0] switch + { + '<' => true, + '>' => false, + '|' => null, + _ => throw new Exception() + }; + } + + Array ReadMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape) + { + int total = 1; + for (int i = 0; i < shape.Length; i++) + total *= shape[i]; + + var buffer = reader.ReadBytes(bytes * total); + System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length); + + return matrix; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs new file mode 100644 index 000000000..7470a1ea7 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzFormat.cs @@ -0,0 +1,37 @@ +using System.IO.Compression; +using System.IO; +using System; + +namespace Tensorflow.NumPy; + +public class NpzFormat +{ + public static void Save(NDArray[] arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false) + { + using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen); + for (int i = 0; i < arrays.Length; i++) + { + var entry = zip.CreateEntry($"arr_{i}", compression); + NpyFormat.Save(arrays[i], entry.Open(), leaveOpen); + } + } + + public static void Save(object arrays, Stream stream, CompressionLevel compression = CompressionLevel.NoCompression, bool leaveOpen = false) + { + var properties = arrays.GetType().GetProperties(); + using var zip = new ZipArchive(stream, ZipArchiveMode.Create, leaveOpen: leaveOpen); + for (int i = 0; i < properties.Length; i++) + { + var entry = zip.CreateEntry(properties[i].Name, compression); + var value = properties[i].GetValue(arrays); + if (value is NDArray nd) + { + NpyFormat.Save(nd, entry.Open(), leaveOpen); + } + else + { + throw new NotSupportedException("Please pass in NDArray."); + } + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs b/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs new file mode 100644 index 000000000..5dff6c16b --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Pickle/DTypePickleWarpper.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy.Pickle +{ + public class DTypePickleWarpper + { + TF_DataType dtype { get; set; } + public DTypePickleWarpper(TF_DataType dtype) + { + this.dtype = dtype; + } + public void __setstate__(object[] args) { } + public static implicit operator TF_DataType(DTypePickleWarpper dTypeWarpper) + { + return dTypeWarpper.dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs b/src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs new file mode 100644 index 000000000..160c7d4e9 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Pickle/DtypeConstructor.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using Razorvine.Pickle; + +namespace Tensorflow.NumPy.Pickle +{ + /// + /// + /// + [SuppressMessage("ReSharper", "InconsistentNaming")] + [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] + [SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")] + class DtypeConstructor : IObjectConstructor + { + public object construct(object[] args) + { + var typeCode = (string)args[0]; + TF_DataType dtype; + if (typeCode == "b1") + dtype = np.@bool; + else if (typeCode == "i1") + dtype = np.@byte; + else if (typeCode == "i2") + dtype = np.int16; + else if (typeCode == "i4") + dtype = np.int32; + else if (typeCode == "i8") + dtype = np.int64; + else if (typeCode == "u1") + dtype = np.ubyte; + else if (typeCode == "u2") + dtype = np.uint16; + else if (typeCode == "u4") + dtype = np.uint32; + else if (typeCode == "u8") + dtype = np.uint64; + else if (typeCode == "f4") + dtype = np.float32; + else if (typeCode == "f8") + dtype = np.float64; + else if (typeCode.StartsWith("S")) + dtype = np.@string; + else if (typeCode.StartsWith("O")) + dtype = np.@object; + else + throw new NotSupportedException(); + return new DTypePickleWarpper(dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs new file mode 100644 index 000000000..885f368c4 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayConstructor.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using Razorvine.Pickle; +using Razorvine.Pickle.Objects; + +namespace Tensorflow.NumPy.Pickle +{ + /// + /// Creates multiarrays of objects. Returns a primitive type multiarray such as int[][] if + /// the objects are ints, etc. + /// + [SuppressMessage("ReSharper", "InconsistentNaming")] + [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] + [SuppressMessage("ReSharper", "MemberCanBeMadeStatic.Global")] + public class MultiArrayConstructor : IObjectConstructor + { + public object construct(object[] args) + { + if (args.Length != 3) + throw new InvalidArgumentError($"Invalid number of arguments in MultiArrayConstructor._reconstruct. Expected three arguments. Given {args.Length} arguments."); + + var types = (ClassDictConstructor)args[0]; + if (types.module != "numpy" || types.name != "ndarray") + throw new RuntimeError("_reconstruct: First argument must be a sub-type of ndarray"); + + var arg1 = (object[])args[1]; + var dims = new int[arg1.Length]; + for (var i = 0; i < arg1.Length; i++) + { + dims[i] = (int)arg1[i]; + } + var shape = new Shape(dims); + + TF_DataType dtype; + string identifier; + if (args[2].GetType() == typeof(string)) + identifier = (string)args[2]; + else + identifier = Encoding.UTF8.GetString((byte[])args[2]); + switch (identifier) + { + case "u": dtype = np.uint32; break; + case "c": dtype = np.complex_; break; + case "f": dtype = np.float32; break; + case "b": dtype = np.@bool; break; + default: throw new NotImplementedException($"Unsupported data type: {args[2]}"); + } + return new MultiArrayPickleWarpper(shape, dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs new file mode 100644 index 000000000..af8d1ecc2 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Pickle/MultiArrayPickleWarpper.cs @@ -0,0 +1,119 @@ +using Newtonsoft.Json.Linq; +using Serilog.Debugging; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy.Pickle +{ + public class MultiArrayPickleWarpper + { + public Shape reconstructedShape { get; set; } + public TF_DataType reconstructedDType { get; set; } + public NDArray reconstructedNDArray { get; set; } + public Array reconstructedMultiArray { get; set; } + public MultiArrayPickleWarpper(Shape shape, TF_DataType dtype) + { + reconstructedShape = shape; + reconstructedDType = dtype; + } + public void __setstate__(object[] args) + { + if (args.Length != 5) + throw new InvalidArgumentError($"Invalid number of arguments in NDArray.__setstate__. Expected five arguments. Given {args.Length} arguments."); + + var version = (int)args[0]; // version + + var arg1 = (object[])args[1]; + var dims = new int[arg1.Length]; + for (var i = 0; i < arg1.Length; i++) + { + dims[i] = (int)arg1[i]; + } + var _ShapeLike = new Shape(dims); // shape + + TF_DataType _DType_co = (DTypePickleWarpper)args[2]; // DType + + var F_continuous = (bool)args[3]; // F-continuous + if (F_continuous) + throw new InvalidArgumentError("Fortran Continuous memory layout is not supported. Please use C-continuous layout or check the data format."); + + var data = args[4]; // Data + /* + * If we ever need another pickle format, increment the version + * number. But we should still be able to handle the old versions. + */ + if (version < 0 || version > 4) + throw new ValueError($"can't handle version {version} of numpy.dtype pickle"); + + // TODO: Implement the missing details and checks from the official Numpy C code here. + // https://github.com/numpy/numpy/blob/2f0bd6e86a77e4401d0384d9a75edf9470c5deb6/numpy/core/src/multiarray/descriptor.c#L2761 + + if (data.GetType() == typeof(ArrayList)) + { + Reconstruct((ArrayList)data); + } + else + throw new NotImplementedException(""); + } + private void Reconstruct(ArrayList arrayList) + { + int ndim = 1; + var subArrayList = arrayList; + while (subArrayList.Count > 0 && subArrayList[0] != null && subArrayList[0].GetType() == typeof(ArrayList)) + { + subArrayList = (ArrayList)subArrayList[0]; + ndim += 1; + } + var type = subArrayList[0].GetType(); + if (type == typeof(int)) + { + if (ndim == 1) + { + int[] list = (int[])arrayList.ToArray(typeof(int)); + Shape shape = new Shape(new int[] { arrayList.Count }); + reconstructedMultiArray = list; + reconstructedNDArray = new NDArray(list, shape); + } + if (ndim == 2) + { + int secondDim = 0; + foreach (ArrayList subArray in arrayList) + { + secondDim = subArray.Count > secondDim ? subArray.Count : secondDim; + } + int[,] list = new int[arrayList.Count, secondDim]; + for (int i = 0; i < arrayList.Count; i++) + { + var subArray = (ArrayList?)arrayList[i]; + if (subArray == null) + throw new NullReferenceException(""); + for (int j = 0; j < subArray.Count; j++) + { + var element = subArray[j]; + if (element == null) + throw new NoNullAllowedException("the element of ArrayList cannot be null."); + list[i, j] = (int)element; + } + } + Shape shape = new Shape(new int[] { arrayList.Count, secondDim }); + reconstructedMultiArray = list; + reconstructedNDArray = new NDArray(list, shape); + } + if (ndim > 2) + throw new NotImplementedException("can't handle ArrayList with more than two dimensions."); + } + else + throw new NotImplementedException(""); + } + public static implicit operator Array(MultiArrayPickleWarpper arrayWarpper) + { + return arrayWarpper.reconstructedMultiArray; + } + public static implicit operator NDArray(MultiArrayPickleWarpper arrayWarpper) + { + return arrayWarpper.reconstructedNDArray; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs new file mode 100644 index 000000000..80f056fe5 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -0,0 +1,157 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class ShapeHelper + { + public static long GetSize(Shape shape) + { + if (shape.IsNull) + return 0; + + // scalar + if (shape.ndim == 0) + return 1; + + var computed = 1L; + for (int i = 0; i < shape.ndim; i++) + { + var val = shape.dims[i]; + if (val == 0) + return 0; + else if (val < 0) + continue; + computed *= val; + } + + return computed; + } + + public static long[] GetStrides(Shape shape) + { + var strides = new long[shape.ndim]; + + if (shape.ndim == 0) + return strides; + + strides[strides.Length - 1] = 1; + for (int idx = strides.Length - 1; idx >= 1; idx--) + strides[idx - 1] = strides[idx] * shape.dims[idx]; + + return strides; + } + + public static Shape GetShape(Shape shape1, params Slice[] slices) + { + var new_dims = shape1.dims.ToArray(); + slices = SliceHelper.AlignWithShape(shape1, slices); + + for (int i = 0; i < shape1.dims.Length; i++) + { + Slice slice = slices[i]; + if (slice.Equals(Slice.All)) + new_dims[i] = shape1.dims[i]; + else if (slice.IsIndex) + new_dims[i] = 1; + else // range + new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0); + } + + // strip first dim if is index + var return_dims = new List(); + for (int i = 0; i< new_dims.Length; i++) + { + if (slices[i].IsIndex) + continue; + return_dims.add(new_dims[i]); + } + + return new Shape(return_dims.ToArray()); + } + + public static Shape AlignWithShape(Shape shape, Shape preShape) + { + if (shape.ndim == preShape.ndim) + return preShape; + + var newShape = shape.dims.Select(x => 1L).ToArray(); + if (preShape.IsScalar) + return new Shape(newShape); + + for (int i = 0; i < preShape.ndim; i++) + { + newShape[i + shape.ndim - preShape.ndim] = preShape[i]; + } + + return new Shape(newShape); + } + + public static bool Equals(Shape shape, object target) + { + if (shape is null && target is null) + return true; + else if (shape is null && target is not null) + return false; + else if (shape is not null && target is null) + return false; + + switch (target) + { + case Shape shape1: + if (shape.ndim == -1 && shape1.ndim == -1) + return false; + else if (shape.ndim != shape1.ndim) + return false; + return Enumerable.SequenceEqual(shape1.dims, shape.dims); + case long[] shape2: + if (shape.ndim != shape2.Length) + return false; + return Enumerable.SequenceEqual(shape.dims, shape2); + case int[] shape3: + if (shape.ndim != shape3.Length) + return false; + return Enumerable.SequenceEqual(shape.as_int_list(), shape3); + case List shape4: + if (shape.ndim != shape4.Count) + return false; + return Enumerable.SequenceEqual(shape.dims, shape4); + case List shape5: + if (shape.ndim != shape5.Count) + return false; + return Enumerable.SequenceEqual(shape.as_int_list(), shape5); + default: + return false; + } + } + + public static string ToString(Shape shape) + { + return shape.ndim switch + { + -1 => "", + 0 => "()", + 1 => $"({shape.dims[0].ToString().Replace("-1", "None")},)", + _ => $"({string.Join(", ", shape.dims).Replace("-1", "None")})" + }; + } + + public static long GetOffset(Shape shape, params int[] indices) + { + if (shape.ndim == 0 && indices.Length == 1) + return indices[0]; + + long offset = 0; + var strides = shape.strides; + for (int i = 0; i < indices.Length; i++) + offset += strides[i] * indices[i]; + + if (offset < 0) + throw new NotImplementedException(""); + + return offset; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs new file mode 100644 index 000000000..30a14c9ea --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class SliceHelper + { + public static Slice[] AlignWithShape(Shape shape, Slice[] slices) + { + var ndim = shape.ndim; + if (ndim == slices.Length) + return slices; + + // align slices + var new_slices = new List(); + var slice_index = 0; + + for (int i = 0; i < ndim; i++) + { + if (slice_index > slices.Length - 1) + { + new_slices.Add(Slice.All); + continue; + } + + if (slices[slice_index] == Slice.All) + { + new_slices.Add(Slice.All); + for (int j = 0; j < ndim - slices.Length; j++) + { + new_slices.Add(Slice.All); + i++; + } + } + else + { + new_slices.Add(slices[slice_index]); + } + slice_index++; + } + + return new_slices.ToArray(); + } + + public static bool AreAllIndex(Slice[] slices, out int[] indices) + { + indices = new int[slices.Length]; + for (int i = 0; i< slices.Length; i++) + { + indices[i] = slices[i].Start ?? 0; + if (!slices[i].IsIndex) + return false; + } + return true; + } + + public static bool IsContinuousBlock(Slice[] slices, int ndim) + { + for (int i = ndim + 1; i < slices.Length; i++) + { + if (slices[i].Equals(Slice.All)) + continue; + return false; + } + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs b/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs new file mode 100644 index 000000000..87c31a214 --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/IMemoryBlock.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy +{ + public interface IMemoryBlock + { + /// + /// The size of a single item stored in . + /// + /// Equivalent to extension. + int ItemLength { get; } + + /// + /// The start address of this memory block. + /// + unsafe void* Address { get; } + + /// + /// How many items are stored in . + /// + /// Not to confuse with + long Count { get; } + + /// + /// How many bytes are stored in this memory block. + /// + /// Calculated by * + long BytesLength { get; } + + /// + /// The of the type stored inside this memory block. + /// + TF_DataType TypeCode { get; } + } +} diff --git a/src/TensorFlowNET.Core/Numpy/IteratorType.cs b/src/TensorFlowNET.Core/Numpy/IteratorType.cs new file mode 100644 index 000000000..ab6345abb --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/IteratorType.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy +{ + public enum IteratorType + { + Scalar, + Vector, + Matrix, + Tensor + } +} diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs new file mode 100644 index 000000000..af7e94c85 --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NDArray + { + protected NDArray() { } + public NDArray(bool value) : base(value) => NewEagerTensorHandle(); + public NDArray(byte value) : base(value) => NewEagerTensorHandle(); + public NDArray(short value) : base(value) => NewEagerTensorHandle(); + public NDArray(int value) : base(value) => NewEagerTensorHandle(); + public NDArray(long value) : base(value) => NewEagerTensorHandle(); + public NDArray(float value) : base(value) => NewEagerTensorHandle(); + public NDArray(double value) : base(value) => NewEagerTensorHandle(); + + public NDArray(Array value, Shape? shape = null) : base(value, shape) + => NewEagerTensorHandle(); + + public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) : base(shape, dtype: dtype) + => NewEagerTensorHandle(); + + public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype) + => NewEagerTensorHandle(); + + public NDArray(int[] value, Shape? shape = null) : base(value, shape) + => NewEagerTensorHandle(); + + public NDArray(long[] value, Shape? shape = null) : base(value, shape) + => NewEagerTensorHandle(); + + public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype) + => NewEagerTensorHandle(); + + public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone) + { + if (_handle is null) + { + tensor = tf.get_default_session().eval(tensor); + _handle = tensor.Handle; + } + + NewEagerTensorHandle(); + } + + public static NDArray Scalar(T value) where T : unmanaged + => value switch + { + bool val => new NDArray(val), + byte val => new NDArray(val), + int val => new NDArray(val), + long val => new NDArray(val), + float val => new NDArray(val), + double val => new NDArray(val), + _ => throw new NotImplementedException("") + }; + + /// + /// Reuse the existing memory instead of copying it. + /// + /// + /// + /// + /// + protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator) + { + _handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero); + tensor_util.DangerousManuallySetTensorDType(_handle, dtype); + NewEagerTensorHandle(); + } + + void NewEagerTensorHandle() + { + if (_handle is not null) + { + _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs new file mode 100644 index 000000000..6e4c6b32c --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -0,0 +1,56 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class NDArray : Tensor, IEnumerable + { + public IntPtr data => TensorDataPointer; + + [AutoNumPy] + public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape)); + [AutoNumPy] + public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype)); + public NDArray ravel() => throw new NotImplementedException(""); + public void shuffle(NDArray nd) => np.random.shuffle(nd); + + public unsafe Array ToMultiDimArray() where T : unmanaged + => NDArrayConverter.ToMultiDimArray(this); + + public byte[] ToByteArray() => BufferToArray(); + public override string ToString() => NDArrayRender.ToString(this); + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < dims[0]; i++) + yield return this[i]; + } + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + public static explicit operator NDArray(Array array) + => new NDArray(array); + } +} diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs new file mode 100644 index 000000000..409e5e310 --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -0,0 +1,114 @@ +using System.IO; +using static Tensorflow.Binding; + +namespace Tensorflow.NumPy +{ + public partial class np + { + [AutoNumPy] + public static NDArray array(Array data, TF_DataType? dtype = null) + { + var nd = new NDArray(data); + return dtype == null ? nd : nd.astype(dtype.Value); + } + + [AutoNumPy] + public static NDArray array(params T[] data) + where T : unmanaged => new NDArray(data); + + [AutoNumPy] + public static NDArray arange(T end) + where T : unmanaged => new NDArray(tf.range(default(T), limit: end)); + + [AutoNumPy] + public static NDArray arange(T start, T? end = null, T? step = null) + where T : unmanaged => new NDArray(tf.range(start, limit: end, delta: step)); + + [AutoNumPy] + public static NDArray empty(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.zeros(shape, dtype: dtype)); + + [AutoNumPy] + public static NDArray eye(int N, int? M = null, int k = 0, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => tf.numpy.eye(N, M: M, k: k, dtype: dtype); + + [AutoNumPy] + public static NDArray full(Shape shape, T fill_value) + where T : unmanaged => new NDArray(tf.fill(tf.constant(shape), fill_value)); + + [AutoNumPy] + public static NDArray full_like(NDArray x, T fill_value, TF_DataType? dtype = null, Shape shape = null) + where T : unmanaged => new NDArray(array_ops.fill(x.shape, constant_op.constant(fill_value))); + + [AutoNumPy] + public static NDArray frombuffer(byte[] bytes, Shape shape, TF_DataType dtype) + => tf.numpy.frombuffer(bytes, shape, dtype); + + [AutoNumPy] + public static NDArray frombuffer(byte[] bytes, string dtype) + => tf.numpy.frombuffer(bytes, dtype); + + [AutoNumPy] + public static NDArray linspace(T start, T stop, int num = 50, bool endpoint = true, bool retstep = false, + TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) + where T : unmanaged => tf.numpy.linspace(start, stop, + num: num, + endpoint: endpoint, + retstep: retstep, + dtype: dtype, + axis: axis); + + [AutoNumPy] + public static NDArray load(string file) => tf.numpy.load(file); + + [AutoNumPy] + public static T Load(string path) + where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + { + using (var stream = new FileStream(path, FileMode.Open)) + return Load(stream); + } + + [AutoNumPy] + public static T Load(Stream stream) + where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + => tf.numpy.Load(stream); + + [AutoNumPy] + public static Array LoadMatrix(Stream stream) => tf.numpy.LoadMatrix(stream); + + [AutoNumPy] + public static NpzDictionary Load_Npz(byte[] bytes) + where T : class, IList, ICloneable, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + => Load_Npz(new MemoryStream(bytes)); + + [AutoNumPy] + public static NpzDictionary Load_Npz(Stream stream) + where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + => new NpzDictionary(stream); + + [AutoNumPy] + public static (NDArray, NDArray) meshgrid(T x, T y, bool copy = true, bool sparse = false) + => tf.numpy.meshgrid(new[] { x, y }, copy: copy, sparse: sparse); + + [AutoNumPy] + public static NDArray ndarray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.zeros(shape, dtype: dtype)); + + [AutoNumPy] + public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.ones(shape, dtype: dtype)); + + [AutoNumPy] + public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) + => new NDArray(tf.ones_like(a, dtype: dtype)); + + [AutoNumPy] + public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + => new NDArray(tf.zeros(shape, dtype: dtype)); + + [AutoNumPy] + public static NDArray zeros_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) + => new NDArray(tf.zeros_like(a, dtype: dtype)); + } +} diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs new file mode 100644 index 000000000..fee2d63fc --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -0,0 +1,73 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.NumPy; + +public partial class np +{ + /// + /// A convenient alias for None, useful for indexing arrays. + /// + /// https://docs.scipy.org/doc/numpy-1.17.0/reference/arrays.indexing.html



https://stackoverflow.com/questions/42190783/what-does-three-dots-in-python-mean-when-indexing-what-looks-like-a-number
+ public static readonly Slice newaxis = new Slice(null, null, 1) { IsNewAxis = true }; + + // https://docs.scipy.org/doc/numpy-1.16.0/user/basics.types.html + #region data type + public static readonly TF_DataType @bool = TF_DataType.TF_BOOL; + public static readonly TF_DataType @char = TF_DataType.TF_INT8; + public static readonly TF_DataType @byte = TF_DataType.TF_INT8; + public static readonly TF_DataType uint8 = TF_DataType.TF_UINT8; + public static readonly TF_DataType ubyte = TF_DataType.TF_UINT8; + public static readonly TF_DataType int16 = TF_DataType.TF_INT16; + public static readonly TF_DataType uint16 = TF_DataType.TF_UINT16; + public static readonly TF_DataType int32 = TF_DataType.TF_INT32; + public static readonly TF_DataType uint32 = TF_DataType.TF_UINT32; + public static readonly TF_DataType int64 = TF_DataType.TF_INT64; + public static readonly TF_DataType uint64 = TF_DataType.TF_UINT64; + public static readonly TF_DataType float32 = TF_DataType.TF_FLOAT; + public static readonly TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @double = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType @decimal = TF_DataType.TF_DOUBLE; + public static readonly TF_DataType complex_ = TF_DataType.TF_COMPLEX; + public static readonly TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static readonly TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + public static readonly TF_DataType @string = TF_DataType.TF_STRING; + public static readonly TF_DataType @object = TF_DataType.TF_VARIANT; + #endregion + + public static double nan => double.NaN; + public static double NAN => double.NaN; + public static double NaN => double.NaN; + public static double pi => Math.PI; + public static double e => Math.E; + public static double euler_gamma => 0.57721566490153286060651209008240243d; + public static double inf => double.PositiveInfinity; + public static double infty => double.PositiveInfinity; + public static double Inf => double.PositiveInfinity; + public static double NINF => double.NegativeInfinity; + public static double PINF => double.PositiveInfinity; + public static double Infinity => double.PositiveInfinity; + public static double infinity => double.PositiveInfinity; + + public static bool array_equal(NDArray a, NDArray b) + => a.Equals(b); + + public static bool allclose(NDArray a, NDArray b, double rtol = 1.0E-5, double atol = 1.0E-8, + bool equal_nan = false) => throw new NotImplementedException(""); + + public static RandomizedImpl random = new RandomizedImpl(); + public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); +} diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs new file mode 100644 index 000000000..cbbf66b44 --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -0,0 +1,288 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Saving.Common; +using Tensorflow.NumPy; + +namespace Tensorflow +{ + [JsonConverter(typeof(CustomizedShapeJsonConverter))] + public class Shape : INestStructure + { + public int ndim => _dims == null ? -1 : _dims.Length; + long[] _dims; + public long[] dims => _dims; + public int rank => ndim; + long[] _strides; + public long[] strides + { + get + { + _strides = _strides ?? ShapeHelper.GetStrides(this); + return _strides; + } + } + + public NestType NestType => NestType.List; + + public int ShallowNestedCount => ndim; + /// + /// The total item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. + /// + public int TotalNestedCount => ndim; + + public IEnumerable Flatten() => dims.Select(x => x); + + public INestStructure MapStructure(Func func) + { + return new NestList(dims.Select(x => func(x))); + } + + public Nest AsNest() + { + return new NestList(Flatten()).AsNest(); + } + + #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges + public int Length => ndim; + public long[] Slice(int start, int length) + { + var slice = new long[length]; + Array.Copy(_dims, start, slice, 0, length); + return slice; + } + #endregion + + private Shape() + { + } + + public Shape(TensorShapeProto proto) + { + _dims = proto.Dim.Select(x => x.Size).ToArray(); + } + + public void Deconstruct(out long h, out long w) + { + h = dims[0]; + w = dims[1]; + } + + public Shape(params int[] dims) + => _dims = dims?.Select(x => Convert.ToInt64(x))?.ToArray(); + + public Shape(params long[] dims) + => _dims = dims; + + public static implicit operator Shape(int dims) + => new Shape(dims); + + public static implicit operator Shape(long[] dims) + => dims == null ? null : new Shape(dims); + + public static implicit operator Shape(int[] dims) + => dims == null ? null : new Shape(dims); + + public static implicit operator Shape((int, int) dims) + => new Shape(dims.Item1, dims.Item2); + + public static implicit operator Shape((long, long) dims) + => new Shape(dims.Item1, dims.Item2); + + public static implicit operator Shape((int, int, int) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3); + + public static implicit operator Shape((long, long, long) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3); + + public static implicit operator Shape((int, int, int, int) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static implicit operator Shape((long, long, long, long) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static implicit operator Shape((int, int, int, int, int) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static implicit operator Shape((long, long, long, long, long) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static implicit operator int[](Shape shape) + => shape.dims.Select(x => (int)x).ToArray(); + + public static implicit operator long[](Shape shape) + => shape.dims; + + public static implicit operator Tensor(Shape shape) + => constant_op.constant(shape); + + public bool IsEmpty => size == 0; + + public bool IsScalar => ndim == 0; + public bool IsNull => _dims == null; + + public bool IsFullyDefined => ndim > -1 && dims.Count(x => x < 1) == 0; + + public static Shape Scalar => new Shape(new long[0]); + public static Shape Null => new Shape(); + + public long this[int n] + { + get => n < 0 ? dims[ndim + n] : dims[n]; + set => dims[n] = value; + } + + public Shape this[Slice slice] + { + get + { + if (!slice.Stop.HasValue) + slice.Stop = dims.Length - slice.Start + 1; + + if (slice.Start.HasValue == false || slice.Length.HasValue == false) + throw new ArgumentException("Slice must has Start and Length."); + + return new Shape(dims.Skip(slice.Start.Value) + .Take(slice.Length.Value) + .ToArray()); + } + } + + /// + /// Returns the size this shape represents. + /// + public long size => ShapeHelper.GetSize(this); + + public bool is_compatible_with(Shape shape2) + { + if (dims != null && shape2.dims != null) + { + if (dims.Contains(-1) || shape2.dims.Contains(-1)) + return true; + + if (size != shape2.size) + return false; + } + + return true; + } + + public Shape with_rank_at_least(int rank) + { + if (ndim < rank) + throw new ValueError($"Shape {this} must have rank at least {rank}"); + else + return this; + } + + public Shape with_rank(int rank) + { + return merge_with(unknown_shape(rank: rank)); + } + + /// + /// Returns an unknown Shape, optionally with a known rank. + /// + /// + /// + public Shape unknown_shape(int rank = -1) + { + if (rank == -1) + return Shape.Null; + else + return new Shape(Enumerable.Repeat(-1L, rank).ToArray()); + } + + public Shape concatenate(long[] other) + { + return concatenate(new Shape(other)); + } + + /// + /// Returns the concatenation of the dimension in `self` and `other`. + /// + /// + /// + public Shape concatenate(Shape other) + { + var otherShape = other; + + if (ndim < 0 || otherShape.ndim < 0) + return Shape.Null; + else + { + var concatenate_dims = new long[ndim + otherShape.ndim]; + for (int i = 0; i < ndim; i++) + concatenate_dims[i] = dims[i]; + + for (int i = 0; i < otherShape.ndim; i++) + concatenate_dims[ndim + i] = otherShape.dims[i]; + + return new Shape(concatenate_dims); + } + } + + /// + /// Returns a `Shape` combining the information in `self` and `other`. + /// + /// + /// + public Shape merge_with(Shape other) + { + if (dims == null) + return other; + + var new_dims = new List(); + + foreach (var i in Enumerable.Range(0, ndim)) + { + var dim = new Dimension(dims[i]); + var merged = dim.merge_with(new Dimension(other.dims[i])); + new_dims.Add(merged.value); + } + + return new Shape(new_dims.ToArray()); + } + + public int[] as_int_list() + { + return _dims.Select(x => (int)x).ToArray(); + } + + public void assert_has_rank(int rank) + { + if (rank != ndim) + throw new ValueError(String.Format("Shape {0} must have rank {1}", ndim, rank)); + } + + public override bool Equals(object obj) => ShapeHelper.Equals(this, obj); + + public override string ToString() => ShapeHelper.ToString(this); + + public static bool operator ==(Shape a, Shape b) + => ShapeHelper.Equals(a, b); + + public static bool operator !=(Shape a, Shape b) + => !ShapeHelper.Equals(a, b); + } +} diff --git a/src/TensorFlowNET.Core/Numpy/Slice.cs b/src/TensorFlowNET.Core/Numpy/Slice.cs new file mode 100644 index 000000000..676ec5e93 --- /dev/null +++ b/src/TensorFlowNET.Core/Numpy/Slice.cs @@ -0,0 +1,295 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; + +namespace Tensorflow +{ + ///

+ /// NDArray can be indexed using slicing

+ /// A slice is constructed by start:stop:step notation

+ ///

+ /// Examples:

+ ///

+ /// a[start:stop] # items start through stop-1

+ /// a[start:] # items start through the rest of the array

+ /// a[:stop] # items from the beginning through stop-1

+ ///

+ /// The key point to remember is that the :stop value represents the first value that is not

+ /// in the selected slice. So, the difference between stop and start is the number of elements

+ /// selected (if step is 1, the default).

+ ///

+ /// There is also the step value, which can be used with any of the above:

+ /// a[:] # a copy of the whole array

+ /// a[start:stop:step] # start through not past stop, by step

+ ///

+ /// The other feature is that start or stop may be a negative number, which means it counts

+ /// from the end of the array instead of the beginning. So:

+ /// a[-1] # last item in the array

+ /// a[-2:] # last two items in the array

+ /// a[:-2] # everything except the last two items

+ /// Similarly, step may be a negative number:

+ ///

+ /// a[::- 1] # all items in the array, reversed

+ /// a[1::- 1] # the first two items, reversed

+ /// a[:-3:-1] # the last two items, reversed

+ /// a[-3::- 1] # everything except the last two items, reversed

+ ///

+ /// NumSharp is kind to the programmer if there are fewer items than

+ /// you ask for. For example, if you ask for a[:-2] and a only contains one element, you get an

+ /// empty list instead of an error.Sometimes you would prefer the error, so you have to be aware

+ /// that this may happen.

+ ///

+ /// Adapted from Greg Hewgill's answer on Stackoverflow: https://stackoverflow.com/questions/509211/understanding-slice-notation

+ ///

+ /// Note: special IsIndex == true

+ /// It will pick only a single value at Start in this dimension effectively reducing the Shape of the sliced matrix by 1 dimension.

+ /// It can be used to reduce an N-dimensional array/matrix to a (N-1)-dimensional array/matrix

+ ///

+ /// Example:

+ /// a=[[1, 2], [3, 4]]

+ /// a[:, 1] returns the second column of that 2x2 matrix as a 1-D vector

+ ///
+ public class Slice + { + /// + /// return : for this dimension + /// + public static readonly Slice All = new Slice(null, null); + + /// + /// return 0:0 for this dimension + /// + public static readonly Slice None = new Slice(0, 0, 1); + + /// + /// fill up the missing dimensions with : at this point, corresponds to ... + /// + public static readonly Slice Ellipsis = new Slice(0, 0, 1) { IsEllipsis = true }; + + /// + /// insert a new dimension at this point + /// + public static readonly Slice NewAxis = new Slice(0, 0, 1) { IsNewAxis = true }; + + /// + /// return exactly one element at this dimension and reduce the shape from n-dim to (n-1)-dim + /// + /// + /// + public static Slice Index(int index) => new Slice(index, index + 1) { IsIndex = true }; + + ///// + ///// return multiple elements for this dimension specified by the given index array (or boolean mask array) + ///// + ///// + ///// + //[MethodImpl(MethodImplOptions.AggressiveInlining)] + //public static Slice Select(NDArray index_array_or_mask) => new Slice(null, null) { Selection=index_array_or_mask }; + + public int? Start; + public int? Stop; + public int Step; + public bool IsIndex; + public bool IsEllipsis; + public bool IsNewAxis; + + ///// + ///// Array of integer indices to select elements by index extraction or boolean values to select by masking the elements of the given dimension. + ///// + //public NDArray Selection = null; + + /// + /// Length of the slice. + /// + /// The length is not guaranteed to be known for i.e. a slice like ":". Make sure to check Start and Stop + /// for null before using it + /// + public int? Length => Stop - Start; + + /// + /// ndarray can be indexed using slicing + /// slice is constructed by start:stop:step notation + /// + /// Start index of the slice, null means from the start of the array + /// Stop index (first index after end of slice), null means to the end of the array + /// Optional step to select every n-th element, defaults to 1 + public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false) + { + Start = start; + Stop = stop; + Step = step; + IsIndex = isIndex; + } + + public Slice(string slice_notation) + { + Parse(slice_notation); + } + + /// + /// Parses Python array slice notation and returns an array of Slice objects + /// + public static Slice[] ParseSlices(string multi_slice_notation) + { + return Regex.Split(multi_slice_notation, @",\s*").Where(s => !string.IsNullOrWhiteSpace(s)).Select(token => new Slice(token)).ToArray(); + } + + /// + /// Creates Python array slice notation out of an array of Slice objects (mainly used for tests) + /// + public static string FormatSlices(params Slice[] slices) + { + return string.Join(",", slices.Select(s => s.ToString())); + } + + private void Parse(string slice_notation) + { + if (string.IsNullOrEmpty(slice_notation)) + throw new ArgumentException("Slice notation expected, got empty string or null"); + var match = Regex.Match(slice_notation, @"^\s*((?'start'[+-]?\s*\d+)?\s*:\s*(?'stop'[+-]?\s*\d+)?\s*(:\s*(?'step'[+-]?\s*\d+)?)?|(?'index'[+-]?\s*\d+)|(?'ellipsis'\.\.\.)|(?'newaxis'(np\.)?newaxis))\s*$"); + if (!match.Success) + throw new ArgumentException($"Invalid slice notation: '{slice_notation}'"); + if (match.Groups["ellipsis"].Success) + { + Start = 0; + Stop = 0; + Step = 1; + IsEllipsis = true; + return; + } + if (match.Groups["newaxis"].Success) + { + Start = 0; + Stop = 0; + Step = 1; + IsNewAxis = true; + return; + } + if (match.Groups["index"].Success) + { + if (!int.TryParse(Regex.Replace(match.Groups["index"].Value ?? "", @"\s+", ""), out var start)) + throw new ArgumentException($"Invalid value for index: '{match.Groups["index"].Value}'"); + Start = start; + Stop = start + 1; + Step = 1; // special case for dimensionality reduction by picking a single element + IsIndex = true; + return; + } + var start_string = Regex.Replace(match.Groups["start"].Value ?? "", @"\s+", ""); // removing spaces from match to be able to parse what python allows, like: "+ 1" or "- 9"; + var stop_string = Regex.Replace(match.Groups["stop"].Value ?? "", @"\s+", ""); + var step_string = Regex.Replace(match.Groups["step"].Value ?? "", @"\s+", ""); + + if (string.IsNullOrWhiteSpace(start_string)) + Start = null; + else + { + if (!int.TryParse(start_string, out var start)) + throw new ArgumentException($"Invalid value for start: {start_string}"); + Start = start; + } + + if (string.IsNullOrWhiteSpace(stop_string)) + Stop = null; + else + { + if (!int.TryParse(stop_string, out var stop)) + throw new ArgumentException($"Invalid value for start: {stop_string}"); + Stop = stop; + } + + if (string.IsNullOrWhiteSpace(step_string)) + Step = 1; + else + { + if (!int.TryParse(step_string, out var step)) + throw new ArgumentException($"Invalid value for start: {step_string}"); + Step = step; + } + } + + #region Equality comparison + + public static bool operator ==(Slice a, Slice b) + { + if (ReferenceEquals(a, b)) + return true; + + if (a is null || b is null) + return false; + + return a.Start == b.Start && a.Stop == b.Stop && a.Step == b.Step; + } + + public static bool operator !=(Slice a, Slice b) + { + return !(a == b); + } + + public override bool Equals(object obj) + { + if (obj == null) + return false; + + if (obj.GetType() != typeof(Slice)) + return false; + + var b = (Slice)obj; + return Start == b.Start && Stop == b.Stop && Step == b.Step; + } + + public override int GetHashCode() + { + return ToString().GetHashCode(); + } + + #endregion + + public override string ToString() + { + if (IsIndex) + return $"{Start ?? 0}"; + else if (IsNewAxis) + return "np.newaxis"; + else if (IsEllipsis) + return "..."; + var optional_step = Step == 1 ? "" : $":{Step}"; + return $"{(Start == 0 ? "" : Start.ToString())}:{(Stop == null ? "" : Stop.ToString())}{optional_step}"; + } + + // return the size of the slice, given the data dimension on this axis + // note: this works only with sanitized shapes! + public int GetSize() + { + var astep = Math.Abs(Step); + return (Math.Abs(Start.Value - Stop.Value) + (astep - 1)) / astep; + } + + #region Operators + + public static Slice operator ++(Slice a) + { + if (a.Start.HasValue) + a.Start++; + if (a.Stop.HasValue) + a.Stop++; + return a; + } + + public static Slice operator --(Slice a) + { + if (a.Start.HasValue) + a.Start--; + if (a.Stop.HasValue) + a.Stop--; + return a; + } + + public static implicit operator Slice(int index) => Slice.Index(index); + public static implicit operator Slice(string slice) => new Slice(slice); + //public static implicit operator Slice(NDArray selection) => Slice.Select(selection); + + #endregion + } +} diff --git a/src/TensorFlowNET.Core/Open.snk b/src/TensorFlowNET.Core/Open.snk new file mode 100644 index 000000000..22a3cbd25 Binary files /dev/null and b/src/TensorFlowNET.Core/Open.snk differ diff --git a/src/TensorFlowNET.Core/Operations/Activation/IActivation.cs b/src/TensorFlowNET.Core/Operations/Activation/IActivation.cs new file mode 100644 index 000000000..8da698da4 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Activation/IActivation.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Operations.Activation +{ + public interface IActivation + { + Tensor Activate(Tensor features, string name = null); + } +} diff --git a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs new file mode 100644 index 000000000..df679bef2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs @@ -0,0 +1,217 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations.Activation +{ + public class sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.sigmoid(x); + } + } + + public class tanh : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.tanh(x); + } + } + + public class leakyrelu : IActivation + { + private readonly float _alpha; + + public leakyrelu(float alpha = 0.3f) + { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + public class elu : IActivation + { + private readonly float _alpha; + + public elu(float alpha = 0.1f) + { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + var res = gen_ops.elu(x); + if (Math.Abs(_alpha - 0.1f) < 0.00001f) + { + return res; + } + + return array_ops.@where(x > 0, res, _alpha * res); + } + } + + public class softmax : IActivation + { + private readonly int _axis; + + /// Initializes a new instance of the class. + public softmax(int axis = -1) + { + _axis = axis; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.softmax(x, _axis); + } + } + + public class softplus : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softplus(x); + } + } + + public class softsign : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softsign(x); + } + } + + public class swish : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.multiply(x, tf.nn.sigmoid(x)); + } + } + + public class linear : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return x; + } + } + + + public class exponential : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.exp(x, name: name); + } + } + + + public class relu : IActivation + { + private readonly float _threshold; + private readonly float _alpha; + private readonly float? _maxValue; + + public relu(float threshold = 0f, float alpha = 0.2f, float? max_value = null) + { + _threshold = threshold; + _alpha = alpha; + _maxValue = max_value; + } + + public Tensor Activate(Tensor x, string name = null) + { + //based on keras/backend.py + if (Math.Abs(_alpha) > 0.000001f) + { + if (!_maxValue.HasValue && Math.Abs(_threshold) < 0.0001) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + Tensor negative_part; + if (Math.Abs(_threshold) > 0.000001f) + { + negative_part = gen_ops.relu(-x + _threshold); + } + else + { + negative_part = gen_ops.relu(-x + _threshold); + } + + if (Math.Abs(_threshold) > 0.000001f) + { + x = x * math_ops.cast(tf.greater(x, _threshold), TF_DataType.TF_FLOAT); + } + else if (Math.Abs(_maxValue.Value - 6f) < 0.0001f) + { + x = gen_ops.relu6(x); + } + else + { + x = gen_ops.relu(x); + } + + bool clip_max = _maxValue.HasValue; + if (clip_max) + { + Tensor maxval = constant_op.constant(_maxValue, x.dtype.as_base_dtype()); + var zero = constant_op.constant(0.0f, x.dtype.as_base_dtype()); + x = gen_ops.clip_by_value(x, zero, maxval); + } + + if (Math.Abs(_alpha) > 0.00001) + { + var a = constant_op.constant(_alpha, x.dtype.as_base_dtype()); + x -= a * negative_part; + } + + return x; + } + } + + public class selu : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + const float alpha = 1.6732632423543772848170429916717f; + const float scale = 1.0507009873554804934193349852946f; + return scale * new elu(alpha).Activate(x, name); + } + } + + public class hard_sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + x = (0.2 * x) + 0.5; + var zero = tf.convert_to_tensor(0.0f, x.dtype.as_base_dtype()); + var one = tf.convert_to_tensor(1.0f, x.dtype.as_base_dtype()); + return tf.clip_by_value(x, zero, one); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs new file mode 100644 index 000000000..5d6707799 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -0,0 +1,347 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Operations.ControlFlows; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + /// + /// The context for the conditional construct. + /// + public class CondContext : ControlFlowContext, IProtoBuf + { +#pragma warning disable CS0108 // Member hides inherited member; missing new keyword + private Dictionary _external_values = new Dictionary(); +#pragma warning restore CS0108 // Member hides inherited member; missing new keyword + + /// + /// + /// + /// The `boolean` tensor for the conditional predicate. + /// The predicate tensor in this branch. + /// 0 or 1 representing this branch. + /// Name of the `CondContext` python object. + /// + /// + public CondContext(Tensor pred = null, + Tensor pivot = null, + int branch = 0, + string name = "cond_text", + CondContextDef context_def = null, + string import_scope = null) + { + if (pred == null && context_def == null) return; + + _name = ops.get_default_graph().unique_name(name); + if (context_def != null) + { + _init_from_proto(context_def, import_scope: import_scope); + } + else + { + // Initializes the default fields. + base.__init__(); + _pred = pred; + _pivot = pivot; + _branch = branch; // 0 or 1 representing this branch + // Values considered to have been already seen in this context. pred is not + // included in this context. + _values.Add(pred.name); + _external_values[pred.name] = pred; + _values.Add(pivot.name); + pivot.op._set_control_flow_context(this); + } + } + + private void _init_from_proto(CondContextDef context_def, string import_scope = null) + { + var g = ops.get_default_graph(); + _name = ops.prepend_name_scope(context_def.ContextName, import_scope); + var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); + _pred = g.as_graph_element(p1) as Tensor; + var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); + _pivot = g.as_graph_element(p2) as Tensor; + _branch = context_def.Branch; + __init__(values_def: context_def.ValuesDef, import_scope: import_scope); + } + + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + public override Tensor AddValue(Tensor val) + { + Tensor result = null; + if (_values.Contains(val.name)) + { + // Use the real value if it comes from outer context. This is needed in + // particular for nested conds. + if (_external_values.ContainsKey(val.name)) + result = _external_values[val.name]; + + result = result == null ? val : result; + } + else + { + result = val; + _values.Add(val.name); + // TODO: _outer_context + if (_outer_context != null) + { + result = _outer_context.AddValue(val); + _values.Add(result.name); + _external_values[result.name] = result; + } + + tf_with(ops.control_dependencies(null), ctrl => + { + var results = control_flow_ops._SwitchRefOrTensor(result, _pred); + result = results[_branch]; + if (_outer_context != null) + _outer_context.AddInnerOp(result.op); + }); + + result.op.graph.prevent_fetching(result.op); + result.op._set_control_flow_context(this); + + // Mark Switch output as seen by this context and any outer contexts, + // just like what we do for normal op outputs in _AddOpInternal() below. + ControlFlowContext ctxt = this; + while (ctxt != null) + { + ctxt.values.Add(result.name); + ctxt = ctxt.outer_context; + } + _external_values[val.name] = result; + } + return result; + } + + /// + /// Add the subgraph defined by fn() to the graph. + /// + public (T, Tensor) BuildCondBranch(Func fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + //TODO: port this chunck of missing code: + /* + if len(post_summaries) > len(pre_summaries): + new_summaries = post_summaries[len(pre_summaries):] + summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access + summary_ref[:] = pre_summaries + with ops.control_dependencies(new_summaries): + if original_result is None: + return no_op(), None + else: + original_result = nest.map_structure(array_ops.identity, + original_result) + */ + if (original_result == null) + return (original_result, null); + + switch (original_result) + { + case Tensor result: + return (original_result, _BuildCondTensor(result)); + case Operation op: + return (original_result, _BuildCondTensor(op)); + case float[] fv: + { + var result = ops.convert_to_tensor(fv[0]); + return (original_result, _BuildCondTensor(result)); + } + default: + return (original_result, null); + } + } + + public (T[], Tensor[]) BuildCondBranch(Func fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + switch (original_result) + { + case Tensor[] results: + return (original_result, results.Select(_BuildCondTensor).ToArray()); + case Operation[] results: + return (original_result, results.Select(_BuildCondTensor).ToArray()); + case float[] fv: + var result = ops.convert_to_tensor(fv[0]); + return (original_result, new Tensor[] { result }); + default: + return (original_result, new Tensor[0]); + } + } + + private Tensor _BuildCondTensor(ITensorOrOperation v) + { + switch (v) + { + case Operation op: + // Use pivot as the proxy for this op. + return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot); + case Tensor t: + return _ProcessOutputTensor(t); + default: + return _ProcessOutputTensor(ops.convert_to_tensor(v)); + + } + } + + /// + /// Process an output tensor of a conditional branch. + /// + private Tensor _ProcessOutputTensor(Tensor val) + { + var real_val = val; + if (!_values.Contains(val.name)) + { + // Handle the special case of lambda: x + _values.Add(val.name); + if (_outer_context != null) + { + real_val = _outer_context.AddValue(val); + _values.Add(real_val.name); + _external_values[real_val.name] = real_val; + } + var results = control_flow_ops._SwitchRefOrTensor(real_val, _pred); + real_val = results[_branch]; + _external_values[val.name] = real_val; + } + else + { + Tensor external_val = null; + if (_external_values.ContainsKey(val.name)) + external_val = _external_values[val.name]; + if (external_val != null) + real_val = external_val; + } + return real_val; + } + + protected override void _AddOpInternal(Operation op) + { + if (op.inputs.Length == 0) + { + //If we're in a while loop, remove any control inputs from outside the + // loop. + _RemoveExternalControlEdges(op); + if (!op.control_inputs.Any(input_op => OpInContext(input_op))) + op._add_control_input(_pivot.op); + } + else + { + // Make each input to 'op' available in this CondContext. If an input is + // already part of this context there's nothing to do, but if it's + // external, AddValue() will handle adding the appropriate Switch node and + // other bookkeeping. + for (int index = 0; index < op.inputs.Length; index++) + { + var x = op.inputs[index]; + Tensor real_x = null; + if (op.type == "Merge" && x.op.type == "NextIteration") + { + //# Edge case: if we're importing a while loop inside this CondContext, + //# AddValue() will not correctly handle the NextIteration inputs to + //# Merge node. The problem is that the NextIteration should also be + //# part of this context, but if we're importing it won't have been + //# processed and added to the context yet, so AddValue() will try to + //# add a Switch which results in an invalid graph. Instead, we use the + //# NextIteration input as-is here, and it will eventually be added to + //# the context via AddOp(). + real_x = x; + } + else + { + real_x = AddValue(x); + } + if (real_x != x) + op._update_input(index, real_x); + } + // Remove any external control dependency on this op. + _RemoveExternalControlEdges(op); + // TODO: implement below code dependencies + //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") + // op._add_control_input(_pivot.op); + } + + // Mark op's outputs as seen by this context and any outer contexts. + var output_names = op.outputs.Select(x => x.name).ToArray(); + ControlFlowContext ctxt = this; + while (ctxt != null) + { + foreach (var name in output_names) + ctxt.values.Add(name); + ctxt = ctxt.outer_context; + } + + if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) + op.graph.prevent_fetching(op); + + if (_outer_context != null) + _outer_context.AddInnerOp(op); + } + + public override GradLoopState grad_state + { + get + { + var whc = GetWhileContext(); + if (whc != null) + return whc.grad_state; + return null; + } + } + + public override bool back_prop + { + get + { + var whc = GetWhileContext(); + if (whc != null) + return whc.back_prop; + return false; + } + } + + public CondContextDef to_proto(string export_scope) + { + throw new NotImplementedException(); + } + + public CondContext from_proto(CondContextDef proto, string import_scope) + { + var ret = new CondContext(context_def: proto, import_scope: import_scope); + + ret.Enter(); + foreach (var nested_def in proto.NestedContexts) + from_control_flow_context_def(nested_def, import_scope: import_scope); + ret.Exit(); + return ret; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs new file mode 100644 index 000000000..0ee73815a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -0,0 +1,333 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Operations.ControlFlows; +using static Tensorflow.Binding; +using static Tensorflow.ControlFlowContextDef; +using util = Tensorflow.control_flow_util; + +namespace Tensorflow.Operations +{ + /// + /// The base class for control flow context. + /// + /// The usage pattern is a sequence of(Enter, Exit) followed by a final + /// ExitResult. + /// + /// We maintain the following state for control flow contexts during graph + /// construction: + /// 1. graph has _control_flow_context: the current context used to + /// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit() + /// 2. op has _control_flow_context: the context to which the op belongs. + /// Set at the time the op is created.Immutable. + /// 3. A ControlFlowContext has _outer_context: the context in which this + /// context is created.Set at the time a context is created.Immutable. + /// 4. A ControlFlowContext has _context_stack. + /// Pushed and popped by ctxt.Enter() and ctxt.Exit() + /// + public abstract class ControlFlowContext : ITensorFlowObject + { + /// + /// The predicate tensor in this branch + /// + protected Tensor _pivot; + public Tensor pivot => _pivot; + + /// + /// The boolean tensor for the cond predicate + /// + protected Tensor _pred; + public Tensor pred => _pred; + + /// + /// 0 or 1 representing this branch + /// + protected int _branch; + public int branch => _branch; + + protected Stack _context_stack; + protected ControlFlowContext _outer_context; + + /// + /// The keys are the names of tensors referenced by but external to this + /// context. Each value is the Tensor that should be used by this context to + /// access the key value (e.g. a switch output guarding a cond input value). + /// + protected Dictionary _external_values; + + public ControlFlowContext() + { + _context_stack = new Stack(); + _external_values = new Dictionary(); + } + + public string Name { get => _name; } + protected string _name; + + public void __init__(ValuesDef values_def = null, string import_scope = null) + { + _outer_context = ops.get_default_graph()._get_control_flow_context(); + if (values_def != null) + _init_values_from_proto(values_def, import_scope: import_scope); + else + { + _values = new HashSet(); + _external_values = new Dictionary(); + } + + } + + public void __enter__() + { + } + + /// + /// Initializes values and external_values from `ValuesDef` protocol buffer. + /// + /// + /// + protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) + { + _external_values = new Dictionary(); + foreach (var value in values_def.Values) + _values.Add(value); + var g = ops.get_default_graph(); + foreach (var value in values_def.ExternalValues) + { + var k = ops.prepend_name_scope(value.Key, import_scope); + var v = value.Value; + _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); + } + + var op_names = _values.Where(x => !_external_values.ContainsKey(x)) + .Select(x => x.Split(':')[0]) + .ToArray(); + + foreach (var op in op_names) + (g.as_graph_element(op) as Operation)._set_control_flow_context(this); + } + + public void __exit__() + { + } + + /// + /// Enter this control flow context. + /// + public virtual void Enter() + { + var graph = ops.get_default_graph(); + _context_stack.Push(graph._get_control_flow_context()); + graph._set_control_flow_context(this); + } + + /// + /// Exit this control flow context. + /// + public virtual void Exit() + { + var graph = ops.get_default_graph(); + var last_context = _context_stack.Pop(); + graph._set_control_flow_context(last_context); + } + + public void ExitResult(Tensor[] result) + { + if (_outer_context != null) + { + throw new NotImplementedException("ExitResult"); + } + } + + /// + /// Add `op` to the current context. + /// + public virtual void AddOp(Operation op) + { + _AddOpInternal(op); + } + + public ControlFlowContext outer_context { get { return _outer_context; } } + public HashSet values => _values; + + public virtual GradLoopState grad_state => throw new NotImplementedException("abstract method"); + + public virtual bool back_prop => throw new NotImplementedException("abstract method"); + + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// + public virtual Tensor AddValue(Tensor val) + { + // to be overridden + return null; + } + + public void AddName(string name) + { + _values.Add(name); + } + + /// + /// Notifies a scope about an operator added to an inner scope. + /// + /// + public virtual void AddInnerOp(Operation op) + { + if (_outer_context != null) + _outer_context.AddInnerOp(op); + } + + protected HashSet _values = new HashSet(); + + /// + /// Add `op` to the current context. + /// + protected virtual void _AddOpInternal(Operation op) + { + if (op == null) + { + throw new NotImplementedException(""); + } + else + { + foreach (var index in range(len(op.inputs))) + { + var x = op.inputs[index]; + var real_x = AddValue(x); + if (real_x != x) + op._update_input(index, real_x); + } + } + } + + protected bool OpInContext(Operation op) + { + return IsContainingContext(op._get_control_flow_context(), this); + } + + /// + /// Returns true if `maybe_containing_ctxt` is or contains `ctxt`. + /// + public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowContext maybe_containing_ctxt) + { + while (ctxt != maybe_containing_ctxt) + { + if (ctxt == null) + return false; + ctxt = ctxt.outer_context; + } + return true; + } + + protected virtual bool _IsInOuterContext(Operation op) + { + throw new NotImplementedException("_IsInOuterContext"); + } + + /// + /// Remove any external control dependency on this op. + /// + /// + protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op) + { + var while_ctxt = GetWhileContext(); + + var internal_control_inputs = new List(); + // A control input of `op` is internal if it is in the same while + // loop context as the enclosing while loop context of self. + if (while_ctxt == null) + { + internal_control_inputs = op.control_inputs.ToList(); + } + else + { + foreach (Operation x in op.control_inputs) + { + var ctxt = util.GetOutputContext(x); + if (ctxt != null && ctxt.GetWhileContext() == while_ctxt) + internal_control_inputs.append(x); + } + } + + var external_control_inputs = new List(); + if (len(internal_control_inputs) != len(op.control_inputs)) + throw new NotImplementedException(""); + + return (internal_control_inputs.ToArray(), external_control_inputs.ToArray()); + } + + /// + /// Return the while context containing this context + /// + public virtual WhileContext GetWhileContext() + { + if (_outer_context != null) + return _outer_context.GetWhileContext(); + return null; + } + + /// + /// Deserializes `context_def` into the appropriate ControlFlowContext. + /// + /// ControlFlowContextDef proto + /// Name scope to add + /// A ControlFlowContext subclass + protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef context_def, string import_scope = "") + { + switch (context_def.CtxtCase) + { + case CtxtOneofCase.CondCtxt: + return new CondContext().from_proto(context_def.CondCtxt, import_scope: import_scope); + case CtxtOneofCase.WhileCtxt: + return new WhileContext().from_proto(context_def.WhileCtxt, import_scope: import_scope); + } + + throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}"); + } + + public virtual bool IsWhileContext() + => false; + + public virtual bool IsCondContext() + => false; + + public object to_proto() + { + throw new NotImplementedException(); + } + + + public void Dispose() + { + } + + public void __init__() + { + + } + + public void __del__() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs new file mode 100644 index 000000000..a6390c791 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs @@ -0,0 +1,322 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using util = Tensorflow.control_flow_util; + +namespace Tensorflow.Operations.ControlFlows +{ + /// + /// Maintain the mapping from the loops to their grad states. + /// + public class ControlFlowState + { + Dictionary _map; + //class ControlFlowState(object): + // """Maintain the mapping from the loops to their grad states.""" + + // def __init__(self): + // self._map = {} # maps forward loop context to GradLoopState + + // def GetGradState(self, op, before): + // """Return the grad state for this op if it's in a forward loop context.""" + // if before and util.IsLoopExit(op): + // forward_ctxt = op._get_control_flow_context() + // forward_ctxt = forward_ctxt.outer_context + // if forward_ctxt: + // forward_ctxt = forward_ctxt.GetWhileContext() + // else: + // forward_ctxt = _GetWhileContext(op) + // if forward_ctxt: + // return self._map.get(forward_ctxt) + // return None + + public ControlFlowState() + { + _map = new Dictionary(); + } + + /// + /// Return the grad state for this op if it's in a forward loop context. + /// + /// + /// + /// + public GradLoopState GetGradState(Operation op, bool before) + { + ControlFlowContext forward_ctxt = null; + if (before && util.IsLoopExit(op)) + { + forward_ctxt = op._get_control_flow_context(); + forward_ctxt = forward_ctxt.outer_context; + if (forward_ctxt != null) + forward_ctxt = forward_ctxt.GetWhileContext(); + } + else + forward_ctxt = util.GetWhileContext(op); + if (forward_ctxt != null) + return _map.get(forward_ctxt); + return null; + } + + public Tensor[] ProcessUnusedLoopExits(Dictionary pending_count, List to_ops_set) + { + var loop_exits = new List(); + foreach (var grad_state in _map.Values) + { + foreach (var y in grad_state.forward_loop_exits) + { + if (!pending_count.ContainsKey(y.op.name)) + { + grad_state.pending_exits_count -= 1; + if (!to_ops_set.Contains(y.op)) + grad_state.unused_exits.append(y); + if (grad_state.pending_exits_count == 0) + loop_exits.extend(grad_state.unused_exits); + } + } + + foreach (var y in grad_state.forward_context.loop_enters) + { + if (!pending_count.ContainsKey(y.op.name)) + pending_count[y.op.name] = 1; + } + } + + return loop_exits.ToArray(); + } + + public void EnterGradWhileContext(Operation op, bool before) + { + var grad_state = GetGradState(op, before); + if (grad_state != null) + grad_state.grad_context.Enter(); + } + + public void ExitGradWhileContext(Operation op, bool before) + { + var grad_state = GetGradState(op, before); + if (grad_state != null) + grad_state.grad_context.Exit(); + } + + // def AddWhileContext(self, op, between_op_list, between_ops): + // """Add the grad state for the while loop that op belongs to. + + // Note that op is an Exit, and this method must be called in + // the control flow context where gradients() is called. + + // Note that this method modifies `between_op_list` and `between_ops`. + // """ + // forward_ctxt = _GetWhileContext(op) + // grad_state = self._map.get(forward_ctxt) + // if grad_state is None: + // # This is a new while loop so create a grad state for it. + // outer_forward_ctxt = forward_ctxt.outer_context + // if outer_forward_ctxt: + // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() + // outer_grad_state = None + // if outer_forward_ctxt: + // outer_grad_state = self._map.get(outer_forward_ctxt) + // grad_state = GradLoopState(forward_ctxt, outer_grad_state) + // self._map[forward_ctxt] = grad_state + + // # We need to include all exits of a loop for backprop. + // for loop_exit in grad_state.forward_loop_exits: + // if loop_exit.op not in between_ops: + // between_ops.add(loop_exit.op) + // between_op_list.append(loop_exit.op) + public void AddWhileContext(Operation op, List between_op_list, List between_ops) + { + var forward_ctxt = op.GetWhileContext(); + var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; + if (grad_state == null) + { + GradLoopState outer_grad_state = null; + var outer_forward_ctxt = forward_ctxt.outer_context; + if (outer_forward_ctxt != null) + outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); + if (outer_forward_ctxt != null) + outer_grad_state = _map[outer_forward_ctxt]; + grad_state = new GradLoopState(forward_ctxt, outer_grad_state); + _map[forward_ctxt] = grad_state; + + // We need to include all exits of a loop for backprop. + foreach (var loop_exit in grad_state.forward_loop_exits) + { + if (!between_ops.Contains(loop_exit.op)) + { + between_ops.add(loop_exit.op); + between_op_list.append(loop_exit.op); + } + } + } + } + + // def ZerosLikeForExit(self, val): + // """Create zeros_like gradient for a loop exit. + + // If the result of a loop variable is not used but is involved in + // computing the result of some needed loop variable, we create a + // zero-valued tensor that is fed as gradient for the Exit node of that + // loop variable. Note that val.op is an Exit, and this method must be + // called in the control flow context where gradients() is called. + + // Args: + // val: The output tensor of an Exit op. + + // Returns: + // A zero tensor of the same shape of val. + // """ + // val_shape = val.get_shape() + // forward_ctxt = val.op._get_control_flow_context() + // outer_forward_ctxt = forward_ctxt.outer_context + // if outer_forward_ctxt: + // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() + // outer_grad_state = None + // if outer_forward_ctxt: + // outer_grad_state = self._map.get(outer_forward_ctxt) + // if outer_grad_state: + // # This is a nested loop. + // if val_shape.is_fully_defined(): + // # If the shape is known statically, just create a zero tensor + // # with the right shape in the right context. + // outer_grad_state.grad_context.Enter() + // result = array_ops.zeros(val_shape.dims, val.dtype) + // outer_grad_state.grad_context.Exit() + // else: + // # Only the shape of value is needed for backprop. + // forward_ctxt.outer_context.Enter() + // shape = array_ops.shape_internal(val, optimize=False) + // forward_ctxt.outer_context.Exit() + // # Save the shape to a stack. + // history_shape = outer_grad_state.AddForwardAccumulator(shape) + // # Get the shape back from the stack. + // outer_grad_ctxt = outer_grad_state.grad_context + // outer_grad_ctxt.Enter() + // real_shape = outer_grad_state.AddBackpropAccumulatedValue( + // history_shape, shape) + // result = array_ops.zeros(real_shape, val.dtype) + // outer_grad_ctxt.Exit() + // else: + // # This is not a nested loop. + // if val_shape.is_fully_defined(): + // # If the shape is known statically, just create a zero tensor + // # with the right shape. + // result = array_ops.zeros(val_shape.dims, val.dtype) + // else: + // result = array_ops.zeros_like(val, optimize=False) + // return result + + public Tensor ZerosLike(Operation op, int index) + { + if (util.IsLoopSwitch(op)) + return null; + if (op.graph.building_function) + return array_ops.zeros_like(op.outputs[index]); + var dead_branch = util.IsSwitch(op); + var forward_ctxt = util.GetWhileContext(op); + var grad_state = _map.get(forward_ctxt); + // op is not in a while loop that is part of gradients(). + if (grad_state == null) + return ZerosLikeOutsideLoop(op, index); + throw new NotImplementedException("ZerosLike"); + } + + public Tensor ZerosLikeOutsideLoop(Operation op, int index) + { + var val = op.outputs[index]; + if (!util.IsSwitch(op)) + { + if (val.dtype == dtypes.resource) + throw new NotImplementedException("ZerosLikeOutsideLoop"); + /*return array_ops.zeros( + gen_resource_variable_ops.variable_shape(val), + dtype: default_gradient.get_zeros_dtype(val));*/ + return array_ops.zeros_like(val, optimize: false); + } + else + throw new NotImplementedException("ZerosLikeOutsideLoop"); + } + + /// + /// Create zeros_like gradient for a loop exit. + /// + /// + /// + public Tensor ZerosLikeForExit(Tensor val) + { + Tensor result = null; + var val_shape = val.shape; + var forward_ctxt = val.op._get_control_flow_context(); + var outer_forward_ctxt = forward_ctxt.outer_context; + if (outer_forward_ctxt != null) + outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); + GradLoopState outer_grad_state = null; + if (outer_forward_ctxt != null) + outer_grad_state = _map.get(outer_forward_ctxt); + // This is a nested loop. + if (outer_grad_state != null) + { + throw new NotImplementedException("ZerosLikeForExit"); + } + else + { + // If the shape is known statically, just create a zero tensor + // with the right shape. + if (val_shape.IsFullyDefined) + result = array_ops.zeros(val_shape.dims, val.dtype); + else + result = array_ops.zeros_like(val, optimize: false); + } + return result; + } + + public void PostProcessing() + { + foreach (var grad_state in _map.Values) + { + foreach (var b_merge in grad_state.switch_map.Values) + { + if (b_merge.op.inputs[0] == b_merge.op.inputs[1]) + { + Tensor next_grad_val = null; + // The value of this loop variable at iteration i+1 doesn't + // depend on its value at iteration i. So use zeros as the + // gradients for all iterations > 0. + var dtype = b_merge.op.inputs[0].dtype; + var shape = b_merge.op.inputs[0].shape; + if (shape.IsFullyDefined) + { + grad_state.grad_context.Enter(); + // Create a zeros and use it for iterations > 0. + var grad_val = constant_op.constant(0, dtype: dtype, shape: shape); + next_grad_val = control_flow_ops._NextIteration(grad_val); + grad_state.grad_context.Exit(); + } + else + { + throw new NotImplementedException("PostProcessing shape is not fully defined."); + } + + b_merge.op._update_input(1, next_grad_val); + } + } + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs new file mode 100644 index 000000000..a807bdb50 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs @@ -0,0 +1,334 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections; +using System.Collections.Generic; +using static Tensorflow.Binding; +using util = Tensorflow.control_flow_util; + +namespace Tensorflow.Operations.ControlFlows +{ + /// + /// The state used for constructing the gradient graph for a while loop. + /// + public class GradLoopState + { + private WhileContext _grad_context = null; + + public WhileContext grad_context => _grad_context; + + // # The loop counter added by AddBackpropLoopCounter. It is the value + // # of the loop counter for the current iteration. + // self._grad_index = None + + // # A sync op for backprop. + // self._grad_sync = None + + // # Information needed by backprop. + private Hashtable _history_map = new Hashtable(); + public Hashtable history_map => _history_map; + Dictionary _switch_map = new Dictionary(); + public Dictionary switch_map => _switch_map; + + /// + /// The while loop context for forward. + /// + WhileContext _forward_context; + public WhileContext forward_context => _forward_context; + + /// + /// The grad loop state for the outer while loop. + /// + GradLoopState _outer_grad_state; + public GradLoopState outer_grad_state => _outer_grad_state; + + Tensor _forward_index; + public Tensor forward_index => _forward_index; + Tensor _grad_index; + + Tensor[] _forward_loop_exits; + /// + /// The list of exits of the forward loop. + /// + public Tensor[] forward_loop_exits => _forward_loop_exits; + + List _deferred_exits; + public List deferred_exits => _deferred_exits; + + List _unused_exits; + public List unused_exits => _unused_exits; + + /// + /// The number of exits we expect to see but haven't. + /// + public int pending_exits_count { get; set; } + + Operation _grad_sync; + public Operation grad_sync + { + get + { + if (_grad_sync == null) + { + tf_with(ops.control_dependencies(null), delegate + { + _grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync"); + }); + _grad_sync._set_control_flow_context(_grad_context); + _grad_index.op._add_control_input(_grad_sync); + if (_grad_context.outer_context != null) + _grad_context.outer_context.AddInnerOp(_grad_sync); + } + return _grad_sync; + } + } + + public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) + { + // Information needed by backprop. + _unused_exits = new List(); + _deferred_exits = new List(); + _forward_loop_exits = list(forward_ctxt.loop_exits); + pending_exits_count = len(forward_ctxt.loop_exits); + + _outer_grad_state = outer_grad_state_; + + ControlFlowContext outer_forward_ctxt = null; + if (outer_grad_state_ != null) + outer_forward_ctxt = outer_grad_state_.forward_context; + + // Add the forward loop counter. + // with forward_ctxt._graph.as_default(): + Tensor cnt, forward_index; + { + if (outer_forward_ctxt != null) + outer_forward_ctxt.Enter(); + (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); + if (outer_forward_ctxt != null) + outer_forward_ctxt.Exit(); + } + _forward_context = forward_ctxt; + _forward_index = forward_index; + + // Add the backprop WhileContext, and the backprop loop counter. + if (outer_grad_state != null) + { + // This is a nested loop. Remember the iteration counts for each + // execution of this inner loop. + throw new NotImplementedException("GradLoopState"); + } + else + { + if (outer_forward_ctxt != null) + outer_forward_ctxt.Enter(); + _grad_context = new WhileContext( + maximum_iterations: forward_ctxt.maximum_iterations, + parallel_iterations: forward_ctxt.parallel_iterations, + back_prop: forward_ctxt.back_prop, + swap_memory: forward_ctxt.swap_memory, + name: forward_ctxt.Name, + grad_state: this); + _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); + if (outer_forward_ctxt != null) + outer_forward_ctxt.Exit(); + } + } + + /// + /// Add an accumulator for each forward tensor that is needed in backprop. + /// + /// This is added to the forward loop at the first time when a tensor + /// in the forward loop is used by backprop gradient computation loop. + /// We create an accumulator that accumulates the value of tensor at each + /// iteration. Called in the control flow context where gradients() is called. + /// + /// The pseudocode is: + /// ``` + /// acc = stack(); + /// while (_pivot) { + /// acc = stack_push(acc, value); + /// } + /// ``` + /// + /// We make sure that the stack push op in one iteration is executed before + /// next iteration. This is achieved by adding a control edge from + /// `forward_index.op.inputs[0].op` to the push op, and another control + /// edge from the push op to either `forward_index.op` or `forward_sync`. + /// + /// The source tensor in forward that is to be accumulated. + /// True iff the tensor is on a dead branch of a cond. + /// The stack that contains the accumulated history of the tensor. + public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false) + { + _forward_index.graph.as_default(); + { + var curr_ctxt = ops.get_default_graph()._get_control_flow_context(); + return tf_with(ops.control_dependencies(null), delegate + { + Tensor acc = null; + Tensor push = null; + if (curr_ctxt != null) + curr_ctxt.Enter(); + ops.colocate_with(value); + { + // We only need to pass maximum_iterations to the stack if + // we're inside an XLA context. + var max_size = constant_op.constant(-1, dtypes.int32); + acc = gen_data_flow_ops.stack_v2( + max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc"); + } + if (curr_ctxt != null) + curr_ctxt.Exit(); + + // Make acc available in the forward context. + var enter_acc = forward_context.AddValue(acc); + + // Add the stack_push op in the context of value.op. + var swap_enabled = forward_context.swap_memory; + var value_ctxt = util.GetOutputContext(value.op); + if (value_ctxt == forward_context) + { + // value is not nested in the forward context. + forward_context.Enter(); + push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled); + forward_context.Exit(); + // Protect stack push and order it before forward_index. + forward_index.op._add_control_input(push.op); + } + else + { + throw new NotImplementedException("AddForwardAccumulator"); + } + + // Order stack push after the successor of forward_index + var add_op = forward_index.op.inputs[0].op; + push.op._add_control_input(add_op); + return acc; + }); + } + } + + // """Add the getter for an accumulated value in the grad context. + // + // This is added to the backprop loop. Called in the grad context to + // get the value of an accumulated value. The stack pop op must be guarded + // by the pred of the controlling cond. + // + // Args: + // history_value: The history (a stack) of a value. + // value: The value that is pushed onto the stack. + // dead_branch: True iff the tensor is on a dead branch of a cond. + // + // Returns: + // The current value (the top of the stack). + // """ + + public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch = false) + { + var history_ctxt = history_value.op._get_control_flow_context(); + // Find the cond context that controls history_value if any. + CondContext cond_ctxt = null; + Tensor pop = null; + var value_ctxt = value.op._get_control_flow_context(); + while (value_ctxt != null && value_ctxt != history_ctxt) + { + if (value_ctxt is CondContext cc) + cond_ctxt = cc; + value_ctxt = value_ctxt.outer_context; + } + tf_with(ops.control_dependencies(null), delegate + { + grad_context.Enter(); + if (cond_ctxt != null) + { + throw new NotImplementedException("AddBackpropAccumulatedValue"); + } + pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); + pop.shape = value.shape; + grad_context.Exit(); + }); + var parallel_iterations = grad_context.parallel_iterations; + if (parallel_iterations > 1) + // All pops are ordered after pivot_for_body and before grad_sync. + grad_sync._add_control_input(pop.op); + return pop; + } + + /// + /// Get the real value of `value`. + /// + /// A tensor to be captured. + /// The same tensor obtained from the saved history. + public Tensor GetRealValue(Tensor value) + { + Tensor real_value = null; + if (real_value == null) + { + var cur_value = value; + var cur_grad_state = this; + Tensor history_value = null; + while (true) + { + var enter_op = util.GetLoopConstantEnter(cur_value); + if (enter_op != null) + { + // Special case: cur_value comes from a constant Enter node. + cur_value = enter_op.inputs[0]; + cur_grad_state = cur_grad_state.outer_grad_state; + if (cur_grad_state == null) + { + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = _grad_context.AddValue(cur_value); + break; + } + } + else if (constant_op.is_constant(cur_value)) + { + // We are now outside all nested loops for this gradient(), + // so `value` is a loop invariant and there is no need to + // save the history of value. Just make cur_value to enter + // the right control flow context. + real_value = constant_op.constant( + tensor_util.constant_value(cur_value), dtype: cur_value.dtype); + break; + } + else + { + // Record the history of this value in forward_ctxt. + _grad_context.Exit(); + history_value = cur_grad_state.AddForwardAccumulator(cur_value); + _grad_context.Enter(); + break; + } + } + + if (real_value == null) + { + // Add the stack pop op in the grad context. + real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value); + if (cur_grad_state != this) + real_value = _grad_context.AddValue(real_value); + } + _history_map[value.name] = real_value; + } + return real_value; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs new file mode 100644 index 000000000..9aa6d28a7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs @@ -0,0 +1,14 @@ +namespace Tensorflow +{ + // henon: this was too much trouble. there is no value just cost to use an interface here. + //public interface IControlFlowContext + //{ + // void AddOp(Operation op); + // IControlFlowContext outer_context { get; } + // HashSet values { get; } + // Tensor pivot { get; } + // Tensor AddValue(Tensor val); + // void AddInnerOp(Operation resultOp); + // object to_proto(); + //} +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs new file mode 100644 index 000000000..7b18ee46a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Operations +{ + internal class LoopVar : ICanBeFlattened, IPackable> + { + public Tensor Counter { get; set; } + public TItem Item { get; set; } + + public LoopVar(Tensor counter, TItem item) + { + Counter = counter; + Item = item; + } + + public object[] Flatten() + { + var elements = new List { Counter }; + if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null) + elements.AddRange((Item as ICanBeFlattened).Flatten()); + else + elements.Add(Item); + return elements.ToArray(); + } + + public LoopVar Pack(object[] sequences) + { + var counter = sequences[0] as Tensor; + var item = default(TItem); + if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null) + item = (Item as IPackable).Pack(sequences.Skip(1).ToArray()); + return new LoopVar(counter, item); + } + + public static implicit operator (Tensor, TItem)(LoopVar loopVar) + { + return (loopVar.Counter, loopVar.Item); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs new file mode 100644 index 000000000..55526b834 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/MergeOutput.cs @@ -0,0 +1,32 @@ +namespace Tensorflow.Operations +{ + public class MergeOutput + { + Tensor output; + Tensor value_index; + public MergeOutput(Tensor[] values) + { + output = values[0]; + value_index = values[1]; + } + + public Tensor this[int idx] + { + get + { + switch (idx) + { + case 0: + return output; + case 1: + return value_index; + default: + return null; + } + } + } + + public static implicit operator Tensor(MergeOutput merge) + => merge.output; + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs new file mode 100644 index 000000000..8bd430a80 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -0,0 +1,676 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Operations.ControlFlows; +using Tensorflow.Util; +using static Tensorflow.Binding; +using static Tensorflow.control_flow_ops; + +namespace Tensorflow.Operations +{ + /// + /// Creates a `WhileContext`. + /// + public class WhileContext : ControlFlowContext + { + bool _back_prop = true; + GradLoopState _grad_state = null; + Tensor _maximum_iterations; + public Tensor maximum_iterations => _maximum_iterations; + int _parallel_iterations; + public int parallel_iterations => _parallel_iterations; + bool _swap_memory; + public bool swap_memory => _swap_memory; + Tensor _pivot_for_pred; + Tensor _pivot_for_body; + List _loop_exits; + public List loop_exits => _loop_exits; + List _loop_enters; + public List loop_enters => _loop_enters; + Graph _graph; + public override GradLoopState grad_state => _grad_state; + public override bool back_prop => _back_prop; + + public WhileContext(Tensor maximum_iterations = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + string name = "while_context", + GradLoopState grad_state = null, + WhileContextDef context_def = null, + string import_scope = null) + { + if (context_def != null) + { + _init_from_proto(context_def, import_scope: import_scope); + } + else + { + __init__(); + _init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); + } + + _grad_state = grad_state; + } + + private void _init_from_args(Tensor maximum_iterations, + int parallel_iterations, + bool back_prop, + bool swap_memory, + string name) + { + _name = ops.get_default_graph().unique_name(name); + _maximum_iterations = maximum_iterations; + _parallel_iterations = parallel_iterations; + _back_prop = back_prop; + _swap_memory = swap_memory; + _loop_exits = new List(); + _loop_enters = new List(); + _graph = ops.get_default_graph(); + } + + private void _init_from_proto(WhileContextDef context_def, string import_scope = null) + { + var g = ops.get_default_graph(); + _name = ops.prepend_name_scope(context_def.ContextName, import_scope); + if (!string.IsNullOrEmpty(context_def.MaximumIterationsName)) + _maximum_iterations = g.as_graph_element(ops.prepend_name_scope(context_def.MaximumIterationsName, import_scope)) as Tensor; + _parallel_iterations = context_def.ParallelIterations; + _back_prop = context_def.BackProp; + _swap_memory = context_def.SwapMemory; + _pivot_for_pred = g.as_graph_element(ops.prepend_name_scope(context_def.PivotForPredName, import_scope)) as Tensor; + // We use this node to control constants created by the body lambda. + _pivot_for_body = g.as_graph_element(ops.prepend_name_scope(context_def.PivotForBodyName, import_scope)) as Tensor; + // The boolean tensor for loop termination condition. + _pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; + // The list of exit tensors for loop variables. + _loop_exits = new List(); + foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) + _loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor); + // The list of enter tensors for loop variables. + _loop_enters = new List(); + foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) + _loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor); + + __init__(values_def: context_def.ValuesDef, import_scope: import_scope); + } + + /// + /// Add the loop termination condition and body to the graph. + /// + internal LoopVar BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, + LoopVar loop_vars, + Shape[] shape_invariants, + bool return_same_structure) where TItem : IFromMergeVars, new() + { + // Keep original_loop_vars to identify which are TensorArrays + var original_loop_vars = loop_vars; + // Convert TensorArrays to their flow variables + var loop_vars_tensors = nest.flatten2(loop_vars) + .Select(x => _convert_tensorarray_to_flow(x)) + .ToArray(); + + if (shape_invariants == null) + shape_invariants = loop_vars_tensors + .Select(x => _get_shape_invariant(x as Tensor)) + .ToArray(); + + Enter(); + var (original_body_result, exit_vars) = _BuildLoop( + pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); + Exit(); + + var flat_result = nest.flatten2(original_body_result) + .Select(x => x as ITensorOrTensorArray) + .ToArray(); + + var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); + var packed_exit_vars = nest.pack_sequence_as2( + structure: original_body_result, + flat_sequence: exit_vars_with_tensor_arrays); + + return packed_exit_vars; + } + + private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) + { + if (tensor_or_tensor_array is TensorArray tensor_array) + return tensor_array.flow; + else if (tensor_or_tensor_array is Tensor tensor) + return tensor; + + throw new NotImplementedException("_convert_tensorarray_to_flow"); + } + + private Shape _get_shape_invariant(Tensor var, int[] shape = null) + { + return var.shape; + } + + /// + /// Add the loop termination condition and body to the graph. + /// + /// + /// + /// + /// + /// + /// + /// + private (LoopVar, Tensor[]) _BuildLoop(Func, Tensor> pred, + Func, LoopVar> body, + LoopVar original_loop_vars, + Tensor[] loop_vars, + Shape[] shape_invariants) where TItem : IFromMergeVars, new() + { + var flat_loop_vars = nest.flatten2(original_loop_vars) + .Select(x => (ITensorOrTensorArray)x) + .ToArray(); + + // Let the context know the loop variables so the loop variables + // would be added in the outer contexts properly. + _InitializeValues(loop_vars); + var real_vars = loop_vars; + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate + { + enter_vars = real_vars.Select(x => control_flow_ops._Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); + + foreach (var x in enter_vars) + { + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); + } + }); + + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + object control_pivot = null; + while (outer_context != null && control_pivot == null) + { + + } + + if (control_pivot != null) + { + + } + + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .Select(m => (Tensor)m) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + var packed_vars = new LoopVar( + (Tensor)merge_vars_with_tensor_arrays[0], + new TItem().FromMergeVars(merge_vars_with_tensor_arrays)); + var pp = pred(packed_vars); + var c = ops.convert_to_tensor(pp); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + _pivot_for_body = vars_for_body[0]; + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + var body_result = body(packed_vars_for_body); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = body_result; + // Convert TensorArrays returned by body into their flow variables + var result = nest.flatten2(body_result) + .Select(x => _convert_tensorarray_to_flow(x)) + .ToArray(); + // result = ops.convert_n_to_tensor_or_composite(result); + var next_vars = new List(); + foreach (var (m, v) in zip(merge_vars, result)) + next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (original_body_result, exit_vars.ToArray()); + } + + private void _FixControlInputsAndContext(Tensor[] enters) + { + var graph = ops.get_default_graph(); + foreach (var x in enters) + { + var inp_op = x.op.inputs[0].op; + var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); + var outer_control_inputs = new List(); + foreach (Operation op in control_inputs) + { + // We need to keep control inputs that are in any ancestor + // ControlFlowContext, and within outer WhileContext. + var keep_as_control_input = true; + var op_ctxt = control_flow_util.GetOutputContext(op); + var outer_ctxt = outer_context; + var outer_while_context = outer_ctxt == null ? null : outer_ctxt.GetWhileContext(); + while (outer_ctxt != op_ctxt) + { + if (outer_ctxt == null || outer_ctxt == outer_while_context) + { + keep_as_control_input = false; + break; + } + outer_ctxt = outer_ctxt.outer_context; + } + if (keep_as_control_input) + outer_control_inputs.append(op); + } + // op for op in control_inputs if self._IsInOuterContext(op) + /*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) + .Select(x => x.op) + .ToArray();*/ + x.op._set_control_flow_context(this); + x.op._add_control_inputs(outer_control_inputs.ToArray()); + graph._record_op_seen_by_control_dependencies(x.op); + } + } + + /// + /// Makes the values known to this context. + /// + /// + private void _InitializeValues(Tensor[] values) + { + _values = new HashSet(); + foreach (var x in values) + _values.Add(x.name); + } + + protected override void _AddOpInternal(Operation op) + { + if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape") + { + + } + + Operation[] external_inputs = new Operation[0]; + Operation[] control_inputs = new Operation[0]; + if (op.inputs.Length == 0) + { + // Remove any external control dependency on this op + (control_inputs, external_inputs) = _RemoveExternalControlEdges(op); + if (control_inputs.Length == 0) + op._add_control_input(GetControlPivot().op); + foreach (var x in op.outputs) + _values.Add(x.name); + } + else + { + foreach (var index in range(len(op.inputs))) + { + var x = op.inputs[index]; + var real_x = AddValue(x); + if (real_x != x) + op._update_input(index, real_x); + } + + // Remove any external control dependency on this op. + (_, external_inputs) = _RemoveExternalControlEdges(op); + // Add a control dependency to prevent loop invariants from + // enabling ops that should not be executed. + _MaybeAddControlDependency(op); + foreach (Tensor x in op.outputs) + _values.Add(x.name); + } + + if (external_inputs.Length > 0) + { + throw new NotImplementedException("external_inputs.Length > 0"); + } + + if (_outer_context != null || !IsLoopExit(op)) + foreach (Tensor x in op.outputs) + op.graph.prevent_feeding(x); + + if (_outer_context != null) + _outer_context.AddInnerOp(op); + } + + protected void _MaybeAddControlDependency(Operation op) + { + // Determines if `op` needs a control dependency. + Func _IsOpFree = (op1) => + { + if (op1.control_inputs.Length > 0) + return false; + + if (op1.type == "SymbolicGradient") + return true; + + foreach (Tensor x in op1.inputs) + if (!control_flow_util.IsLoopConstantEnter(x.op)) + return false; + + return true; + }; + + if (_IsOpFree(op)) + op._add_control_input(GetControlPivot().op); + } + + private Tensor GetControlPivot() + { + if (_pivot_for_body != null) + return _pivot_for_body; + return _pivot_for_pred; + } + + public override void AddOp(Operation op) + { + _AddOpInternal(op); + } + + /// + /// Adds a loop that counts the number of iterations. + /// + /// The outer grad state. None if not nested. + /// The number of iterations taken by the forward loop and the loop index. + public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) + { + var n = constant_op.constant(0, name: "f_count"); + if (outer_grad_state != null) + throw new NotImplementedException("AddForwardLoopCounter"); + + Enter(); + AddName(n.name); + var enter_n = _Enter(n, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "f_count"); + _loop_enters.Add(enter_n); + + var m1 = merge(new[] { enter_n, enter_n }); + var merge_n = m1[0]; + var switch_n = @switch(merge_n, _pivot); + + var index = math_ops.add(switch_n[1], 1); + var next_n = _NextIteration(index); + merge_n.op._update_input(1, next_n); + + var total_iterations = exit(switch_n[0], name: "f_count"); + loop_exits.append(total_iterations); + ExitResult(new[] { total_iterations }); + Exit(); + + return (total_iterations, next_n); + } + + /// + /// Add an accumulation loop for every loop invariant. + /// + /// The Enter op for a loop invariant. + /// The partial gradient of an iteration for a loop invariant. + /// The gradient for a loop invariant. + public Tensor AddBackpropAccumulator(Operation op, Tensor grad) + { + Tensor acc = null; + Exit(); + // Create a zeros tensor with the right shape for acc. If we don't + // know the full shape statically, we will have to get the shape + // dynamically from the forward inference. Getting the shape right + // for the zeros is only needed for the base case when the loop exits + // without running any iterations. + var shape = grad.shape; + if (shape.IsFullyDefined) + { + if (outer_context != null) + outer_context.Enter(); + acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc"); + if (outer_context != null) + outer_context.Exit(); + } + else + { + var value = op.inputs[0]; + if (outer_context is WhileContext wc) + { + // We are in a nested while loop. + var forward_ctxt = grad_state.forward_context; + forward_ctxt.outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + forward_ctxt.outer_context.Exit(); + var outer_grad_state = grad_state.outer_grad_state; + var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape); + outer_context.Enter(); + var real_shape = outer_grad_state.AddBackpropAccumulatedValue( + history_zeros_shape, zeros_shape); + acc = array_ops.zeros(real_shape, grad.dtype); + outer_context.Exit(); + } + else + { + if (outer_context != null) + outer_context.Enter(); + var zeros_shape = array_ops.shape_internal(value, optimize: false); + acc = array_ops.zeros(zeros_shape, grad.dtype); + if (outer_context != null) + outer_context.Exit(); + } + throw new NotImplementedException("AddBackpropAccumulator"); + } + + Enter(); + AddName(acc.name); + var enter_acc = _Enter( + acc, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "b_acc"); + loop_enters.append(enter_acc); + var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0]; + + var switch_result = @switch(merge_acc, _pivot); + var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]); + + var add_acc = math_ops.add(switch_acc_true, grad); + var next_acc = _NextIteration(add_acc); + merge_acc.op._update_input(1, next_acc); + + var result_acc = exit(switch_acc_false, name: "b_acc"); + loop_exits.append(result_acc); + ExitResult(new[] { result_acc }); + return result_acc; + } + + /// + /// Add the backprop loop that controls the iterations. + /// + /// The number of iterations for backprop. + /// The outer grad state. None if not nested. + /// The loop index. + public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) + { + Tensor one = null; + var in_separate_functions = count.graph != ops.get_default_graph(); + if (in_separate_functions) + // Brings the count into this graph + count = array_ops.identity(count); + else + one = constant_op.constant(1, name: "b_count"); + + Enter(); + AddName(count.name); + var enter_count = _Enter( + count, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + name: "b_count"); + loop_enters.append(enter_count); + + var merge_count = merge(new[] { enter_count, enter_count })[0]; + _pivot_for_pred = merge_count; + if (in_separate_functions) + one = constant_op.constant(1, name: "b_count"); + var pred = math_ops.greater_equal(merge_count, one); + _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); + var switch_count = @switch(merge_count, _pivot); + + var index = math_ops.subtract(switch_count[1], one); + _pivot_for_body = index; + var next_count = _NextIteration(index); + merge_count.op._update_input(1, next_count); + + var final_zero = exit(switch_count[0], name: "b_count"); + loop_exits.append(final_zero); + // Force the stack pops of i-th execution of an inner loop to be ordered + // before the pops of (i+1)-th execution of the same inner loop. + if (outer_grad_state != null) + throw new NotImplementedException("outer_grad_state"); + //outer_grad_state.grad_sync._add_control_input(final_zero.op); + ExitResult(new[] { final_zero }); + Exit(); + return next_count; + } + + /// + /// Add `val` to the current context and its outer context recursively. + /// + /// + /// + public override Tensor AddValue(Tensor val) + { + var result = val; + var new_value = !_values.Contains(val.name); + new_value &= val.op._get_control_flow_context() != this; + if (new_value) + { + _values.Add(val.name); + + // If we are in a grad context and val is from its forward context, + // use GetRealValue(), which adds the logic to save the history of + // val in forward. + var grad_ctxt = ops.get_default_graph()._get_control_flow_context(); + if (grad_ctxt != null) + { + grad_ctxt = grad_ctxt.GetWhileContext(); + if (grad_ctxt.grad_state != null) + { + var forward_ctxt = val.op.GetWhileContext(); + if (control_flow_util.IsLoopExit(val.op)) + { + forward_ctxt = forward_ctxt.outer_context as WhileContext; + if (forward_ctxt != null) + forward_ctxt = forward_ctxt.GetWhileContext(); + throw new NotImplementedException("control_flow_util.IsLoopExit"); + } + if (forward_ctxt == grad_ctxt.grad_state.forward_context) + { + var real_val = grad_ctxt.grad_state.GetRealValue(val); + _external_values[val.name] = real_val; + return real_val; + } + } + } + + if (_outer_context != null) + result = _outer_context.AddValue(val); + + // Create an Enter to make `result` known to this loop context. + Tensor enter = null; + tf_with(ops.control_dependencies(null), delegate + { + enter = control_flow_ops._Enter( + result, + _name, + is_constant: true, + parallel_iterations: _parallel_iterations); + enter.graph.prevent_feeding(enter); + if (_outer_context != null) + _outer_context.AddInnerOp(enter.op); + }); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(new[] { enter }); + // Add `enter` in this context. + _values.Add(enter.name); + _external_values[val.name] = enter; + result = enter; + } + else + { + var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; + if (actual_val != null) + result = actual_val as Tensor; + } + + return result; + } + + public override bool IsWhileContext() + => true; + + public override WhileContext GetWhileContext() + { + return this; + } + + public WhileContext from_proto(WhileContextDef proto, string import_scope) + { + var ret = new WhileContext(context_def: proto, import_scope: import_scope); + + ret.Enter(); + foreach (var nested_def in proto.NestedContexts) + from_control_flow_context_def(nested_def, import_scope: import_scope); + ret.Exit(); + return ret; + } + +#pragma warning disable CS0108 // Member hides inherited member; missing new keyword + public object to_proto() +#pragma warning restore CS0108 // Member hides inherited member; missing new keyword + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Distributions/DistributionEnum.cs b/src/TensorFlowNET.Core/Operations/Distributions/DistributionEnum.cs new file mode 100644 index 000000000..0139f0332 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Distributions/DistributionEnum.cs @@ -0,0 +1,9 @@ +namespace Tensorflow.Operations.Distributions +{ + public enum DistributionEnum + { + + + + } +} diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs new file mode 100644 index 000000000..4375788d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs @@ -0,0 +1,161 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +//Base classes for probability distributions. +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + + +namespace Tensorflow +{ + public class _BaseDistribution + { + // Abstract base class needed for resolving subclass hierarchy. + } + + /// + /// A generic probability distribution base class. + /// Distribution is a base class for constructing and organizing properties + /// (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). + /// + public class Distribution : _BaseDistribution + { + public TF_DataType _dtype { get; set; } + //public ReparameterizationType _reparameterization_type {get;set;} + public bool _validate_args { get; set; } + public bool _allow_nan_stats { get; set; } + public Dictionary _parameters { get; set; } + public List _graph_parents { get; set; } + public string _name { get; set; } + + + /// + /// Log probability density/mass function. + /// + /// `Tensor`. + /// Python `str` prepended to names of ops created by this function. + /// log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`. + + + public Tensor log_prob(Tensor value, string name = "log_prob") + { + return _call_log_prob(value, name); + } + + private Tensor _call_log_prob(Tensor value, string name) + { + return tf_with(ops.name_scope(name, "moments", new { value }), scope => + { + return math_ops.log(value); + }); + } + + protected virtual Tensor _log_prob(Tensor value) + { + throw new NotImplementedException(); + } + + private Tensor _prob(Tensor value) + { + throw new NotImplementedException(); + } + + public TF_DataType dtype() + { + return this._dtype; + } + + + /* + /// + /// Constructs the `Distribution' + /// **This is a private method for subclass use.** + /// + /// The type of the event samples. `None` implies no type-enforcement. + /// Instance of `ReparameterizationType`. + /// If `distributions.FULLY_REPARAMETERIZED`, this `Distribution` can be reparameterized + /// in terms of some standard distribution with a function whose Jacobian is constant for the support + /// of the standard distribution. If `distributions.NOT_REPARAMETERIZED`, + /// then no such reparameterization is available. + /// When `True` distribution parameters are checked for validity despite + /// possibly degrading runtime performance. When `False` invalid inputs silently render incorrect outputs. + /// When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" + /// to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's + /// batch members are undefined. + /// `dict` of parameters used to instantiate this `Distribution`. + /// `list` of graph prerequisites of this `Distribution`. + /// Name prefixed to Ops created by this class. Default: subclass name. + /// Two `Tensor` objects: `mean` and `variance`. + + private Distribution ( + TF_DataType dtype, + ReparameterizationType reparameterization_type, + bool validate_args, + bool allow_nan_stats, + Dictionary parameters=null, + List graph_parents=null, + string name= null) + { + this._dtype = dtype; + this._reparameterization_type = reparameterization_type; + this._allow_nan_stats = allow_nan_stats; + this._validate_args = validate_args; + this._parameters = parameters; + this._graph_parents = graph_parents; + this._name = name; + } + */ + + + + + } + + /// + /// Instances of this class represent how sampling is reparameterized. + /// Two static instances exist in the distributions library, signifying + /// one of two possible properties for samples from a distribution: + /// `FULLY_REPARAMETERIZED`: Samples from the distribution are fully + /// reparameterized, and straight-through gradients are supported. + /// `NOT_REPARAMETERIZED`: Samples from the distribution are not fully + /// reparameterized, and straight-through gradients are either partially + /// unsupported or are not supported at all. In this case, for purposes of + /// e.g. RL or variational inference, it is generally safest to wrap the + /// sample results in a `stop_gradients` call and use policy + /// gradients / surrogate loss instead. + /// + class ReparameterizationType + { + public string _rep_type { get; set; } + public ReparameterizationType(string rep_type) + { + this._rep_type = rep_type; + } + + public void repr() + { + Binding.tf_output_redirect.WriteLine($""); + } + + public bool eq(ReparameterizationType other) + { + return this.Equals(other); + } + } + + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs new file mode 100644 index 000000000..a73bbcc02 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs @@ -0,0 +1,126 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class Normal : Distribution + { + public Tensor _loc { get; set; } + public Tensor _scale { get; set; } + + public Dictionary parameters = new Dictionary(); + /// + /// The Normal distribution with location `loc` and `scale` parameters. + /// Mathematical details + /// The probability density function(pdf) is, + /// ''' + /// pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z + /// Z = (2 pi sigma**2)**0.5 + /// ''' + /// where `loc = mu` is the mean, `scale = sigma` is the std.deviation, and, `Z` + /// is the normalization constant. + /// + /// + /// + /// + /// + /// + public Normal(Tensor loc, Tensor scale, bool validate_args = false, bool allow_nan_stats = true, string name = "Normal") + { + parameters.Add("name", name); + parameters.Add("loc", loc); + parameters.Add("scale", scale); + parameters.Add("validate_args", validate_args); + parameters.Add("allow_nan_stats", allow_nan_stats); + + tf_with(ops.name_scope(name, "", new { loc, scale }), scope => + { + tf_with(ops.control_dependencies(validate_args ? new Operation[] { scale.op } : new Operation[] { }), cd => + { + this._loc = array_ops.identity(loc, name); + this._scale = array_ops.identity(scale, name); + base._dtype = this._scale.dtype; + // base._reparameterization_type = new ReparameterizationType("FULLY_REPARAMETERIZED"); + base._validate_args = validate_args; + base._allow_nan_stats = allow_nan_stats; + base._parameters = parameters; + base._graph_parents = new List(new Tensor[] { this._loc, this._scale }); + base._name = name; + }); + + }); + + } + /// + /// Distribution parameter for the mean. + /// + /// + public Tensor loc() + { + return _loc; + } + /// + /// Distribution parameter for standard deviation." + /// + /// + public Tensor scale() + { + return _scale; + } + + public Tensor _batch_shape_tensor() + { + return array_ops.broadcast_dynamic_shape(array_ops.shape(_loc), array_ops.shape(_scale)); + } + + public Tensor _batch_shape() + { + return array_ops.broadcast_static_shape(new Tensor(_loc.shape.dims), new Tensor(_scale.shape.dims)); + } + + protected override Tensor _log_prob(Tensor x) + { + var log_prob = _log_unnormalized_prob(x); + var log_norm = _log_normalization(); + return tf.sub(log_prob, log_norm); + } + + private Tensor _log_unnormalized_prob(Tensor x) + { + return -0.5 * math_ops.square(_z(x)); + } + /// + /// Standardize input `x` to a unit normal. + /// + /// + /// + private Tensor _z(Tensor x) + { + return tf.divide(tf.sub(x, this._loc), this._scale); + } + + private Tensor _log_normalization() + { + Tensor t1 = ops.convert_to_tensor(Math.Log(2.0 * Math.PI), TF_DataType.TF_FLOAT); + Tensor t2 = tf.multiply(ops.convert_to_tensor(0.5, TF_DataType.TF_FLOAT), t1); + return tf.add(t2, math_ops.log(this._scale)); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs new file mode 100644 index 000000000..e7e9955c0 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -0,0 +1,55 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class Constant : IInitializer + { + TF_DataType dtype; + T value; + bool _verify_shape; + + private readonly Dictionary _config; + + public string ClassName => "Constant"; + public IDictionary Config => _config; + + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + { + this.value = value; + this.dtype = dtype; + _verify_shape = verify_shape; + + _config = new Dictionary(); + _config["value"] = this.value; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; + + args.VerifyShape = _verify_shape; + + return constant_op.constant(value, args.DType, args.Shape, + name: "Const", + verify_shape: args.VerifyShape, + allow_broadcast: false); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs new file mode 100644 index 000000000..7cd88cc68 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -0,0 +1,42 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class GlorotUniform : VarianceScaling + { + private readonly Dictionary _config; + + public override string ClassName => "GlorotUniform"; + public override IDictionary Config => _config; + + public GlorotUniform(float scale = 1.0f, + string mode = "fan_avg", + string distribution = "uniform", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale: scale, + mode: mode, + distribution: distribution, + seed: seed, + dtype: dtype) + { + _config = new Dictionary(); + _config["seed"] = _seed; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs new file mode 100644 index 000000000..35b92448c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Newtonsoft.Json; +using System.Collections.Generic; +using Tensorflow.Keras.Saving.Common; + +namespace Tensorflow +{ + [JsonConverter(typeof(CustomizedIinitializerJsonConverter))] + public interface IInitializer + { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } + Tensor Apply(InitializerArgs args); + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs new file mode 100644 index 000000000..9df8b5bde --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs @@ -0,0 +1,21 @@ +namespace Tensorflow +{ + public class InitializerArgs + { + public string Name { get; set; } + public Shape Shape { get; set; } + public TF_DataType DType { get; set; } + public bool VerifyShape { get; set; } + + public InitializerArgs(Shape shape, + TF_DataType dtype = TF_DataType.DtInvalid, + bool verify_shape = false, + string name = null) + { + Shape = shape; + DType = dtype; + VerifyShape = verify_shape; + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs new file mode 100644 index 000000000..202af652a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.NumPy; + +namespace Tensorflow.Operations.Initializers +{ + /// + /// An initializer specially used for debugging (to load weights from disk). + /// + class NpyLoadInitializer : IInitializer + { + string _path; + public NpyLoadInitializer(string path) { _path = path; } + public string ClassName => ""; + public IDictionary Config => new Dictionary(); + public Tensor Apply(InitializerArgs args) + { + return np.load(_path); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs new file mode 100644 index 000000000..3077a1e0e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class Ones : IInitializer + { + private TF_DataType dtype; + + private readonly Dictionary _config; + + public string ClassName => "Ones"; + public IDictionary Config => new Dictionary(); + + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.dtype = dtype; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; + + return array_ops.ones(args.Shape, dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs new file mode 100644 index 000000000..ae8733740 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -0,0 +1,66 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations.Initializers; +using System.Collections.Generic; + +public class Orthogonal : IInitializer +{ + float _gain = 0f; + int? _seed; + + public Orthogonal(float gain = 1.0f, int? seed = null) + { + _gain = gain; + _seed = seed; + } + + private readonly Dictionary _config; + + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); + public Tensor Apply(InitializerArgs args) + { + return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); + } + + private Tensor _generate_init_val(Shape shape, TF_DataType dtype) + { + var num_rows = 1L; + foreach (var dim in shape.dims.Take(shape.ndim - 1)) + num_rows *= dim; + var num_cols = shape.dims.Last(); + var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows)); + + var a = tf.random.stateless_normal(flat_shape, dtype: dtype); + // Compute the qr factorization + var (q, r) = tf.linalg.qr(a, full_matrices: false); + // Make Q uniform + var d = tf.linalg.tensor_diag_part(r.Single); + q *= tf.sign(d); + + if (num_rows < num_cols) + { + q = array_ops.matrix_transpose(q); + } + + return _gain * tf.reshape(q, shape); + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs new file mode 100644 index 000000000..21fa7e2b2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -0,0 +1,56 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class RandomNormal : IInitializer + { + private float mean; + private float stddev; + private int? seed; + private TF_DataType dtype; + + private readonly Dictionary _config; + + public string ClassName => "RandomNormal"; + public IDictionary Config => _config; + + public RandomNormal(float mean = 0.0f, + float stddev = 0.05f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.mean = mean; + this.stddev = stddev; + this.seed = seed; + this.dtype = dtype; + + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = dtype; + return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs new file mode 100644 index 000000000..87404708c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -0,0 +1,58 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class RandomUniform : IInitializer + { + private int? seed; + private float minval; + private float maxval; + private TF_DataType dtype; + + private readonly Dictionary _config; + + public string ClassName => "RandomUniform"; + public IDictionary Config => _config; + + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) + { + this.dtype = dtype; + this.minval = minval; + this.maxval = maxval; + this.seed = seed; + + _config = new Dictionary(); + _config["minval"] = this.minval; + _config["maxval"] = this.maxval; + _config["seed"] = this.seed; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = dtype; + + return random_ops.random_uniform(args.Shape, + minval: minval, + maxval: maxval, + dtype: dtype, + seed: seed); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs new file mode 100644 index 000000000..c1c3e9996 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -0,0 +1,55 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class TruncatedNormal : IInitializer + { + private float mean; + private float stddev; + private int? seed; + private TF_DataType dtype; + + private readonly Dictionary _config; + + public string ClassName => "TruncatedNormal"; + public IDictionary Config => _config; + + public TruncatedNormal(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.mean = mean; + this.stddev = stddev; + this.seed = seed; + this.dtype = dtype; + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType != TF_DataType.DtInvalid) + dtype = args.DType; + return random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs new file mode 100644 index 000000000..37fdd764c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -0,0 +1,128 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace Tensorflow.Operations.Initializers +{ + /// + /// Initializer capable of adapting its scale to the shape of weights tensors. + /// + public class VarianceScaling : IInitializer + { + protected float _scale; + protected string _mode; + protected int? _seed; + protected TF_DataType _dtype; + protected string _distribution; + private readonly Dictionary _config; + + public virtual string ClassName => "VarianceScaling"; + + public virtual IDictionary Config => _config; + + public VarianceScaling(float scale = 1.0f, + string mode = "fan_in", + string distribution = "truncated_normal", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + if (!dtype.is_floating()) + throw new TypeError("Cannot create initializer for non-floating point type."); + if (!new string[] { "fan_in", "fan_out", "fan_avg" }.Contains(mode)) + throw new TypeError($"Unknown {mode} %s [fan_in, fan_out, fan_avg]"); + if(distribution == "normal") + { + distribution = "truncated_normal"; + } + if(!new string[] { "uniform", "truncated_normal", "untruncated_normal" }.Contains(distribution)) + { + throw new ValueError($"Invalid `distribution` argument: {distribution}"); + } + + if (scale <= 0) + throw new ValueError("`scale` must be positive float."); + + _scale = scale; + _mode = mode; + _seed = seed; + _dtype = dtype; + _distribution = distribution; + + _config = new(); + _config["scale"] = _scale; + _config["mode"] = _mode; + _config["distribution"] = _distribution; + _config["seed"] = _seed; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = this._dtype; + + float n = 0; + var (fan_in, fan_out) = _compute_fans(args.Shape); + var scale = this._scale; + if (_mode == "fan_in") + scale /= Math.Max(1.0f, fan_in); + else if (_mode == "fan_out") + scale /= Math.Max(1.0f, fan_out); + else + scale /= Math.Max(1.0f, (fan_in + fan_out) / 2); + + if(_distribution == "truncated_normal") + { + var stddev = Math.Sqrt(scale) / .87962566103423978f; + return random_ops.truncated_normal(args.Shape, 0.0f, (float)stddev, args.DType); + } + else if(_distribution == "untruncated_normal") + { + var stddev = Math.Sqrt(scale); + return random_ops.random_normal(args.Shape, 0.0f, (float)stddev, args.DType); + } + else + { + var limit = (float)Math.Sqrt(scale * 3.0f); + return random_ops.random_uniform(args.Shape, -limit, limit, args.DType); + } + } + + private (int, int) _compute_fans(int[] shape) + { + if (shape.Length < 1) + return (1, 1); + if (shape.Length == 1) + return (shape[0], shape[0]); + if (shape.Length == 2) + return (shape[0], shape[1]); + else + { + // Assuming convolution kernels (2D, 3D, or more). + // kernel shape: (..., input_depth, depth) + int receptive_field_size = 1; + foreach (var dim in shape.Take(shape.Length - 2)) + receptive_field_size *= dim; + var fan_in = shape[shape.Length - 2] * receptive_field_size; + var fan_out = shape[shape.Length - 1] * receptive_field_size; + return (fan_in, fan_out); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs new file mode 100644 index 000000000..c4ed25a17 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -0,0 +1,45 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow.Operations.Initializers +{ + public class Zeros : IInitializer + { + Shape shape; + TF_DataType dtype; + + public string ClassName => "Zeros"; + public IDictionary Config => new Dictionary(); + + public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.shape = shape; + this.dtype = dtype; + } + + public Tensor Apply(InitializerArgs args) + { + if (args.DType == TF_DataType.DtInvalid) + args.DType = dtype; + if (args.Shape == null) + args.Shape = shape; + + return array_ops.zeros(args.Shape, dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs new file mode 100644 index 000000000..40c897c5d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/InputList.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow +{ + public class InputList : IEnumerable + { + public Tensor[] _inputs; + public int Length => _inputs.Length; + public Tensor this[int index] + { + get + { + if (index == -1) + index = _inputs.Length - 1; + return _inputs[index]; + } + } + + public InputList(Tensor[] inputs) + { + _inputs = inputs; + } + + public IEnumerator GetEnumerator() + { + return _inputs.GetEnumerator(); + } + + public static implicit operator List(InputList input) + { + return input._inputs.ToList(); + } + + public static implicit operator Tensor[](InputList input) + { + return input._inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs new file mode 100644 index 000000000..bef485461 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs @@ -0,0 +1,13 @@ +namespace Tensorflow +{ + public class Reduction + { + public const string NONE = "none"; + public const string SUM = "sum"; + public const string WEIGHTED_SUM = "weighted_sum"; + public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; + public const string WEIGHTED_MEAN = "weighted_mean"; + public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"; + public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS; + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/Util.cs b/src/TensorFlowNET.Core/Operations/Losses/Util.cs new file mode 100644 index 000000000..fde5bcb09 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Losses/Util.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Operations.Losses +{ + public class Util + { + public static void add_loss(Tensor loss, string loss_collection = "losses") + { + if (!string.IsNullOrEmpty(loss_collection)) + ops.add_to_collection(loss_collection, loss); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs new file mode 100644 index 000000000..a412f07ee --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -0,0 +1,158 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class LossesImpl + { + public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, + string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) + { + if (weights == null) + weights = tf.constant(1.0f); + + return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate + { + // Save the `reduction` argument for loss normalization when distributing + // to multiple replicas. Used only for estimator + v1 optimizer flow. + ops.get_default_graph()._last_loss_reduction = reduction; + + /*var dp = weights_broadcast_ops.assert_broadcastable(weights, losses); + with(ops.control_dependencies(dp), delegate + { + + });*/ + + losses = ops.convert_to_tensor(losses); + var input_dtype = losses.dtype; + losses = math_ops.cast(losses, dtype: dtypes.float32); + weights = math_ops.cast(weights, dtype: dtypes.float32); + var weighted_losses = math_ops.multiply(losses, weights); + Tensor loss = null; + if (reduction == Reduction.NONE) + loss = weighted_losses; + else + { + loss = math_ops.reduce_sum(weighted_losses); + if (reduction == Reduction.WEIGHTED_MEAN) + loss = _safe_mean( + loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)); + else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS || + reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss = _safe_mean(loss, _num_present(losses, weights)); + else if (reduction == Reduction.SUM_OVER_BATCH_SIZE) + loss = _safe_mean(loss, _num_elements(losses)); + } + + // Convert the result back to the input type. + loss = math_ops.cast(loss, input_dtype); + Operations.Losses.Util.add_loss(loss, loss_collection); + return loss; + }); + } + + public Tensor _safe_mean(Tensor losses, Tensor num_present) + { + var total_loss = math_ops.reduce_sum(losses); + return math_ops.div_no_nan(total_loss, num_present, name: "value"); + } + + public Tensor _num_elements(Tensor losses) + { + throw new NotImplementedException("LossesImpl._num_elements"); + } + + public Tensor _num_present(Tensor losses, Tensor weights, bool per_batch = false) + { + return tf_with(ops.name_scope(null, default_name: "num_present", (losses, weights)), name_scope => + { + string scope = name_scope; + weights = math_ops.cast(weights, dtype: dtypes.float32); + var present = array_ops.where( + math_ops.equal(weights, 0.0), + array_ops.zeros_like(weights), + array_ops.ones_like(weights)); + present = weights_broadcast_ops.broadcast_weights(present, losses); + + if (per_batch) + return math_ops.reduce_sum( + present, + axis: math_ops.range(1, array_ops.rank(present)), + keepdims: true, + name: scope); + return math_ops.reduce_sum(present, name: scope); + }); + } + + public Tensor sparse_softmax_cross_entropy(Tensor labels, + Tensor logits, + float weights = 1.0f, + string scope = null, + string loss_collection = "losses", + string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) + { + return tf_with(ops.name_scope(scope, + "sparse_softmax_cross_entropy_loss", + (logits, labels, weights)), + name_scope => + { + scope = name_scope; + Tensor weights_tensor = null; + (labels, logits, weights_tensor) = _remove_squeezable_dimensions( + labels, logits, weights, expected_rank_diff: 1); + + var losses = nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, + logits: logits, + name: "xentropy"); + return compute_weighted_loss(losses, weights_tensor, scope, loss_collection, reduction: reduction); + }); + } + + public (Tensor, Tensor, Tensor) _remove_squeezable_dimensions(Tensor labels, + Tensor predictions, + float weights = 0, + int expected_rank_diff = 0) + { + (labels, predictions) = confusion_matrix.remove_squeezable_dimensions( + labels, predictions, expected_rank_diff: expected_rank_diff); + + if (weights > 0) + { + var weights_tensor = ops.convert_to_tensor(weights); + var labels_rank = labels.shape.ndim; + var weights_shape = weights_tensor.shape; + var weights_rank = weights_shape.ndim; + + if (labels_rank > -1 && weights_rank > -1) + { + // Use static rank. + var rank_diff = weights_rank - labels_rank; + if (rank_diff == 1) + weights = (float)array_ops.squeeze(weights_tensor, new int[] { -1 }); + return (labels, predictions, weights_tensor); + } + + // Use dynamic rank. + throw new NotImplementedException("_remove_squeezable_dimensions dynamic rank"); + } + + throw new NotImplementedException("_remove_squeezable_dimensions"); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs new file mode 100644 index 000000000..84ce56a4b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs @@ -0,0 +1,47 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + /// + /// Performs the average pooling on the input. + /// + public class AveragePoolFunction : IPoolFunction + { + public Tensor Apply(Tensor value, + int[] ksize, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null) + { + return tf_with(ops.name_scope(name, "AveragePool", value), scope => + { + name = scope; + value = ops.convert_to_tensor(value, name: "input"); + return gen_nn_ops.avg_pool( + value, + ksize: ksize, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs new file mode 100644 index 000000000..16cbd0010 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -0,0 +1,170 @@ +using System; +using System.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Operations; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Basic LSTM recurrent network cell. + /// The implementation is based on: http://arxiv.org/abs/1409.2329. + /// + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] + public class BasicLstmCell : LayerRnnCell + { + int _num_units; + float _forget_bias; + bool _state_is_tuple; + IActivation _activation; + LSTMStateTuple _state; + IVariableV1 _kernel; + IVariableV1 _bias; + string _WEIGHTS_VARIABLE_NAME = "kernel"; + string _BIAS_VARIABLE_NAME = "bias"; + + /// + /// Initialize the basic LSTM cell. + /// + /// The number of units in the LSTM cell. + /// + /// + /// + /// + /// + /// + public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, + IActivation activation = null, bool? reuse = null, string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) + { + inputSpec = new InputSpec(ndim: 2); + _num_units = num_units; + _forget_bias = forget_bias; + _state_is_tuple = state_is_tuple; + _activation = activation; + if (_activation == null) + _activation = tf.nn.tanh(); + } + + protected override void build(Shape input_shape) + { + var input_depth = input_shape.dims.Last(); + var h_depth = _num_units; + _kernel = add_weight(_WEIGHTS_VARIABLE_NAME, + shape: new int[] { (int)(input_depth + h_depth), 4 * _num_units }); + _bias = add_weight(_BIAS_VARIABLE_NAME, + shape: new[] { 4 * _num_units }, + initializer: tf.zeros_initializer); + built = true; + } + + public Tensor __call__(Tensor inputs, LSTMStateTuple state) + { + _state = state; + return base.__call__(inputs); + } + + /// + /// Long short-term memory cell (LSTM). + /// + /// + /// + /// + /// + protected Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + { + var one = constant_op.constant(1, dtype: dtypes.int32); + // Parameters of gates are concatenated into one multiply for efficiency. + Tensor c = null; + Tensor h = null; + if (_state_is_tuple) + (c, h) = ((Tensor)_state.c, (Tensor)_state.h); + else + { + // array_ops.split(value: state, num_or_size_splits: 2, axis: one); + throw new NotImplementedException("BasicLstmCell call"); + } + var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor()); + gate_inputs = nn_ops.bias_add(gate_inputs, _bias); + + // i = input_gate, j = new_input, f = forget_gate, o = output_gate + var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); + var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); + + var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); + // Note that using `add` and `multiply` instead of `+` and `*` gives a + // performance improvement. So using those at the cost of readability. + var new_c = gen_math_ops.add( + math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))), + math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j))); + + var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o)); + + + if (_state_is_tuple) + return new_c; + else + return array_ops.concat(new[] { new_c, new_h }, 1); + } + + public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (inputs != null) + throw new NotImplementedException("get_initial_state input is not null"); + + return zero_state(batch_size, dtype); + } + + /// + /// Return zero-filled state tensor(s). + /// + /// + /// + /// + private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype) + { + LSTMStateTuple output = null; + tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate + { + output = _zero_state_tensors(state_size, batch_size, dtype); + }); + + return output; + } + + private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) + { + if (state_size is LSTMStateTuple state_size_tuple) + { + var outputs = state_size_tuple.Flatten() + .Select(x => (int)x) + .Select(s => + { + var c = rnn_cell_impl._concat(batch_size, s); + var size = array_ops.zeros(c, dtype: dtype); + + var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); + size.set_shape(c_static); + + return size; + }).ToArray(); + + return new LSTMStateTuple(outputs[0], outputs[1]); + } + + throw new NotImplementedException("_zero_state_tensors"); + } + + public override object state_size + { + get + { + if (_state_is_tuple) + return new LSTMStateTuple(_num_units, _num_units); + else + return 2 * _num_units; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs new file mode 100644 index 000000000..3308aebb7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -0,0 +1,80 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] + public class BasicRnnCell : LayerRnnCell + { + int _num_units; + Func _activation; + + public override object state_size => _num_units; + public override int output_size => _num_units; + public IVariableV1 _kernel; + string _WEIGHTS_VARIABLE_NAME = "kernel"; + public IVariableV1 _bias; + string _BIAS_VARIABLE_NAME = "bias"; + + public BasicRnnCell(int num_units, + Func activation = null, + bool? reuse = null, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, + name: name, + dtype: dtype) + { + // Inputs must be 2-dimensional. + inputSpec = new InputSpec(ndim: 2); + + _num_units = num_units; + if (activation == null) + _activation = math_ops.tanh; + else + _activation = activation; + } + + protected override void build(Shape inputs_shape) + { + var input_depth = inputs_shape.dims[inputs_shape.ndim - 1]; + + _kernel = add_weight( + _WEIGHTS_VARIABLE_NAME, + shape: new int[] { (int)(input_depth + _num_units), _num_units }); + + _bias = add_weight( + _BIAS_VARIABLE_NAME, + shape: new[] { _num_units }, + initializer: tf.zeros_initializer); + + built = true; + } + + protected Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + { + // Most basic RNN: output = new_state = act(W * input + U * state + B). + var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); + var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); + gate_inputs = nn_ops.bias_add(gate_inputs, _bias); + var output = _activation(gate_inputs, null); + return new Tensors(output, output); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs new file mode 100644 index 000000000..d8cc0c25d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -0,0 +1,59 @@ +using System.Collections.Generic; + +namespace Tensorflow.Operations +{ + internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable, IFromMergeVars + { + /// + /// int32 scalar Tensor. + /// + public Tensor time { get; set; } + /// + /// List of `TensorArray`s that represent the output. + /// + public TensorArray[] output_ta_t { get; set; } + /// + /// nested tuple of vector tensors that represent the state. + /// + public Tensor state { get; set; } + + public BodyItemInRnnWhileLoop() + { + } + + public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) + { + this.time = time; + this.output_ta_t = output_ta_t; + this.state = state; + } + + public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) + => (item.time, item.output_ta_t, item.state); + + public object[] Flatten() + { + var elements = new List { time }; + elements.AddRange(output_ta_t); + elements.Add(state); + return elements.ToArray(); + } + + public BodyItemInRnnWhileLoop Pack(object[] sequences) + { + time = sequences[0] as Tensor; + output_ta_t = new[] { sequences[1] as TensorArray }; + state = sequences[2] as Tensor; + + return new BodyItemInRnnWhileLoop(time, output_ta_t, state); + } + + public BodyItemInRnnWhileLoop FromMergeVars(ITensorOrTensorArray[] mergeVars) + { + time = (Tensor)mergeVars[1]; + output_ta_t = new[] { (TensorArray)mergeVars[2] }; + state = (Tensor)mergeVars[3]; + return this; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Conv1dParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/Conv1dParams.cs new file mode 100644 index 000000000..4282a2791 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/Conv1dParams.cs @@ -0,0 +1,81 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Operations +{ + public class Conv1dParams + { + public string Name { get; set; } + + /// + /// An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`. + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// + public string DataFormat { get; set; } = "NHWC"; + + /// + /// Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. + /// A 4-D tensor. The dimension order is interpreted according to the value + /// + public Tensor Input { get; set; } + + /// + /// An integer vector representing the shape of `input` + /// + public Tensor InputSizes { get; set; } + + /// + /// A 4-D tensor of shape + /// + public IVariableV1 Filter { get; set; } + + /// + /// An integer vector representing the tensor shape of `filter` + /// + public Tensor FilterSizes { get; set; } + + /// + /// A `Tensor`. Must have the same type as `filter`. + /// 4-D with shape `[batch, out_height, out_width, out_channels]`. + /// + public Tensor OutBackProp { get; set; } + + /// + /// The stride of the sliding window for each + /// dimension of `input`. The dimension order is determined by the value of + /// `data_format`, see below for details. + /// + public int[] Strides { get; set; } + + /// + /// A `string` from: `"SAME", "VALID", "EXPLICIT"`. + /// + public string Padding { get; set; } + + public int[] ExplicitPaddings { get; set; } = new int[0]; + + public bool UseCudnnOnGpu { get; set; } = true; + + public int[] Dilations { get; set; } = new int[] { 1, 1, 1 }; + + public Conv1dParams() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs new file mode 100644 index 000000000..fa0d5bef6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs @@ -0,0 +1,81 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Operations +{ + public class Conv2dParams + { + public string Name { get; set; } + + /// + /// An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`. + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// + public string DataFormat { get; set; } = "NHWC"; + + /// + /// Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. + /// A 4-D tensor. The dimension order is interpreted according to the value + /// + public Tensor Input { get; set; } + + /// + /// An integer vector representing the shape of `input` + /// + public Tensor InputSizes { get; set; } + + /// + /// A 4-D tensor of shape + /// + public Tensor Filter { get; set; } + + /// + /// An integer vector representing the tensor shape of `filter` + /// + public Tensor FilterSizes { get; set; } + + /// + /// A `Tensor`. Must have the same type as `filter`. + /// 4-D with shape `[batch, out_height, out_width, out_channels]`. + /// + public Tensor OutBackProp { get; set; } + + /// + /// The stride of the sliding window for each + /// dimension of `input`. The dimension order is determined by the value of + /// `data_format`, see below for details. + /// + public int[] Strides { get; set; } + + /// + /// A `string` from: `"SAME", "VALID", "EXPLICIT"`. + /// + public string Padding { get; set; } + + public int[] ExplicitPaddings { get; set; } = new int[0]; + + public bool UseCudnnOnGpu { get; set; } = true; + + public int[] Dilations { get; set; } = new int[] { 1, 1, 1, 1 }; + + public Conv2dParams() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs new file mode 100644 index 000000000..ec70b1858 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs @@ -0,0 +1,131 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class ConvolutionInternal + { + ConvolutionalArgs args; + + string data_format => args.DataFormat; + string name; + string padding => args.Padding; + + public ConvolutionInternal(ConvolutionalArgs args) + { + this.args = args; + name = args.Name; + } + + public Tensor Apply(Tensors input, Tensor filters) + { + var filters_rank = filters.shape.ndim; + var inputs_rank = input.shape.ndim; + var num_spatial_dims = args.NumSpatialDims; + if (args.Rank == 1) + { + // Special case: Conv1D + num_spatial_dims = 1; + } + else if (num_spatial_dims == Unknown) + { + num_spatial_dims = filters_rank - 2; + } + + // Channel dimension. + var num_batch_dims = inputs_rank - num_spatial_dims - 1; + if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims)) + throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + + $"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); + + Tensor result = null; + tf_with(ops.name_scope(name, default_name: null), scope => + { + name = scope; + if (num_spatial_dims == 2) + { + var channel_index = num_batch_dims + num_spatial_dims; + var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index).ToArray(); + var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index).ToArray(); + + result = gen_nn_ops.conv2d( + input, + filters, + strides, + padding, + data_format: data_format, + dilations: dilations, + name: name + ); + } + else + { + var channel_first = data_format == "NCW"; + var spatial_start_dim = channel_first ? -2 : -3; + + var channel_index = channel_first ? 1 : 2; + var dilations = _get_sequence(args.DilationRate, 1, channel_index); + var strides = _get_sequence(args.Strides, 1, channel_index); + + strides.Insert(0, 1); + dilations.Insert(0, 1); + + input = array_ops.expand_dims(input, spatial_start_dim); + filters = array_ops.expand_dims(filters, 0); + + result = gen_nn_ops.conv2d( + input, + filters, + strides.ToArray(), + padding, + data_format: channel_first ? "NCHW" : "NHWC", + dilations: dilations.ToArray(), + name: name + ); + result = array_ops.squeeze(result, new[] { spatial_start_dim }); + } + }); + + return result; + } + + IList _get_sequence(int[] value, int n, int channel_index) + { + var seq = new List(); + + if (channel_index == 1) + { + seq.Add(1); + seq.Add(1); + seq.AddRange(value); + } + else + { + seq.Add(1); + seq.AddRange(value); + seq.Add(1); + } + + return seq; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs new file mode 100644 index 000000000..5826ad8b1 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs @@ -0,0 +1,23 @@ +namespace Tensorflow.Operations +{ + public class FusedBatchNormParams + { + public string Name { get; set; } + public Tensor YBackprop { get; set; } + public Tensor X { get; set; } + public Tensor Scale { get; set; } + public Tensor ReserveSpace1 { get; set; } + public Tensor ReserveSpace2 { get; set; } + public Tensor ReserveSpace3 { get; set; } + public float Epsilon { get; set; } + public string DataFormat { get; set; } + public bool IsTraining { get; set; } + + public FusedBatchNormParams() + { + Epsilon = 0.0001f; + DataFormat = "NHWC"; + IsTraining = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs new file mode 100644 index 000000000..a86233663 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs @@ -0,0 +1,31 @@ +namespace Tensorflow.Operations +{ + /// + /// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. + /// + /// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state + /// and `h` is the output. + /// + /// Only used when `state_is_tuple=True`. + /// + public class LSTMStateTuple : ICanBeFlattened + { + public object c; + public object h; + + public LSTMStateTuple(int c, int h) + { + this.c = c; + this.h = h; + } + + public LSTMStateTuple(Tensor c, Tensor h) + { + this.c = c; + this.h = h; + } + + public object[] Flatten() + => new[] { c, h }; + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs new file mode 100644 index 000000000..65de4fe90 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs @@ -0,0 +1,178 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ +using System; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] + public class LayerRnnCell : RnnCell + { + protected InputSpec inputSpec; + protected bool built; + protected Graph _graph; + + protected VariableScope _scope; + protected VariableScope _current_scope; + + protected bool? _reuse; + protected bool _use_resource_variables; + protected bool _keras_style; + + public LayerRnnCell(bool trainable = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool? _reuse = null) : base(_reuse: _reuse, + name: name, + dtype: dtype) + { + // For backwards compatibility, legacy layers do not use `ResourceVariable` + // by default. + this._use_resource_variables = false; + this._reuse = _reuse; + + // Avoid an incorrect lint error + this.built = false; + _keras_style = false; + } + + protected virtual void build(Shape inputs_shape) + { + + } + + public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) + { + var results = __call__(inputs, training: training); + return (results[0], results[1]); + } + + public Tensors __call__(Tensors inputs, + Tensor state = null, + Tensor training = null, + VariableScope scope = null) + { + _set_scope(scope); + _graph = ops._get_graph_from_inputs(inputs, graph: _graph); + + variable_scope scope_context_manager = null; + if (built) + { + scope_context_manager = tf.variable_scope(_scope, + reuse: true, + auxiliary_name_scope: false); + } + else + { + scope_context_manager = tf.variable_scope(_scope, + reuse: _reuse, + auxiliary_name_scope: false); + } + + Tensors outputs = null; + tf_with(scope_context_manager, scope2 => + { + _current_scope = scope2; + // Actually call layer + + }); + + + // Update global default collections. + + return outputs; + } + + protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) + { + foreach (var name in collection_list) + { + var collection = ops.get_collection_ref(name); + + foreach (var element in elements) + if (!collection.Contains(element)) + collection.Add(element); + } + } + + /// + /// Adds a new variable to the layer, or gets an existing one; returns it. + /// + /// + /// + /// + /// + /// + /// + /// + /// + protected virtual IVariableV1 add_weight(string name, + int[] shape, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool trainable = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + var default_graph = ops.get_default_graph(); + Graph init_graph = null; + IVariableV1[] existing_variables = null; + + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + + if (default_graph.building_function) + { + throw new NotImplementedException("add_weight"); + } + else + { + init_graph = default_graph; + existing_variables = variables.global_variables().ToArray(); + } + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + _set_scope(); + var reuse = built || (_reuse != null && _reuse.Value); + return tf.Variable(0); + } + + protected string _name_scope() + { + return _current_scope.original_name_scope; + } + + protected void _set_scope(VariableScope scope = null) + { + if (_scope == null) + { + if (_reuse.HasValue && _reuse.Value) + { + throw new NotImplementedException("_set_scope _reuse.HasValue"); + /*with(tf.variable_scope(scope == null ? _base_name : scope), + captured_scope => _scope = captured_scope);*/ + } + else + { + + } + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs new file mode 100644 index 000000000..149d2e889 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs @@ -0,0 +1,47 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + /// + /// Performs the max pooling on the input. + /// + public class MaxPoolFunction : IPoolFunction + { + public Tensor Apply(Tensor value, + int[] pool_size, + int[] strides, + string padding, + string data_format = "NHWC", + string name = null) + { + return tf_with(ops.name_scope(name, "MaxPool", value), scope => + { + name = scope; + return gen_nn_ops.max_pool( + value, + ksize: pool_size, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs new file mode 100644 index 000000000..9905d39c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -0,0 +1,192 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using Tensorflow.Common.Types; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; +using Tensorflow.Operations; +using Tensorflow.Train; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Abstract object representing an RNN cell. + /// + /// Every `RNNCell` must have the properties below and implement `call` with + /// the signature `(output, next_state) = call(input, state)`. The optional + /// third input argument, `scope`, is allowed for backwards compatibility + /// purposes; but should be left off for new subclasses. + /// + /// This definition of cell differs from the definition used in the literature. + /// In the literature, 'cell' refers to an object with a single scalar output. + /// This definition refers to a horizontal array of such units. + /// + /// An RNN cell, in the most abstract setting, is anything that has + /// a state and performs some operation that takes a matrix of inputs. + /// This operation results in an output matrix with `self.output_size` columns. + /// If `self.state_size` is an integer, this operation also results in a new + /// state matrix with `self.state_size` columns. If `self.state_size` is a + /// (possibly nested tuple of) Shape object(s), then it should return a + /// matching structure of Tensors having shape `[batch_size].concatenate(s)` + /// for each `s` in `self.batch_size`. + /// + [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] + public abstract class RnnCell : ILayer, IRnnCell + { + /// + /// Attribute that indicates whether the cell is a TF RNN cell, due the slight + /// difference between TF and Keras RNN cell. + /// + protected bool _is_tf_rnn_cell = false; + public virtual object state_size { get; } + + public virtual int output_size { get; } + public string Name { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public List InboundNodes => throw new NotImplementedException(); + + public List OutboundNodes => throw new NotImplementedException(); + + public List Layers => throw new NotImplementedException(); + + public bool Trainable => throw new NotImplementedException(); + + public List TrainableVariables => throw new NotImplementedException(); + public List TrainableWeights => throw new NotImplementedException(); + public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public List get_weights() => throw new NotImplementedException(); + public void set_weights(IEnumerable weights) => throw new NotImplementedException(); + public List NonTrainableWeights => throw new NotImplementedException(); + + public Shape OutputShape => throw new NotImplementedException(); + + public KerasShapesWrapper BatchInputShape => throw new NotImplementedException(); + + public KerasShapesWrapper BuildInputShape => throw new NotImplementedException(); + + public TF_DataType DType => throw new NotImplementedException(); + protected bool built = false; + public bool Built => built; + + public RnnCell(bool trainable = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool? _reuse = null) + { + _is_tf_rnn_cell = true; + } + + public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (inputs != null) + throw new NotImplementedException("get_initial_state input is not null"); + + return zero_state(batch_size, dtype); + } + + /// + /// Return zero-filled state tensor(s). + /// + /// + /// + /// + private Tensor zero_state(Tensor batch_size, TF_DataType dtype) + { + Tensor output = null; + tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate + { + output = _zero_state_tensors(state_size, batch_size, dtype); + }); + + return output; + } + + private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) + { + if (state_size is int state_size_int) + { + var output = nest.map_structure(s => + { + var c = rnn_cell_impl._concat(batch_size, s); + var size = array_ops.zeros(c, dtype: dtype); + + var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); + size.set_shape(c_static); + + return size; + }, state_size_int); + + return output; + } + + throw new NotImplementedException("_zero_state_tensors"); + } + + public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) + { + throw new NotImplementedException(); + } + + public int count_params() + { + throw new NotImplementedException(); + } + + public IKerasConfig get_config() + { + throw new NotImplementedException(); + } + + public void build(Shape input_shape) + { + throw new NotImplementedException(); + } + + public void build(KerasShapesWrapper input_shape) + { + throw new NotImplementedException(); + } + + public Trackable GetTrackable() { throw new NotImplementedException(); } + + public void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + throw new NotImplementedException(); + } + + public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) + { + throw new NotImplementedException(); + } + public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + throw new NotImplementedException(); + } + public INestStructure StateSize => throw new NotImplementedException(); + public INestStructure OutputSize => throw new NotImplementedException(); + public bool IsTFRnnCell => throw new NotImplementedException(); + public bool SupportOptionalArgs => throw new NotImplementedException(); + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs new file mode 100644 index 000000000..6b9f073c1 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -0,0 +1,472 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class rnn + { + /// + /// Creates a bidirectional recurrent neural network. + /// + public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw, + BasicLstmCell cell_bw, + Tensor[] inputs, + Tensor initial_state_fw = null, + Tensor initial_state_bw = null, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor sequence_length = null, + string scope = null) + { + if (inputs == null || inputs.Length == 0) + throw new ValueError("inputs must not be empty"); + + Tensor[] output_fw = null; + Tensor[] output_bw = null; + LSTMStateTuple output_state_fw = null; + LSTMStateTuple output_state_bw = null; + + tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate + { + // Forward direction + tf_with(tf.variable_scope("fw"), fw_scope => + { + (output_fw, output_state_fw) = static_rnn( + cell_fw, + inputs, + initial_state_fw, + dtype, + sequence_length, + scope: fw_scope); + }); + + // backward direction + tf_with(tf.variable_scope("bw"), bw_scope => + { + var reversed_inputs = _reverse_seq(inputs, sequence_length); + (output_bw, output_state_bw) = static_rnn( + cell_bw, + reversed_inputs, + initial_state_bw, + dtype, + sequence_length, + scope: bw_scope); + }); + }); + + output_bw = _reverse_seq(output_bw, sequence_length); + + var flat_outputs = zip(output_fw, output_bw) + .Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1)) + .ToArray(); + + return (flat_outputs, output_state_fw, output_state_bw); + } + + private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths) + { + if (lengths == null) + return input_seq.Reverse().ToArray(); + + throw new NotImplementedException("_reverse_seq"); + } + + public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell, + Tensor[] inputs, + Tensor initial_state, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor sequence_length = null, + VariableScope scope = null) + { + List outputs = new List(); + object state = null; + + // Create a new scope in which the caching device is either + // determined by the parent scope, or is set to place the cached + // Variable using the same placement as for the rest of the RNN. + if (scope == null) + tf_with(tf.variable_scope("rnn"), varscope => + { + throw new NotImplementedException("static_rnn"); + }); + else + tf_with(tf.variable_scope(scope), scope1 => + { + Dimension fixed_batch_size = null; + Dimension batch_size = null; + Tensor batch_size_tensor = null; + VariableScope varscope = scope1; + // Obtain the first sequence of the input + var first_input = inputs[0]; + if (first_input.shape.ndim != 1) + { + var input_shape = first_input.shape.with_rank_at_least(2); + fixed_batch_size = input_shape.dims[0]; + var flat_inputs = nest.flatten2(inputs); + foreach (var flat_input in flat_inputs) + { + input_shape = flat_input.shape.with_rank_at_least(2); + batch_size = tensor_shape.dimension_at_index(input_shape, 0); + var input_size = input_shape[new Slice(1)]; + fixed_batch_size.merge_with(batch_size); + foreach (var (i, size) in enumerate(input_size.dims)) + { + if (size < 0) + throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " + + "shape inference, but saw value None."); + } + } + } + else + fixed_batch_size = first_input.shape.with_rank_at_least(1).dims[0]; + + if (tensor_shape.dimension_value(fixed_batch_size) >= 0) + batch_size = tensor_shape.dimension_value(fixed_batch_size); + else + batch_size_tensor = array_ops.shape(first_input)[0]; + + if (initial_state != null) + state = initial_state; + else + { + state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); + } + + Tensor output = null; + if (state is LSTMStateTuple state_tuple) + { + foreach (var (time, input_) in enumerate(inputs)) + { + if (time > 0) + varscope.reuse_variables(); + if (sequence_length != null) + throw new NotImplementedException("static_rnn"); + + var results = cell.__call__(input_, state_tuple); + (output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1])); + outputs.Add(output); + } + } + }); + + return (outputs.ToArray(), state as LSTMStateTuple); + } + + public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, + Tensor sequence_length = null, Tensor initial_state = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) + { + return tf_with(tf.variable_scope("rnn"), scope => + { + VariableScope varscope = scope; + var flat_input = nest.flatten(inputs_tensor); + + if (!time_major) + { + flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); + flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); + } + + parallel_iterations = parallel_iterations ?? 32; + + if (sequence_length != null) + throw new NotImplementedException("dynamic_rnn sequence_length has value"); + + var batch_size = _best_effort_input_batch_size(flat_input); + + Tensor state = null; + if (initial_state != null) + state = initial_state; + else + state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor; + + var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); + + var (outputs, final_state) = _dynamic_rnn_loop( + cell, + inputs as Tensor, + state, + parallel_iterations: parallel_iterations.Value, + swap_memory: swap_memory, + sequence_length: sequence_length, + dtype: dtype); + + if (!time_major) + outputs = nest.map_structure(_transpose_batch_time, outputs); + + return (outputs, final_state); + }); + } + + /// + /// Internal implementation of Dynamic RNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, Tensor initial_state, + int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + var state = initial_state; + var state_size = cell.state_size; + + var flat_input = nest.flatten(inputs); + var flat_output_size = nest.flatten(cell.output_size); + + // Construct an initial output + var input_shape = array_ops.shape(flat_input[0]); + var time_steps = input_shape.slice(0); + var batch_size = _best_effort_input_batch_size(flat_input); + var inputs_got_shape = flat_input.Select(input_ => input_.shape.with_rank_at_least(3)).ToArray(); + + var dims = inputs_got_shape[0].dims.Take(2).ToArray(); + var (const_time_steps, const_batch_size) = (dims[0], dims[1]); + + foreach (var shape in inputs_got_shape) + { + if (shape.dims[2] == -1) + throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," + + " but saw value None."); + + var got_time_steps = shape.dims[0]; + var got_batch_size = shape.dims[1]; + + if (const_time_steps != got_time_steps) + throw new ValueError("Time steps is not the same for all the elements in the input in a " + + "batch."); + + if (const_batch_size != got_batch_size) + throw new ValueError("Batch_size is not the same for all the elements in the input."); + } + + Func _create_zero_arrays = (size_) => + { + var size = rnn_cell_impl._concat(batch_size, size_); + return array_ops.zeros( + array_ops.stack(size), dtype: _infer_state_dtype(dtype, state)); + }; + + // Prepare dynamic conditional copying of state & output + var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray(); + var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output); + + Tensor min_sequence_length = null, max_sequence_length = null; + if (sequence_length != null) + { + min_sequence_length = math_ops.reduce_min(sequence_length); + max_sequence_length = math_ops.reduce_max(sequence_length); + } + else + { + max_sequence_length = time_steps; + } + + var time = array_ops.constant(0, dtype: dtypes.int32, name: "time"); + + string base_name = null; + tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); + + Func _create_ta = (name, element_shape, dtype_) => + { + var ta = tf.TensorArray(dtype: dtype_, + size: time_steps, + element_shape: element_shape); + return ta; + }; + + bool in_graph_mode = true; + var output_ta = new List(); + var input_ta = new List(); + if (in_graph_mode) + { + foreach (var (i, out_size) in enumerate(flat_output_size)) + { + output_ta.Add(_create_ta($"output_{i}", + new Shape(const_batch_size).concatenate( + _maybe_tensor_shape_from_tensor(out_size)), + _infer_state_dtype(dtype, state))); + } + + foreach (var (i, flat_input_i) in enumerate(flat_input)) + { + input_ta.Add(_create_ta($"input_{i}", + new Shape(flat_input_i.dims.Skip(1).ToArray()), + flat_input_i.dtype)); + } + + input_ta = zip(input_ta, flat_input).Select(x => + { + var (ta, input_) = (x.Item1, x.Item2); + return ta.unstack(input_); + }).ToList(); + } + + // Make sure that we run at least 1 step, if necessary, to ensure + // the TensorArrays pick up the dynamic shape. + Tensor loop_bound = null; + if (in_graph_mode) + loop_bound = math_ops.minimum( + time_steps, math_ops.maximum(1, max_sequence_length)); + + Func cond = (item) => + { + return item.time < loop_bound; + }; + + // Take a time step of the dynamic RNN. + Func _time_step = (item) => + { + Tensor[] input_t = null; + var (time1, output_ta_t, state1) = (item.time, item.output_ta_t, item.state); + if (in_graph_mode) + { + input_t = input_ta.Select(ta => ta.read(time1)).ToArray(); + // Restore some shape information + foreach (var (input_, shape) in zip(input_t, inputs_got_shape)) + input_.shape = shape[new Slice(1)]; + } + else + { + // input_t = tuple(ta[time.numpy()] for ta in input_ta) + } + + var input_t_t = nest.pack_sequence_as2(structure: inputs, flat_sequence: input_t); + // Keras RNN cells only accept state as list, even if it's a single tensor. + // var is_keras_rnn_cell = _is_keras_rnn_cell(cell); + Tensor[] outputs = null; + if (sequence_length != null) + throw new NotImplementedException("sequence_length != null"); + /*else + outputs = cell.__call__(input_t_t, state: state1);*/ + + var (output, new_state) = (outputs[0], outputs[1]); + // Keras cells always wrap state as list, even if it's a single tensor. + // if(is_keras_rnn_cell && len(new_state)) == 1 + // Pack state if using state tuples + outputs = nest.flatten2(output).Select(x => x as Tensor).ToArray(); + + output_ta_t = zip(output_ta_t, outputs).Select(x => + { + var (ta, @out) = (x.Item1, x.Item2); + return ta.write(item.time, @out); + }).ToArray(); + + return new BodyItemInRnnWhileLoop(item.time + 1, output_ta_t, new_state); + }; + + var while_loop_result = control_flow_ops.while_loop( + cond: cond, + body: _time_step, + loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), + parallel_iterations: parallel_iterations, + maximum_iterations: time_steps, + swap_memory: swap_memory); + + (_, TensorArray[] output_final_ta, Tensor final_state) = (while_loop_result.time, while_loop_result.output_ta_t, while_loop_result.state); + + // Unpack final output if not using output tuples. + var final_outputs = output_final_ta.Select(ta => ta.stack()).ToArray(); + // Restore some shape information + foreach (var (output, output_size) in zip(final_outputs, flat_output_size)) + { + var shape = rnn_cell_impl._concat(new int[] { (int)const_time_steps, (int)const_batch_size }, output_size, @static: true); + output.shape = shape; + } + + return (final_outputs[0], final_state); + } + + private static Shape _maybe_tensor_shape_from_tensor(Tensor shape) + => shape.shape; + + private static Shape _maybe_tensor_shape_from_tensor(int shape) + => new Shape(shape); + + private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state) + { + if (explicit_dtype != TF_DataType.DtInvalid) + return explicit_dtype; + + throw new NotImplementedException("_infer_state_dtype"); + } + + /// + /// Transposes the batch and time dimensions of a Tensor. + /// + /// + /// + public static Tensor _transpose_batch_time(Tensor x) + { + var x_static_shape = x.shape; + if (x_static_shape.ndim == 1) + return x; + + var x_rank = array_ops.rank(x); + var con1 = new object[] + { + new []{1, 0 }, + math_ops.range(2, x_rank) + }; + var x_t = array_ops.transpose(x, array_ops.concat(con1, 0)); + + var dims = new long[] { x_static_shape.dims[1], x_static_shape.dims[0] } + .ToList(); + dims.AddRange(x_static_shape.dims.Skip(2)); + var shape = new Shape(dims.ToArray()); + + x_t.shape = shape; + + return x_t; + } + + /// + /// Get static input batch size if available, with fallback to the dynamic one. + /// + /// + /// + private static Tensor _best_effort_input_batch_size(List flat_input) + { + foreach (var input_ in flat_input) + { + var shape = input_.shape; + if (shape.ndim < 0) + continue; + if (shape.ndim < 2) + throw new ValueError($"Expected input tensor {input_.name} to have rank at least 2"); + + var batch_size = shape.dims[1]; + if (batch_size > -1) + throw new ValueError("_best_effort_input_batch_size batch_size > -1"); + //return batch_size; + } + + return array_ops.shape(flat_input[0]).slice(1); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs new file mode 100644 index 000000000..49fe843bd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -0,0 +1,86 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Operations +{ + public class rnn_cell_impl + { + public BasicRnnCell BasicRNNCell(int num_units) + => new BasicRnnCell(num_units); + + public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) + { + var p = prefix; + var p_static = tensor_util.constant_value(prefix); + if (p.ndim == 0) + p = array_ops.expand_dims(p, 0); + else if (p.ndim != 1) + throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); + + var s_tensor_shape = new Shape(suffix); + var s_static = s_tensor_shape.ndim > -1 ? + s_tensor_shape.dims : + null; + var s = s_tensor_shape.IsFullyDefined ? + constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : + null; + + if (@static) + { + if (p_static is null) return null; + var shape = new Shape(p_static).concatenate(s_static); + throw new NotImplementedException("RNNCell _concat"); + } + else + { + if (p is null || s is null) + throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); + return array_ops.concat(new[] { p, s }, 0); + } + } + + public static Shape _concat(int[] prefix, int suffix, bool @static = false) + { + var p = new Shape(prefix); + var p_static = prefix; + var p_tensor = p.IsFullyDefined ? constant_op.constant(p, dtype: dtypes.int32) : null; + + var s_tensor_shape = new Shape(suffix); + var s_static = s_tensor_shape.ndim > -1 ? + s_tensor_shape.dims : + null; + var s_tensor = s_tensor_shape.IsFullyDefined ? + constant_op.constant(s_tensor_shape.dims, dtype: dtypes.int32) : + null; + + if (@static) + { + if (p_static is null) return null; + var shape = new Shape(p_static).concatenate(s_static); + return shape; + } + else + { + if (p is null || s_tensor is null) + throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); + // return array_ops.concat(new[] { p_tensor, s_tensor }, 0); + throw new NotImplementedException(""); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 27ce59109..29e1f074f 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -1,78 +1,196 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using Google.Protobuf.Collections; +using System; using System.Collections.Generic; -using System.IO; -using System.Runtime.InteropServices; -using System.Text; +using System.Linq; +using Tensorflow.Functions; +using static Tensorflow.Binding; using static Tensorflow.OpDef.Types; namespace Tensorflow { public class OpDefLibrary { - public Dictionary _ops = new Dictionary(); - - public void add_op_list(OpList op_list) - { - foreach(var op_def in op_list.Op) - { - add_op(op_def); - } - } + public Operation _apply_op_helper(string op_type_name, string name = null, object args = null) + => _apply_op_helper(op_type_name, name: name, keywords: ConvertToDict(args)); - public void add_op(OpDef op_def) + public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary keywords = null) { - _ops[op_def.Name] = op_def; - } - - public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary keywords = null) - { - var op_def = _ops[op_type_name]; - - var status = new Status(); - var buffer = new Buffer(); - - var g = ops.get_default_graph(); + var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray()); + var op_def = g.GetOpDef(op_type_name); + // Default name if not specified. if (String.IsNullOrEmpty(name)) - { name = op_type_name; - } - string scope = g.unique_name(name) + "/"; + // Check for deprecation + if (op_def.Deprecation != null && op_def.Deprecation.Version > 0) + { + } + + var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) { if (attr_def.Type != "type") continue; var key = attr_def.Name; + if (attr_def.DefaultValue != null) + { + default_type_attr_map[key] = attr_def.DefaultValue.Type; + } } var attrs = new Dictionary(); - - // Perform input type inference var inputs = new List(); var input_types = new List(); - - foreach (var input_arg in op_def.InputArg) + object values = null; + + g.as_default(); + + var scope = ops.name_scope(name); + scope.__enter__(); + + var inferred_from = new Dictionary(); + var base_types = new List(); + var types = new List(); + string _scope_name = scope; + + // Perform input type inference + foreach (var (i, input_arg) in enumerate(op_def.InputArg)) { var input_name = input_arg.Name; + if (keywords.ContainsKey(input_name)) + values = keywords[input_name]; + else if (keywords.ContainsKey(input_name + "_")) + { + input_name += "_"; + values = keywords[input_name]; + } + else if (keywords.ContainsKey($"input_{i}")) { - inputs.Add(keywords[input_name] as Tensor); + values = keywords[$"input_{i}"]; } + else + throw new TypeError("No argument for input " + input_name); - if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + // Goals: + // * Convert values to Tensors if it contains constants. + // * Verify that values is a list if that matches the input_arg's + // type. + // * If the input_arg's type is determined by attrs, either set + // those attrs and validate those attr values are legal (if + // they have not yet been set) or validate the input matches + // the type indicated by the attrs (if they have already been + // inferred via an earlier input). + // * If the input_arg has an explicit type, make sure the input + // conforms. + + DataType dtype = DataType.DtInvalid; + DataType default_dtype = DataType.DtInvalid; + + if (values is Tensors tensors) { - attrs[input_arg.TypeAttr] = DataType.DtFloat; + values = (Tensor[])tensors; } - if (input_arg.IsRef) + if (_IsListParameter(input_arg)) { + if (!_IsListValue(values)) + throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); + if (input_arg.Type != DataType.DtInvalid) + dtype = input_arg.Type; + else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) + { + if (attrs.ContainsKey(input_arg.TypeAttr)) + dtype = (DataType)attrs[input_arg.TypeAttr]; + else + switch (values) + { + case Tensor[] values1: + dtype = values1[0].dtype.as_datatype_enum(); + break; + case object[] values1: + foreach (var t in values1) + if (t is Tensor tensor) + { + dtype = tensor.dtype.as_datatype_enum(); + break; + } + break; + default: + throw new NotImplementedException($"can't infer the dtype for {values.GetType()}"); + } + + if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) + default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; + } + + if (!input_arg.IsRef && dtype != DataType.DtInvalid) + dtype = dtype.as_base_dtype(); + values = ops.internal_convert_n_to_tensor(values as object[], + name: input_arg.Name, + dtype: dtype.as_tf_dtype(), + preferred_dtype: default_dtype.as_tf_dtype(), + as_ref: input_arg.IsRef); } else { - input_types.Add((keywords[input_name] as Tensor).dtype); + if (input_arg.Type != DataType.DtInvalid) + dtype = input_arg.Type; + else if (attrs.ContainsKey(input_arg.TypeAttr)) + dtype = (DataType)attrs[input_arg.TypeAttr]; + else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) + dtype = DataType.DtString; + else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) + default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; + + var value = ops.convert_to_tensor(values, + name: input_name, + dtype: dtype.as_tf_dtype(), + as_ref: input_arg.IsRef, + preferred_dtype: default_dtype.as_tf_dtype()); + + //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + //attrs[input_arg.TypeAttr] = values.dtype; + + values = new Tensor[] { value }; } + + if (values is Tensor[] values2) + { + types = values2.Select(x => x.dtype).ToList(); + inputs.AddRange(values2); + base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList(); + } + else throw new NotImplementedException("_IsListParameter"); + + SetAttrs(op_type_name, + input_arg, + op_def, + attrs, + inferred_from, + types, + base_types, + input_types, + values); } // Process remaining attrs @@ -86,53 +204,298 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "", // Convert attr values to AttrValue protos. var attr_protos = new Dictionary(); - foreach (var attr_def in op_def.Attr) + foreach (AttrDef attr_def in op_def.Attr) { var key = attr_def.Name; - var value = attrs[key]; - var attr_value = new AttrValue(); - - switch (attr_def.Type) + if (attrs.ContainsKey(key)) { - case "type": - attr_value.Type = _MakeType(value, attr_def); - break; - case "shape": - attr_value.Shape = new TensorShapeProto(); - break; + attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); + } + else + { + if (attr_def.DefaultValue == null) + { + throw new TypeError("Missing required positional argument " + key); + } } - - attr_protos[key] = attr_value; } + attrs.Clear(); + // Determine output types (possibly using attrs) var output_types = new List(); foreach (var arg in op_def.OutputArg) { - if (!String.IsNullOrEmpty(arg.NumberAttr)) + types = new List(); + if (!string.IsNullOrEmpty(arg.NumberAttr)) { } - else if (!String.IsNullOrEmpty(arg.TypeAttr)) + else if (!string.IsNullOrEmpty(arg.TypeAttr)) { - output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); + types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; } + + if (arg.IsRef) + types = types.Select(x => x.as_ref()).ToList(); + + output_types.AddRange(types); } + // We add an explicit colocation constraint between + // the newly created op and any of its reference-typed inputs. + var must_colocate_inputs = zip(op_def.InputArg, inputs) + .Where(x => x.Item1.IsRef) + .Select(x => x.Item2) + .ToArray(); + + _MaybeColocateWith(must_colocate_inputs); + // Add Op to graph - var op = g.create_op(op_type_name, inputs, output_types.ToArray(), - name: scope, + var ret_op = g.create_op(op_type_name, + inputs.ToArray(), + output_types.ToArray(), + name: _scope_name, input_types: input_types.ToArray(), attrs: attr_protos, op_def: op_def); - return op; + scope.__exit__(); + + g.Exit(); + + return ret_op; + } + + private void _MaybeColocateWith(ITensorOrOperation[] inputs) + { + + } + + private void SetAttrs(string op_type_name, + ArgDef input_arg, + OpDef op_def, + Dictionary attrs, + Dictionary inferred_from, + List types, + List base_types, + List input_types, + object values) + { + var input_name = input_arg.Name; + + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) + { + if (attrs.ContainsKey(input_arg.NumberAttr)) + { + + } + else + { + if(values is Tensor[] tensors) + { + var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr); + if (num_attr.HasMinimum && tensors.Length < num_attr.Minimum) + throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " + + $"than minimum length {num_attr.Minimum}"); + + attrs[input_arg.NumberAttr] = Convert.ToInt64(tensors.Length); + inferred_from[input_arg.NumberAttr] = input_name; + } + } + + // All tensors must have the same base type. + if (input_arg.Type != DataType.DtInvalid) + { + + } + else + { + attrs[input_arg.TypeAttr] = base_types[0]; + inferred_from[input_arg.TypeAttr] = input_name; + var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); + } + } + else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) + { + var attr_value = base_types[0]; + if (attrs.ContainsKey(input_arg.TypeAttr)) + { + + } + else + { + attrs[input_arg.TypeAttr] = attr_value; + inferred_from[input_arg.TypeAttr] = input_name; + } + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + var attr_value = base_types; + if (attrs.ContainsKey(input_arg.TypeListAttr)) + { + + } + else + { + attrs[input_arg.TypeListAttr] = attr_value; + inferred_from[input_arg.TypeListAttr] = input_name; + } + } + + if (input_arg.IsRef) + input_types.AddRange(types); + else + input_types.AddRange(base_types); + } + + public ByteString _MakeStr(string value, AttrDef attr_def) + { + return ByteString.CopyFromUtf8(value ?? string.Empty); + } + + public TensorShapeProto _MakeShape(Shape shape, AttrDef attr_def) + { + return shape.as_proto(); + } + + public DataType _MakeType(TF_DataType v, AttrDef attr_def) + { + return v.as_base_dtype().as_datatype_enum(); + } + + private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value) + { + var attr_value = new AttrValue(); + + if (attr_def.Type.StartsWith("list(")) + { + if (attr_def.HasMinimum) +#pragma warning disable CS0642 // Possible mistaken empty statement + ; +#pragma warning restore CS0642 // Possible mistaken empty statement + attr_value.List = new AttrValue.Types.ListValue(); + } + + switch (attr_def.Type) + { + case "string": + attr_value.S = _MakeStr((string)value, attr_def); + break; + case "type": + attr_value.Type = _MakeType((TF_DataType)value, attr_def); + break; + case "list(type)": + attr_value.List.Type.AddRange((value as IList).Select(x => _MakeType(x, attr_def))); + break; + case "list(float)": + if (value != null) + attr_value.List.F.AddRange((value as IEnumerable).ToArray()); + break; + case "list(int)": + if (value != null) + attr_value.List.I.AddRange((value as IEnumerable).Select(x => Convert.ToInt64(x))); + break; + case "bool": + attr_value.B = (bool)value; + break; + case "float": + attr_value.F = (float)value; + break; + case "int": + if (value is long value_long) + attr_value.I = value_long; + else + attr_value.I = Convert.ToInt64(value); + if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum) + throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); + break; + case "shape": + if (value == null && attr_def.DefaultValue != null) + attr_value.Shape = attr_def.DefaultValue.Shape; + + if (value is Shape val1) + attr_value.Shape = val1.as_proto(); + else if (value is long[] val2) + attr_value.Shape = tensor_util.as_shape(val2); + else if (value is int[] val3) + attr_value.Shape = tensor_util.as_shape(val3); + + break; + case "list(shape)": + attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); + break; + case "func": + attr_value.Func = _MakeFunc(value, attr_def.Name); + break; + case "list(func)": + attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); + break; + case "list(string)": + attr_value.List.S.AddRange((value as IEnumerable).Select(x => ByteString.CopyFromUtf8(x))); + break; + default: + throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); + } + + return attr_value; + } + + private NameAttrList _MakeFunc(object func, string arg_name) + { + if(func is NameAttrList attrList) + { + return attrList; + } + NameAttrList fn_attr; + if(func is string funcStr) + { + fn_attr = new NameAttrList() { Name = funcStr }; + } + else if(func is ConcreteFunction concrete) + { + concrete.AddTograph(ops.get_default_graph()); + fn_attr = concrete.AsNameAttrList; + } + else if(func is EagerDefinedFunction eager) + { + eager.AddToGraph(ops.get_default_graph()); + fn_attr = new NameAttrList() { Name = eager.Name }; + } + else + { + throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); + } + return fn_attr; + } + + private List _MakeFuncList(object funcList, string arg_name) + { + List res = new List(); + if(funcList is IEnumerable enumerable) + { + foreach(var func in enumerable) + { + res.Add(_MakeFunc(func, arg_name)); + } + } + return res; + } + + private bool _IsListParameter(ArgDef arg) + { + if (!String.IsNullOrEmpty(arg.NumberAttr)) + return true; + else if (!String.IsNullOrEmpty(arg.TypeListAttr)) + return true; + else + return false; } - public DataType _MakeType(Object v, AttrDef attr_def) + private bool _IsListValue(object v) { - return DataType.DtFloat; + return v.GetType().IsArray; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs new file mode 100644 index 000000000..89145e413 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -0,0 +1,67 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; + +namespace Tensorflow +{ + public partial class Operation + { + private ControlFlowContext _control_flow_context; + + /// + /// Add this op to its control flow context. + /// + /// This may add new ops and change this op's inputs. self.inputs must be + /// available before calling this method. + /// + public void _control_flow_post_processing() + { + foreach (Tensor input_tensor in inputs) + control_flow_util.CheckInputFromValidContext(this, input_tensor.op); + + if (_control_flow_context != null) + _control_flow_context.AddOp(this); + } + + public void _add_control_input(Operation op) + { + // c_api.TF_AddControlInput(_opDesc, op); + //c_api.AddControlInput(graph, _handle, op); + } + + public void _add_control_inputs(Operation[] ops) + { + foreach (var op in ops) + _add_control_input(op); + } + + public void _set_control_flow_context(ControlFlowContext ctx) + { + _control_flow_context = ctx; + } + + public ControlFlowContext _get_control_flow_context() + { + return _control_flow_context; + } + + public WhileContext GetWhileContext() + { + return _control_flow_context as WhileContext; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs new file mode 100644 index 000000000..ec49f8505 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs @@ -0,0 +1,53 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + /// + /// Convert to other datatype implicitly + /// + public partial class Operation + { + // make sure the new op is in the same graph instance + public static implicit operator Operation(IntPtr handle) + => new Operation(handle); + + public static implicit operator IntPtr(Operation op) + => op._handle; + public static implicit operator Tensor(Operation op) + => op.output; + + public override string ToString() + { + return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $""; + } + + public override bool Equals(object obj) + { + switch (obj) + { + case IntPtr val: + return val == _handle; + case Operation val: + return val._handle == _handle; + } + + return base.Equals(obj); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs new file mode 100644 index 000000000..9aa6fde22 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -0,0 +1,107 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using System.Runtime.InteropServices; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + + // from ops.py + public partial class Operation + { + public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); + public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); + + public int InputListLength(string name) + { + int num = 0; + num = c_api.TF_OperationInputListLength(_handle, name, tf.Status); + tf.Status.Check(true); + return num; + } + public int NumInputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumInputs(_handle); + private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); + + protected InputList _inputs_val; + + public virtual InputList inputs + { + get + { + if (_inputs_val == null) + { + var retval = new Tensor[NumInputs]; + + for (int i = 0; i < NumInputs; i++) + { + var tf_output = Input(i); + var op = GetOperation(tf_output.oper); + retval[i] = op.outputs[tf_output.index]; + } + + _inputs_val = new InputList(retval); + } + + return _inputs_val; + } + } + + public int NumControlInputs + => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationNumControlInputs(_handle); + + Operation[] _control_inputs; + /// + /// The `Operation` objects on which this op has a control dependency. + /// + /// Before this op is executed, TensorFlow will ensure that the + /// operations in `self.control_inputs` have finished executing.This + /// mechanism can be used to run ops sequentially for performance + /// reasons, or to ensure that the side effects of an op are observed + /// in the correct order. + /// + public Operation[] control_inputs + { + get + { + if (_control_inputs == null || _control_inputs.Length == 0) + _control_inputs = GetControlInputs(); + return _control_inputs; + } + } + + public unsafe Operation[] GetControlInputs() + { + var control_inputs = new Operation[NumControlInputs]; + + if (NumControlInputs > 0) + { + IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); + c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs); + for (int i = 0; i < NumControlInputs; i++) + { + var handle = control_input_handle + Marshal.SizeOf() * i; + control_inputs[i] = new Operation(*(IntPtr*)handle); + } + Marshal.FreeHGlobal(control_input_handle); + } + + return control_inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs new file mode 100644 index 000000000..e6e59fe15 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Operation + { + /// + /// Get operation by handle + /// + /// + /// + public Operation GetOperation(IntPtr handle) + { + var nodes = tf.get_default_graph()._nodes_by_name; + foreach (var node in nodes.Values) + { + if (node is Operation op) + { + if (op == handle) + return op; + } + } + + return new Operation(handle); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs new file mode 100644 index 000000000..2329a4786 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -0,0 +1,92 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using System.Runtime.InteropServices; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Operation + { + public int NumOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumOutputs(_handle); + public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); + + public int OutputListLength(string name) + { + int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status); + tf.Status.Check(true); + + return num; + } + + internal Tensor[] _outputs; + public virtual Tensor[] outputs => _outputs; + public Tensor output => _outputs.FirstOrDefault(); + + public int NumControlOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumControlOutputs(_handle); + + public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); + + public TF_Output this[int index] => _tf_output(index); + + /// + /// List this operation's output types. + /// + public TF_DataType[] _output_types + { + get + { + var output_types = range(NumOutputs) + .Select(i => OutputType(i)) + .ToArray(); + return output_types; + } + } + + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) + { + var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); + int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); + var consumers = new TF_Input[num]; + var inputptr = (TF_Input*)handle; + for (int i = 0; i < num; i++) + consumers[i] = *(inputptr + i); + Marshal.FreeHGlobal(handle); + return consumers; + } + + public unsafe Operation[] GetControlOutputs() + { + var control_outputs = new Operation[NumControlOutputs]; + + if (NumControlOutputs > 0) + { + IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); + c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs); + for (int i = 0; i < NumControlOutputs; i++) + { + var handle = control_output_handle + Marshal.SizeOf() * i; + control_outputs[i] = new Operation(*(IntPtr*)handle); + } + Marshal.FreeHGlobal(control_output_handle); + } + + return control_outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 70ed5d582..2105c53fa 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,67 +1,444 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; using System.Collections.Generic; -using System.Text; -using TF_DataType = Tensorflow.DataType; +using System.Linq; +using Tensorflow.Util; +using static Tensorflow.Binding; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using System.Diagnostics; namespace Tensorflow { - public class Operation + /// + /// Represents a graph node that performs computation on tensors. + /// + /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or + /// more `Tensor` objects as input, and produces zero or more `Tensor` + /// objects as output. Objects of type `Operation` are created by + /// calling an op constructor(such as `tf.matmul`) + /// or `tf.Graph.create_op`. + /// + /// For example `c = tf.matmul(a, b)` creates an `Operation` of type + /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` + /// as output. + /// + /// After the graph has been launched in a session, an `Operation` can + /// be executed by passing it to + /// `tf.Session.run`. + /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. + /// + public partial class Operation : ITensorOrOperation { - private Graph _graph; + protected IntPtr _handle; // _c_op in python + + protected Graph _graph; + + internal Func _gradient_function; + + public string type => OpType; + public Graph graph => _graph; - public IntPtr _c_op; + public int _id => _id_value; - private int _id_value; - public string name; - private Tensor[] _outputs; - public Tensor[] outputs => _outputs; - public Tensor[] inputs; - public Operation(Graph g, string opType, string oper_name) + public int _id_value { get; set; } + public Operation op => this; + public TF_DataType dtype => output.dtype; + public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); + + public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); + + //private OperationDescription _op_desc; + + public NodeDef node_def => GetNodeDef(); + protected Operation() { } + + public Operation(IntPtr handle, Graph g = null) { - _graph = g; + if (handle == IntPtr.Zero) + return; - var status = new Status(); + _handle = handle; + _graph = g ?? ops.get_default_graph(); + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); - var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); - c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); - c_api.TF_FinishOperation(desc, status.Handle); + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = _graph._get_control_flow_context(); + + // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. } - public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + /*public Operation(Graph g, string opType, string oper_name) { _graph = g; + var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); + c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); + lock (Locks.ProcessWide) + using (var status = new Status()) + { + _handle = c_api.TF_FinishOperation(_operDesc, status); + status.Check(true); + } + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); + }*/ + + /// + /// Creates an `Operation`. + /// + /// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. + /// `Graph`. The parent graph. + /// list of `Tensor` objects. The inputs to this `Operation`. + /// list of `DType` objects. + /// + /// list of operations or tensors from which to have a + /// control dependency. + /// + /// + /// List of `DType` objects representing the + /// types of the tensors accepted by the `Operation`. By default + /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect + /// reference-typed inputs must specify these explicitly. + /// + /// + /// + public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + { + _graph = g; + // Build the list of control inputs. + var control_input_ops = new List(); + if (control_inputs != null) + { + foreach (var c in control_inputs) + { + switch (c) + { + case Operation c1: + control_input_ops.Add(c1); + break; + case Tensor tensor: + control_input_ops.Add(tensor.op); + break; + // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented + //case IndexedSlices islices: + // control_input_ops.Add(islices.op); + // break; + default: + throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); + } + } + } + _id_value = _graph._next_id(); - _c_op = ops._create_c_op(g, node_def, inputs); - var num_outputs = c_api.TF_OperationNumOutputs(_c_op); - _outputs = new Tensor[num_outputs]; - for (int i = 0; i < num_outputs; i++) + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); + + // This will be set by self.inputs. + if (op_def == null) + op_def = g.GetOpDef(node_def.Op); + + (_handle, _) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); + + // Initialize self._outputs. + output_types = new TF_DataType[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + output_types[i] = OutputType(i); + + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, output_types[i]); + + graph._add_op(this); + + if (_handle != IntPtr.Zero) + _control_flow_post_processing(); + } + + public void run(FeedItem[] feed_dict = null, Session session = null) + { + ops._run_using_default_session(this, feed_dict, graph, session); + } + + public virtual T get_attr(string name) + { + if (typeof(T).IsValueType) + { + return (T)Convert.ChangeType(get_attr(name), typeof(T)); + } + else + { + return (T)get_attr(name); + } + } + + internal unsafe TF_DataType _get_attr_type(string name) + { + Status status = new(); + TF_DataType result; + c_api.TF_OperationGetAttrType(_handle, name, new IntPtr(&result), status); + status.Check(true); + return result; + } + + internal unsafe long _get_attr_int(string name) + { + long result; + c_api.TF_OperationGetAttrInt(_handle, name, new IntPtr(&result), tf.Status); + tf.Status.Check(true); + return result; + } + + internal unsafe bool _get_attr_bool(string name) + { + Status status = new(); + bool result; + c_api.TF_OperationGetAttrBool(_handle, name, new IntPtr(&result), status); + status.Check(true); + return result; + } + + public virtual T[] get_attr_list(string name) + { + if (tf.executing_eagerly()) + return (T[])get_attr(name); + + var buf = new Buffer(); + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); + tf.Status.Check(true); + + var x = AttrValue.Parser.ParseFrom(buf.ToArray()); + + string oneof_value = x.ValueCase.ToString(); + if (string.IsNullOrEmpty(oneof_value)) + return null; + + switch (typeof(T).Name) + { + case nameof(Int32): + return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); + case nameof(Int64): + return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); + default: + return null; + } + } + + public virtual object get_attr(string name) + { + var buf = new Buffer(); + Status status = new(); + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + var tf_buffer = c_api.TF_GetBuffer(buf); + + var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan()); + + var oneof_value = x.ValueCase; + if (oneof_value == AttrValue.ValueOneofCase.None) + return new object[0]; + + if(oneof_value == AttrValue.ValueOneofCase.List) + { + if (x.List.S is not null && x.List.S.Count > 0) + { + return x.List.S.Select(x => x.ToStringUtf8()).ToArray(); + } + else if (x.List.I is not null && x.List.I.Count > 0) + { + return x.List.I.ToArray(); + } + else if (x.List.F is not null && x.List.F.Count > 0) + { + return x.List.F.ToArray(); + } + else if (x.List.B is not null && x.List.B.Count > 0) + { + return x.List.B.ToArray(); + } + else if (x.List.Shape is not null && x.List.Shape.Count > 0) + { + return x.List.Shape.ToArray(); + } + else if (x.List.Tensor is not null && x.List.Tensor.Count > 0) + { + return x.List.Tensor.ToArray(); + } + else if (x.List.Func is not null && x.List.Func.Count > 0) + { + return x.List.Func.ToArray(); + } + else if (x.List.Type is not null && x.List.Type.Count > 0) + { + return x.List.Type.Select(x => x.as_tf_dtype()).ToArray(); + } + else + { + return null; + } + } + if(oneof_value == AttrValue.ValueOneofCase.Type) { - _outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT); + return dtypes.as_tf_dtype(x.Type); } + return ProtoUtils.GetSingleAttrValue(x, oneof_value); + } - _graph._add_op(this); + public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) + { + return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); } - public object get_attr(string name) + [Obsolete("The implementation is not complete.")] + internal void _set_device_from_string(string device_str) { - object ret = null; + // TODO(Rinne): complete it with new C API `SetRequestedDevice`. + //c_api.TF_SetDevice(_handle, device_str); + } - var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; + [Obsolete("The implementation is not complete.")] + internal void _set_device(string device) + { + _set_device_from_string(device); + } + + private NodeDef GetNodeDef() + { + var buffer = new Buffer(); + c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status); + tf.Status.Check(throwException: true); + return NodeDef.Parser.ParseFrom(buffer.ToArray()); + } + + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + _assert_same_graph(tensor); + + // var input = _tf_input(index); + // var output = tensor._as_tf_output(); + + // Reset cached inputs. + _inputs_val = null; + // _node_def = null; + // after the c_api call next time _inputs is accessed + // the updated inputs are reloaded from the c_api + // lock (Locks.ProcessWide) + // { + // disable + // c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle); + //var updated_inputs = inputs; + // tf.Status.Check(); + // } + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } - switch (name) + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + return new TF_Output(_handle, output_idx); + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + return new TF_Input(_handle, input_idx); + } + + public NDArray numpy() => throw new NotImplementedException(""); + + internal void _add_outputs(TF_DataType[] types, Shape[] shapes) + { + Debug.Assert(types.Length == shapes.Length); + int orig_num_outputs = this.outputs.Length; + var new_outputs = new List(_outputs); + + // Since the `_outputs` is defined as `Array`, when we add new output, we + // have to create a new array, which brings some performance concerns. + // In the future maybe the type of `outputs` should be reconsidered. + for(int i = 0; i < types.Length; i++) + { + var t = new Tensor(this, orig_num_outputs + i, types[i]); + t.shape = shapes[i]; + new_outputs.Add(t); + } + _outputs = new_outputs.ToArray(); + } + + internal void _set_func_attr(string attr_name, string func_name) + { + var func = new NameAttrList() { Name = func_name }; + _set_attr(attr_name, new AttrValue() { Func = func }); + } + + internal void _set_type_list_attr(string attr_name, DataType[] types) + { + if(types is null || types.Length == 0) { - case "dtype": - ret = _outputs[0]; - break; - case "shape": - ret = new TensorShapeProto(); - break; + return; } + var type_list = new AttrValue.Types.ListValue(); + type_list.Type.AddRange(types); + _set_attr(attr_name, new AttrValue() { List = type_list }); + } - return ret; + internal void _set_attr(string attr_name, AttrValue attr_value) + { + var buffer = new Buffer(attr_value.ToByteArray()); + try + { + _set_attr_with_buf(attr_name, buffer); + } + finally + { + buffer.Release(); + } + } + + internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) + { + Status status = new(); + c_api.TF_SetAttr(graph, _handle, attr_name, attr_buf, status); + status.Check(true); } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs new file mode 100644 index 000000000..28df548dd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs @@ -0,0 +1,66 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public class OperationDescription + { + private IntPtr _handle; + public IntPtr op => _handle; + + public OperationDescription(Graph graph, string opType, string opName) + { + _handle = c_api.TF_NewOperation(graph, opType, opName); + } + + public OperationDescription(IntPtr handle) + { + _handle = handle; + } + + public void AddInputList(params TF_Output[] inputs) + { + c_api.TF_AddInputList(_handle, inputs, inputs.Length); + } + + public void SetAttrType(string attr_name, TF_DataType value) + { + c_api.TF_SetAttrType(_handle, attr_name, value); + } + + public void SetAttrShape(string attr_name, long[] dims) + { + c_api.TF_SetAttrShape(_handle, attr_name, dims, dims.Length); + } + + public Operation FinishOperation(Status status) + { + return c_api.TF_FinishOperation(_handle, status); + } + + public static implicit operator OperationDescription(IntPtr handle) + { + return new OperationDescription(handle); + } + + public static implicit operator IntPtr(OperationDescription desc) + { + return desc._handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs new file mode 100644 index 000000000..0f824b9bf --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs @@ -0,0 +1,41 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; + +namespace Tensorflow.Queues +{ + public class FIFOQueue : QueueBase + { + public FIFOQueue(int capacity, + TF_DataType[] dtypes, + Shape[] shapes, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs new file mode 100644 index 000000000..d18f90220 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; + +namespace Tensorflow.Queues +{ + /// + /// A FIFOQueue that supports batching variable-sized tensors by padding. + /// + public class PaddingFIFOQueue : QueueBase + { + public PaddingFIFOQueue(int capacity, + TF_DataType[] dtypes, + Shape[] shapes, + string[] names = null, + string shared_name = null, + string name = "padding_fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs new file mode 100644 index 000000000..e54427bc8 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs @@ -0,0 +1,82 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + public class PriorityQueue : QueueBase + { + public PriorityQueue(int capacity, + TF_DataType[] dtypes, + Shape[] shapes, + string[] names = null, + string shared_name = null, + string name = "priority_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.priority_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + + var dtypes1 = dtypes.ToList(); + dtypes1.Insert(0, TF_DataType.TF_INT64); + _dtypes = dtypes1.ToArray(); + + var shapes1 = shapes.ToList(); + shapes1.Insert(0, Shape.Null); + _shapes = shapes1.ToArray(); + } + + public Operation enqueue_many(long[] indexes, T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor1 = _check_enqueue_dtypes(indexes); + var vals_tensor2 = _check_enqueue_dtypes(vals); + + var tensors = new List(); + tensors.AddRange(vals_tensor1); + tensors.AddRange(vals_tensor2); + + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); + }); + } + +#pragma warning disable CS0108 // Member hides inherited member; missing new keyword + public Tensor[] dequeue(string name = null) +#pragma warning restore CS0108 // Member hides inherited member; missing new keyword + { + Tensor[] ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); + + return ret; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs new file mode 100644 index 000000000..992646eee --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs @@ -0,0 +1,121 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + public class QueueBase + { + protected TF_DataType[] _dtypes; + protected Shape[] _shapes; + protected string[] _names; + protected Tensor _queue_ref; + protected string _name; + + public QueueBase(TF_DataType[] dtypes, Shape[] shapes, string[] names) + { + _dtypes = dtypes; + _shapes = shapes; + _names = names; + } + + public Operation enqueue(Tensor val, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope => + { + var vals = new[] { val }; + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope); + else + return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope); + }); + } + + public Operation enqueue_many(T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor = _check_enqueue_dtypes(vals); + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, vals_tensor, name: scope); + }); + } + + protected Tensor[] _check_enqueue_dtypes(object vals) + { + var tensors = new List(); + + switch (vals) + { + case int[][] vals1: + { + int i = 0; + foreach (var (val, dtype) in zip(vals1, _dtypes)) + tensors.Add(ops.convert_to_tensor(val, dtype: dtype, name: $"component_{i++}")); + } + break; + + default: + var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; + tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); + break; + } + + return tensors.ToArray(); + } + + /// + /// Dequeues one element from this queue. + /// + /// + /// + public Tensor dequeue(string name = null) + { + Tensor ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name)[0]; + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name)[0]; + + return ret; + } + + public Tensor[] dequeue_many(int n, string name = null) + { + if (name == null) + name = $"{_name}_DequeueMany"; + + var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name); + //var op = ret[0].op; + //var cv = tensor_util.constant_value(op.inputs[1]); + //var batch_dim = new Dimension(cv); + + return _dequeue_return_value(ret); + } + + public Tensor[] _dequeue_return_value(Tensor[] tensors) + { + if (_names != null) + throw new NotImplementedException(""); + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs new file mode 100644 index 000000000..3f15c593a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs @@ -0,0 +1,54 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; + +namespace Tensorflow.Queues +{ + /// + /// Create a queue that dequeues elements in a random order. + /// + public class RandomShuffleQueue : QueueBase + { + public RandomShuffleQueue(int capacity, + int min_after_dequeue, + TF_DataType[] dtypes, + Shape[] shapes, + string[] names = null, + int? seed = null, + string shared_name = null, + string name = "random_shuffle_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + var (seed1, seed2) = random_seed.get_seed(seed); + if (!seed1.HasValue && !seed2.HasValue) + (seed1, seed2) = (0, 0); + + + _queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + min_after_dequeue: min_after_dequeue, + seed: seed1.Value, + seed2: seed2.Value, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs b/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs new file mode 100644 index 000000000..9e0619454 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Regularizers/L1.cs @@ -0,0 +1,33 @@ +using System; + +using Tensorflow.Keras; + +namespace Tensorflow.Operations.Regularizers +{ + public class L1 : IRegularizer + { + float _l1; + private readonly Dictionary _config; + + public string ClassName => "L1"; + public virtual IDictionary Config => _config; + + public L1(float l1 = 0.01f) + { + // l1 = 0.01 if l1 is None else l1 + // validate_float_arg(l1, name = "l1") + // self.l1 = ops.convert_to_tensor(l1) + this._l1 = l1; + + _config = new(); + _config["l1"] = _l1; + } + + + public Tensor Apply(RegularizerArgs args) + { + //return self.l1 * ops.sum(ops.absolute(x)) + return _l1 * math_ops.reduce_sum(math_ops.abs(args.X)); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Regularizers/L1L2.cs b/src/TensorFlowNET.Core/Operations/Regularizers/L1L2.cs new file mode 100644 index 000000000..e3af00eb5 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Regularizers/L1L2.cs @@ -0,0 +1,48 @@ +using System; + +using Tensorflow.Keras; + +namespace Tensorflow.Operations.Regularizers +{ + public class L1L2 : IRegularizer + { + float _l1; + float _l2; + private readonly Dictionary _config; + + public string ClassName => "L1L2"; + public virtual IDictionary Config => _config; + + public L1L2(float l1 = 0.0f, float l2 = 0.0f) + { + //l1 = 0.0 if l1 is None else l1 + //l2 = 0.0 if l2 is None else l2 + // validate_float_arg(l1, name = "l1") + // validate_float_arg(l2, name = "l2") + + // self.l1 = l1 + // self.l2 = l2 + this._l1 = l1; + this._l2 = l2; + + _config = new(); + _config["l1"] = l1; + _config["l2"] = l2; + } + + public Tensor Apply(RegularizerArgs args) + { + //regularization = ops.convert_to_tensor(0.0, dtype = x.dtype) + //if self.l1: + // regularization += self.l1 * ops.sum(ops.absolute(x)) + //if self.l2: + // regularization += self.l2 * ops.sum(ops.square(x)) + //return regularization + + Tensor regularization = tf.constant(0.0, args.X.dtype); + regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X)); + regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X)); + return regularization; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Regularizers/L2.cs b/src/TensorFlowNET.Core/Operations/Regularizers/L2.cs new file mode 100644 index 000000000..6c0e950a9 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Regularizers/L2.cs @@ -0,0 +1,33 @@ +using System; + +using Tensorflow.Keras; + +namespace Tensorflow.Operations.Regularizers +{ + public class L2 : IRegularizer + { + float _l2; + private readonly Dictionary _config; + + public string ClassName => "L2"; + public virtual IDictionary Config => _config; + + public L2(float l2 = 0.01f) + { + // l2 = 0.01 if l2 is None else l2 + // validate_float_arg(l2, name = "l2") + // self.l2 = l2 + this._l2 = l2; + + _config = new(); + _config["l2"] = _l2; + } + + + public Tensor Apply(RegularizerArgs args) + { + //return self.l2 * ops.sum(ops.square(x)) + return _l2 * math_ops.reduce_sum(math_ops.square(args.X)); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs b/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs new file mode 100644 index 000000000..41364fe65 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Util; + +namespace Tensorflow; + +public sealed class SafeOperationHandle : SafeTensorflowHandle +{ + private SafeOperationHandle() + { + } + + public SafeOperationHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + var status = new Status(); + // c_api.TF_CloseSession(handle, status); + c_api.TF_DeleteSession(handle, status); + SetHandle(IntPtr.Zero); + return true; + } +} diff --git a/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs new file mode 100644 index 000000000..a66ee0961 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public struct TF_AttrMetadata + { + public byte is_list; + public long list_size; + public TF_AttrType type; + public long total_size; + } +} diff --git a/src/TensorFlowNET.Core/Operations/TF_AttrType.cs b/src/TensorFlowNET.Core/Operations/TF_AttrType.cs new file mode 100644 index 000000000..aa520e4bf --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/TF_AttrType.cs @@ -0,0 +1,15 @@ +namespace Tensorflow +{ + public enum TF_AttrType + { + TF_ATTR_STRING = 0, + TF_ATTR_INT = 1, + TF_ATTR_FLOAT = 2, + TF_ATTR_BOOL = 3, + TF_ATTR_TYPE = 4, + TF_ATTR_SHAPE = 5, + TF_ATTR_TENSOR = 6, + TF_ATTR_PLACEHOLDER = 7, + TF_ATTR_FUNC = 8 + } +} diff --git a/src/TensorFlowNET.Core/Graphs/TF_Input.cs b/src/TensorFlowNET.Core/Operations/TF_Input.cs similarity index 62% rename from src/TensorFlowNET.Core/Graphs/TF_Input.cs rename to src/TensorFlowNET.Core/Operations/TF_Input.cs index 4adcd0408..c68cb2dba 100644 --- a/src/TensorFlowNET.Core/Graphs/TF_Input.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Input.cs @@ -1,13 +1,17 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { [StructLayout(LayoutKind.Sequential)] public struct TF_Input { + public TF_Input(IntPtr oper, int index) + { + this.oper = oper; + this.index = index; + } + public IntPtr oper; public int index; } diff --git a/src/TensorFlowNET.Core/Operations/TF_Operation.cs b/src/TensorFlowNET.Core/Operations/TF_Operation.cs new file mode 100644 index 000000000..81a142eda --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/TF_Operation.cs @@ -0,0 +1,11 @@ +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_Operation + { + public IntPtr node; + } +} diff --git a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs index 9f04289be..4d0a91d99 100644 --- a/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs +++ b/src/TensorFlowNET.Core/Operations/TF_OperationDescription.cs @@ -1,7 +1,5 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { @@ -9,6 +7,7 @@ namespace Tensorflow public struct TF_OperationDescription { public IntPtr node_builder; - //public TF_Graph graph; + public IntPtr graph; + public IntPtr colocation_constraints; } } diff --git a/src/TensorFlowNET.Core/Graphs/TF_Output.cs b/src/TensorFlowNET.Core/Operations/TF_Output.cs similarity index 62% rename from src/TensorFlowNET.Core/Graphs/TF_Output.cs rename to src/TensorFlowNET.Core/Operations/TF_Output.cs index 98ec3d178..a02daa341 100644 --- a/src/TensorFlowNET.Core/Graphs/TF_Output.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Output.cs @@ -1,13 +1,17 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { [StructLayout(LayoutKind.Sequential)] public struct TF_Output { + public TF_Output(IntPtr oper, int index) + { + this.oper = oper; + this.index = index; + } + public IntPtr oper; public int index; } diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs new file mode 100644 index 000000000..591760600 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -0,0 +1,273 @@ +/***************************************************************************** + Copyright 2022 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Eager; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class _EagerTensorArray : TensorArray + { + TF_DataType _dtype; + public override TF_DataType dtype => _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; + + bool _infer_shape; + public override bool infer_shape => _infer_shape; + + Tensor _handle; + public override Tensor handle => _handle; + Tensor _flow; + public override Tensor flow => _flow; + bool _clear_after_read; + List _tensor_array; + List _previous_read_indices; + + public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, + bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + _size = size; + _flow = constant_op.constant(0); + _infer_shape = infer_shape; + _element_shape = element_shape ?? Shape.Null; + _colocate_with_first_write_call = colocate_with_first_write_call; + _dtype = dtype.as_base_dtype(); + _dynamic_size = dynamic_size; + _clear_after_read = clear_after_read; + _tensor_array = Enumerable.Repeat(null, size.numpy()).ToList(); + _previous_read_indices = new(); + } + + public override TensorArray unstack(Tensor value, string name = null) + { + var tensors = array_ops.unstack(value, name: name); + if(tensors.Length > _tensor_array.Count && !_dynamic_size) + { + throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}"); + } + _tensor_array = tensors.ToList(); + // TODO(Rinne): revise the implementation. Here we should return `parent()`. + return this; + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + if (_infer_shape) + { + var shape = new Shape(value.shape.dims.Skip(1).ToArray()); + _merge_element_shape(shape); + } + + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( + handle: _handle, + indices: indices, + value: value, + flow_in: _flow, + name: name); + + var ta = new _EagerTensorArray(_dtype, + infer_shape: _infer_shape, + element_shape: _element_shape[0], + dynamic_size: _dynamic_size, + handle: _handle, + flow: flow_out, + colocate_with_first_write_call: _colocate_with_first_write_call); + + + return ta; + });*/ + //if (indices is EagerTensor) + //{ + // indices = indices as EagerTensor; + // indices = indices.numpy(); + //} + + //foreach (var (index, val) in zip(indices.ToArray(), array_ops.unstack(value))) + //{ + // this.write(index, val); + //} + //return base; + //throw new NotImplementedException(""); + return this; + } + + public void _merge_element_shape(Shape shape) + { + _element_shape.concatenate(shape); + } + + public void _maybe_colocate_with(Tensor value) + { + _colocate_with.Add(value); + } + + private Tensor _maybe_zero(int ix) + { + var val = _tensor_array[ix]; + if(val is null) + { + val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype); + } + return val; + } + + public override Tensor read(T index, string name = null) + { + int index_int; + if (index is int int_index) + index_int = int_index; + else if (index is Tensor tensor_index) + index_int = tensor_index.numpy(); + else + throw new ValueError(""); + + if(index_int >= _tensor_array.Count) + { + throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} "); + } + + var res = _tensor_array[index_int]; + if(res is null) + { + if (_previous_read_indices.Contains(index_int)) + { + throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " + + $"a previous read (perhaps try setting clear_after_read = false?)"); + } + else + { + res = _maybe_zero(index_int); + } + } + + if (_clear_after_read) + { + _tensor_array[index_int] = null; + _previous_read_indices.Add(index_int); + } + return res; + } + + public override TensorArray write(Tensor index, Tensor value, string name = null) + { + int index_int; + if(index is EagerTensor eager) + { + return write(eager.numpy(), value, name); + } + throw new InvalidArgumentError("The index is supposed to be an EagerTensor"); + } + + public override TensorArray write(int index, T value, string name = null) + { + int size = _tensor_array.Count; + if(index >= size) + { + if (!_dynamic_size) + { + throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " + + $"is: {size} "); + } + _tensor_array.AddRange(Enumerable.Repeat(null, index - size + 1)); + } + + Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + + if(_dtype != tensor.dtype) + { + throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " + + $"trying to write dtype {tensor.dtype.as_python_name()} "); + } + + if (!_element_shape.is_compatible_with(tensor.shape)) + { + throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})"); + } + + if (_infer_shape) + { + _element_shape = _element_shape.merge_with(tensor.shape); + } + _tensor_array[index] = tensor; + return this; + } + + private Tensor size(string name = null) + { + return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); + } + + public override Tensor stack(string name = null) + { + if(_tensor_array.Count > 0) + { + for(int i = 0; i < _tensor_array.Count; i++) + { + _maybe_zero(i); + } + } + if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined) + { + return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype); + } + else + { + return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype); + } + //ops.colocate_with(_handle); + //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + //{ + // return gather(math_ops.range(0, size()), name: name); + //}); + } + + public override Tensor gather(Tensor indices, string name = null) + { + var element_shape = Shape.Null; + + var value = gen_data_flow_ops.tensor_array_gather_v3( + handle: _handle, + indices: indices, + flow_in: _flow, + dtype: _dtype, + name: name, + element_shape: element_shape); + + //if (element_shape != null) + //value.set_shape(-1, element_shape.dims); + + return value; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs new file mode 100644 index 000000000..2384e8146 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -0,0 +1,410 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class _GraphTensorArray : TensorArray + { + internal TF_DataType _dtype; + public TF_DataType dtype => _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; + + bool _infer_shape; + public override bool infer_shape => _infer_shape; + public List _element_shape; + + public List _colocate_with; + + internal Tensor _handle; + public override Tensor handle => _handle; + internal Tensor _flow; + public override Tensor flow => _flow; + + public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, + bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + clear_after_read = clear_after_read ?? true; + dynamic_size = dynamic_size ?? false; + _dynamic_size = dynamic_size.Value; + _dtype = dtype; + _size = size; + + _colocate_with_first_write_call = colocate_with_first_write_call; + if (colocate_with_first_write_call) + _colocate_with = new List(); + + // Record the current static shape for the array elements. The element + // shape is defined either by `element_shape` or the shape of the tensor + // of the first write. If `infer_shape` is true, all writes checks for + // shape equality. + if (element_shape == null) + { + _infer_shape = infer_shape; + _element_shape = new List { }; + } + else + { + _infer_shape = true; + _element_shape = new List { element_shape }; + } + + tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => + { + if (handle != null) + { + _handle = handle; + _flow = flow; + } + else + { + Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size, + dtype: dtype, + element_shape: element_shape, + identical_element_shapes: infer_shape, + dynamic_size: dynamic_size.Value, + clear_after_read: clear_after_read.Value, + tensor_array_name: tensor_array_name, + name: scope); + + // Construct the TensorArray with an empty device. The first + // write into the TensorArray from a Tensor with a set device + // will retroactively set the device value of this op. + if (colocate_with_first_write_call) + { + ops.colocate_with(ignore_existing: true); + (_handle, _flow) = create(); + } + else + { + (_handle, _flow) = create(); + } + } + }); + } + + public override TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + { + var num_elements = array_ops.shape(value)[0]; + return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + if (_infer_shape) + { + var shape = new Shape(value.shape.dims.Skip(1).ToArray()); + _merge_element_shape(shape); + } + + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( + handle: _handle, + indices: indices, + value: value, + flow_in: _flow, + name: name); + + var ta = new _GraphTensorArray(_dtype, + infer_shape: _infer_shape, + element_shape: _element_shape[0], + dynamic_size: _dynamic_size, + handle: _handle, + flow: flow_out, + colocate_with_first_write_call: _colocate_with_first_write_call); + + return ta; + });*/ + + //throw new NotImplementedException(""); + return this; + } + + public void _merge_element_shape(Shape shape) + { + _element_shape.Add(shape); + } + + public void _maybe_colocate_with(Tensor value) + { + _colocate_with.Add(value); + } + + public override Tensor read(T index, string name = null) + { + var value = gen_data_flow_ops.tensor_array_read_v3( + handle: _handle, + index: constant_op.constant(index), + flow_in: _flow, + dtype: _dtype, + name: name); + + if (_element_shape != null) + value.shape = _element_shape[0].dims; + + return value; + } + + public override TensorArray write(Tensor index, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate + { + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_write_v3( + handle: _handle, + index: index, + value: value, + flow_in: _flow, + name: name); + + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public override TensorArray write(int index, T value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + var index_tensor = ops.convert_to_tensor(index, name: "index"); + return write(index_tensor, value_tensor); + } + + private Tensor size(string name = null) + { + return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); + } + + public override Tensor stack(string name = null) + { + ops.colocate_with(_handle); + return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + { + return gather(math_ops.range(0, size()), name: name); + }); + } + + public override Tensor gather(Tensor indices, string name = null) + { + var element_shape = Shape.Null; + + if (_element_shape.Count > 0) + element_shape = _element_shape[0]; + + var value = gen_data_flow_ops.tensor_array_gather_v3( + handle: _handle, + indices: indices, + flow_in: _flow, + dtype: _dtype, + name: name, + element_shape: element_shape); + + //if (element_shape != null) + //value.set_shape(-1, element_shape.dims); + + return value; + } + } + + public class _GraphTensorArrayV2 : TensorArray + { + internal TF_DataType _dtype; + public override TF_DataType dtype => _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; + + bool _infer_shape; + public override bool infer_shape => _infer_shape; + public Shape _element_shape; + + public List _colocate_with; + + internal Tensor _handle; + public override Tensor handle => _handle; + internal Tensor _flow; + public override Tensor flow => _flow; + + public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null, + bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + Debug.Assert(handle is null); + dynamic_size = dynamic_size ?? false; + _dynamic_size = dynamic_size.Value; + _size = size; + + if(flow is not null && flow.dtype != dtypes.variant) + { + throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead"); + } + if(flow is null && size is null) + { + throw new ValueError("Argument `size` must be provided if argument `flow` is not provided."); + } + if(flow is not null && size is not null) + { + throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time."); + } + if(flow is not null && element_shape is not null) + { + throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time."); + } + + _dtype = dtype; + + _element_shape = element_shape; + _infer_shape = infer_shape; + tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope => + { + if (flow is null) + { + _flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name); + } + else + { + _flow = flow; + } + }); + + _colocate_with_first_write_call = false; + _colocate_with = null; + } + + public override TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray()); + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow); + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public override Tensor read(T index, string name = null) + { + if(index is Tensor tensor) + { + return read(tensor, name); + } + else + { + throw new TypeError("Please use non-generic method instead."); + } + } + + public Tensor read(Tensor index, string name = null) + { + return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope => + { + return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name); + }); + } + + public override TensorArray write(Tensor index, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name); + + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public override TensorArray write(int index, T value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + var index_tensor = ops.convert_to_tensor(index, name: "index"); + return write(index_tensor, value_tensor); + } + + private Tensor size(string name = null) + { + if(!_dynamic_size && _size is not null) + { + return ops.convert_to_tensor(_size, dtypes.int32); + } + else + { + return gen_list_ops.tensor_list_length(_flow, name); + } + } + + public override Tensor stack(string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate + { + int ta_size; + if(!_dynamic_size && (_size is not null)) + { + var size_tensor = tensor_util.constant_value(_size); + ta_size = size_tensor is null ? -1 : (int)size_tensor; + } + else + { + ta_size = -1; + } + var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape); + return value; + }); + } + + public override Tensor gather(Tensor indices, string name = null) + { + return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/_UserDeviceSpec.cs b/src/TensorFlowNET.Core/Operations/_UserDeviceSpec.cs new file mode 100644 index 000000000..0ffead377 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/_UserDeviceSpec.cs @@ -0,0 +1,78 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public class StringOrFunction + { + private object variable; + + private StringOrFunction(object val) + { + variable = val; + } + + public static implicit operator StringOrFunction(string val) + { + return new StringOrFunction(val); + } + + public static implicit operator StringOrFunction(Function val) + { + return new StringOrFunction(val); + } + + public bool IsFunction + { + get + { + return variable is FunctionDef; + } + } + + public override string ToString() + { + if (variable == null) + return ""; + + if (!IsFunction) + { + return variable.ToString(); + } + + return ((FunctionDef)variable).ToString(); + } + } + + public class _UserDeviceSpec + { + private StringOrFunction _device_name_or_function; + private string display_name; +#pragma warning disable CS0169 // The field '_UserDeviceSpec.function' is never used + private FunctionDef function; +#pragma warning restore CS0169 // The field '_UserDeviceSpec.function' is never used +#pragma warning disable CS0169 // The field '_UserDeviceSpec.raw_string' is never used + private string raw_string; +#pragma warning restore CS0169 // The field '_UserDeviceSpec.raw_string' is never used + + public _UserDeviceSpec(StringOrFunction device_name_or_function) + { + + _device_name_or_function = device_name_or_function; + display_name = device_name_or_function.ToString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs new file mode 100644 index 000000000..548a885ed --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -0,0 +1,1161 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using System.Diagnostics; + +namespace Tensorflow +{ + public class array_ops + { + public static Tensor placeholder_with_default(Tensor input, int[] shape, string name = null) + => gen_array_ops.placeholder_with_default(input, shape, name); + + /// + /// An identity op that triggers an error if a gradient is requested. + /// + /// + /// any tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. + /// + /// + /// Will be printed in the error when anyone tries to differentiate + /// this operation. + /// + /// + /// the same input tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, the TensorFlow gradient system + /// will return an error when trying to lookup the gradient of this op, + /// because no gradient must ever be registered for this function. This + /// op exists to prevent subtle bugs from silently returning unimplemented + /// gradients in some corner cases. + /// + public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) + => tf.Context.ExecuteOp("PreventGradient", name, new ExecuteOpArgs(input) + .SetAttributes(new { message })); + + internal static Tensor constant(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, + string name = "Const", + bool verify_shape = false) => constant_op.constant(value, + dtype: dtype, + shape: shape, + name: name, + verify_shape: verify_shape, + allow_broadcast: false); + + public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + + if (tf.executing_eagerly()) + { + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); + Tensor zeros = dtype switch + { + TF_DataType.TF_BOOL => constant(false), + TF_DataType.TF_DOUBLE => constant(0d), + TF_DataType.TF_FLOAT => constant(0f), + TF_DataType.TF_INT64 => constant(0L), + TF_DataType.TF_UINT64 => constant((ulong)0), + TF_DataType.TF_INT32 => constant(0), + TF_DataType.TF_UINT32 => constant((uint)0), + TF_DataType.TF_INT8 => constant((sbyte)0), + TF_DataType.TF_UINT8 => constant((byte)0), + _ => constant(0) + }; + return fill(shape, zeros, name: name); + }); + } + else + { + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return _constant_if_small(false, shape, dtype, name); + case TF_DataType.TF_DOUBLE: + return _constant_if_small(0.0D, shape, dtype, name); + case TF_DataType.TF_FLOAT: + return _constant_if_small(0.0F, shape, dtype, name); + case TF_DataType.TF_INT64: + return _constant_if_small(0L, shape, dtype, name); + case TF_DataType.TF_UINT64: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_INT32: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT32: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT8: + return _constant_if_small(0, shape, dtype, name); + default: + throw new TypeError("can't find type for zeros"); + } + }); + } + } + + public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + Tensor shapeTensor; + if(shape.Length > 1) + { + shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); + if (shapeTensor.ndim > 1) + { + shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); + } + } + else + { + shapeTensor = shape[0]; + } + var output = fill(shapeTensor, array_ops.constant(0, dtype), name); + Debug.Assert(output.dtype.as_base_dtype() == dtype); + return output; + } + + public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + { + return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate + { + var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); + var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); + + var shape_mask = mask_tensor.shape; + var ndims_mask = shape_mask.ndim; + var shape_tensor = tensor_tensor.shape; + + if (ndims_mask < 1) + throw new ValueError("mask cannot be scalar."); + + var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); + if (leading_size.rank == 0) + { + leading_size = expand_dims(leading_size, 0); + } + + var shape1 = concat(new[] + { + shape(tensor_tensor)[$":{axis}"], + leading_size, + shape(tensor_tensor)[$"{axis + ndims_mask}:"] + }, 0); + tensor_tensor = reshape(tensor_tensor, shape1); + var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); + var s1 = new Shape(shape_tensor.dims.Take(axis).ToArray()); + var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); + tensor_tensor.shape = s2; + + mask_tensor = reshape(mask_tensor, new[] { -1 }); + return _apply_mask_1d(tensor_tensor, mask_tensor, axis); + }); + } + + private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) + { + var indices = squeeze(where_v2(mask), axis: new[] { 1 }); + return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); + } + + public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return gen_array_ops.fill(shape, tf.constant(false, dtype: dtype), name: name); + case TF_DataType.TF_DOUBLE: + return gen_array_ops.fill(shape, tf.constant(0.0D, dtype: dtype), name: name); + case TF_DataType.TF_FLOAT: + return gen_array_ops.fill(shape, tf.constant(0.0F, dtype: dtype), name: name); + case TF_DataType.TF_INT32: + return gen_array_ops.fill(shape, tf.constant(0, dtype: dtype), name: name); + default: + throw new TypeError("can't find type for zeros"); + } + + }); + } + + private static Tensor _constant_if_small(int value, Tensor shape) + { + if (shape.dtype == TF_DataType.TF_INT64) + return shape < 1000L; + else + return shape < 1000; + } + + private static Tensor _constant_if_small(T value, Shape shape, TF_DataType dtype, string name) + { + if (shape.size < 1000) + { + return constant_op.constant(value, shape: shape, dtype: dtype, name: name); + } + else + { + var shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); + var c = constant_op.constant(0, dtype: dtype); + return gen_array_ops.fill(shape_t, c, name: name); + } + } + + public static Tensor _autopacking_conversion_function(IEnumerable v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + var inferred_dtype = _get_dtype_from_nested_lists(v); + if (dtype == TF_DataType.DtInvalid) + dtype = inferred_dtype; + + return _autopacking_helper(v, dtype, name == null ? "packed" : name); + } + + private static TF_DataType _get_dtype_from_nested_lists(IEnumerable list_or_tuple) + { + TF_DataType dtype = TF_DataType.DtInvalid; + + foreach (var obj in list_or_tuple) + { + switch (obj) + { + case Tensor t: + dtype = t.dtype.as_base_dtype(); + break; + case int t: + dtype = TF_DataType.TF_INT32; + break; + } + + if (dtype != TF_DataType.DtInvalid) + break; + } + + return dtype; + } + + /// + /// Converts the given list or tuple to a tensor by packing. + /// + /// A (possibly nested) list or tuple containing a tensor. + /// + /// + /// A `tf.Tensor` with value equivalent to `list_or_tuple`. + public static Tensor _autopacking_helper(IEnumerable list_or_tuple, TF_DataType dtype, string name) + { + var must_pack = false; + var converted_elems = new List(); + + bool switch_to_graph = tf.Context.switched_to_graph(list_or_tuple.ToArray()); + + var result = tf_with(ops.name_scope(name), scope => + { + foreach (var (i, elem) in enumerate(list_or_tuple)) + { + converted_elems.Add(elem); + must_pack = true; + } + + if (must_pack) + { + var elems_as_tensors = new List(); + foreach (var (i, elem) in enumerate(converted_elems)) + { + if (elem is EagerTensor eager_tensor) + { + if (switch_to_graph) + elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString())); + else + elems_as_tensors.Add(eager_tensor); + } + else if (elem is Tensor tensor) + { + elems_as_tensors.Add(tensor); + } + else if (elem is KerasTensor kt) + { + elems_as_tensors.Add(kt); + } + else + { + var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); + elems_as_tensors.Add(elem_tensor); + } + } + + return gen_array_ops.pack(elems_as_tensors.ToArray(), name: scope); + } + else + { + return tf.constant(np.array(new float[0])); + } + }); + + if (switch_to_graph) + tf.Context.restore_mode(); + + return result; + } + + public static Tensor expand_dims(Tensor input, int axis = -1, string name = null) + => gen_array_ops.expand_dims(input, ops.convert_to_tensor(axis), name); + + /// + /// Creates a tensor filled with a scalar value. + /// This operation creates a tensor of shape `dims` and fills it with `value`. + /// + /// A 1-D sequence of non-negative numbers. + /// A value to fill the returned `tf.Tensor`. + /// Optional string. The name of the output `tf.Tensor`. + /// A `tf.Tensor` with shape `dims` and the same dtype as `value`. + public static Tensor fill(Shape dims, T value, string name = null) + => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); + + public static Tensor fill(Tensor dims, T value, string name = null) + => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); + + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// + public static Tensor rank(Tensor input, string name = null) + => rank_internal(input, name, optimize: true); + + public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "Rank", new List { input }), scope => + { + name = scope; + var input_shape = input.shape; + if (optimize && input_shape.ndim > 0) + return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); + else + return gen_array_ops.rank(input, name); + }); + } + + /// + /// Creates a tensor with all elements set to 1. + /// + /// + /// + /// + /// + /// + public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope => + { + name = scope; + tensor = ops.convert_to_tensor(tensor, name: "tensor"); + + // is_fully_defined return unexpected value. + if (optimize && tensor.shape.IsFullyDefined && dtype != TF_DataType.TF_VARIANT) + { + + } + + if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) + { + throw new NotImplementedException("ones_like"); + // return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); + } + else + { + return gen_array_ops.ones_like(tensor, name: name); + } + }); + } + + public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) + => gen_array_ops.reshape(tensor, shape, name: name); + + public static Tensor reshape(Tensor tensor, Shape shape, string name = null) + => gen_array_ops.reshape(tensor, shape, name: name); + + public static Tensor reshape(Tensor tensor, object[] shape, string name = null) + { + var dims = shape_utils.from_object_array(shape); + return gen_array_ops.reshape(tensor, dims, name: name); + } + + public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) + => tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tidx = op.get_attr("Tidx") + } + }); + + private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => + { + name = scope; + var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); + var ones_shape = shape_internal(tensor1, optimize: optimize); + if (dtype == TF_DataType.DtInvalid) + dtype = tensor1.dtype; + var ret = ones(ones_shape, dtype: dtype, name: name); + return ret; + }); + } + + public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + return tf_with(ops.name_scope(name, "ones", new { shape }), scope => + { + name = scope; + if (shape._shape_tuple().Length == 0) + { + shape = reshape(shape, new Shape(-1)); + } + var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); + return output; + }); + } + + public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "ones", new { shape }), scope => + { + name = scope; + var output = _constant_if_small(1, shape[0]); + var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); + output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); + return output; + }); + } + + public static Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => tf_with(ops.name_scope(name, "ones", shape), scope => + { + dtype = dtype.as_base_dtype(); + name = scope; + + Tensor ones = dtype switch + { + TF_DataType.TF_DOUBLE => constant(1.0d), + TF_DataType.TF_FLOAT => constant(1.0f), + _ => constant(1, dtype) + }; + + if (shape.ndim == 0) + return ones; + + // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); + return fill(shape, ones, name: name); + }); + + public static Tensor one_hot(Tensor indices, Tensor depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) + { + return tf_with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope => + { + name = scope; + var on_exists = false; + var off_exists = false; + var on_dtype = TF_DataType.DtInvalid; + var off_dtype = TF_DataType.DtInvalid; + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + if (!on_exists) + { + on_value = ops.convert_to_tensor(1, dtype, name: "on_value"); + on_dtype = dtype; + } + + if (!off_exists) + { + off_value = ops.convert_to_tensor(0, dtype, name = "off_value"); + off_dtype = dtype; + } + + return gen_array_ops.one_hot(indices, depth, + on_value: on_value, + off_value: off_value, + axis: axis, + name: name); + }); + } + + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + { + var res = gen_array_ops.unique(x, out_idx: out_idx, name: name); + Debug.Assert(res.Length == 2); + return (res[0], res[1]); + } + + public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack") + { + if (axis == 0) + { + return ops.convert_to_tensor(values, name: name); + } + + return gen_array_ops.pack(values, axis: axis, name: name); + } + + public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") + { + num = num ?? value.shape.as_int_list()[axis]; + return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); + } + + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) + { + if (x == null && y == null) + { + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition, name: name); + }); + } + else if (x != null && y != null) + { + return gen_math_ops.select(condition, ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + } + else + { + throw new ValueError("x and y must both be non-None or both be None."); + } + } + + public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null) + { + if (x == null && y == null) + { + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition, name: name); + }); + } + else if (x != null && y != null) + { + return gen_math_ops.select_v2(condition, ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + } + else + { + throw new ValueError("x and y must both be non-None or both be None."); + } + } + + /// + /// Returns the shape of a tensor. + /// + /// A `Tensor` or `SparseTensor`. + /// A name for the operation (optional). + /// + /// (Optional) The specified output type of the operation + /// (`int32` or `int64`). Defaults to `tf.int32`. + /// + /// A `Tensor` of type `out_type`. + public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => shape_internal(input, name, optimize: true, out_type: out_type); + + public static Tensor shape_v2(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => shape_internal(input, name, optimize: true, out_type: out_type); + + public static Tensor size(T input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + => size_internal(input, name, optimize: optimize, out_type: out_type); + + public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + { + return tf_with(ops.name_scope(name, "Shape", new { input }), scope => + { + name = scope; + + if (!tf.Context.executing_eagerly()) + { + var input_shape = input.shape; + if (optimize && input.ndim > -1 && input_shape.IsFullyDefined) + { + if(out_type == TF_DataType.TF_INT32) + return constant_op.constant(input.shape.as_int_list(), name: name); + else + return constant_op.constant(input.shape.dims, name: name); + } + } + + return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + out_type = op.get_attr("out_type") + } + }.SetAttributes(new + { + out_type + })).First(); + }); + } + + private static Tensor size_internal(T input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + { + return tf_with(ops.name_scope(name, "Size", new { input }), scope => + { + name = scope; + + var input_tensor = ops.convert_to_tensor(input); + var input_shape = input_tensor.shape; + if (optimize) + { + if (input_shape.IsFullyDefined) + { + return constant_op.constant(input_shape.size, dtype: out_type, name: name); + } + } + + return gen_array_ops.size(input_tensor, name: name, out_type: out_type); + }); + } + + public static Tensor tile(Tensor input, Tensor multiples, string name = null) + => tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tmultiples = op.get_attr("Tmultiples") + } + }); + + /*public static Tensor tile(Tensor input, Shape multiples, string name = null) + { + return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tmultiples = op.get_attr("Tmultiples") + } + }); + }*/ + + public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => + { + name = scope; + tensor = ops.convert_to_tensor(tensor, name: "tensor"); + + // is_fully_defined return unexpected value. + if (optimize && tensor.shape.IsFullyDefined && dtype != TF_DataType.TF_VARIANT) + { + + } + + if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) + { + throw new NotImplementedException("zeros_like"); + // return zeros(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); + } + else + { + return gen_array_ops.zeros_like(tensor, name: name); + } + }); + } + + /// + /// When building ops to compute gradients, this op prevents the contribution of + /// its inputs to be taken into account.Normally, the gradient generator adds ops + /// to a graph to compute the derivatives of a specified 'loss' by recursively + /// finding out inputs that contributed to its computation.If you insert this op + /// in the graph it inputs are masked from the gradient generator. They are not + /// taken into account for computing gradients. + /// + /// + /// + /// + public static Tensor stop_gradient(Tensor input, string name = null) + { + var tape = tf.GradientTape().stop_recording(); + var result = gen_array_ops.stop_gradient(input, name); + tape.StartRecord(); + return result; + } + + /// + /// Extracts a strided slice of a tensor (generalized python array indexing). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end, + Tensor strides = null, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) + => tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Index = op.get_attr("Index"), + begin_mask = op.get_attr("begin_mask"), + end_mask = op.get_attr("end_mask"), + ellipsis_mask = op.get_attr("ellipsis_mask"), + new_axis_mask = op.get_attr("new_axis_mask"), + shrink_axis_mask = op.get_attr("shrink_axis_mask") + } + }.SetAttributes(new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + })); + + /// + /// Returns the gradient of `StridedSlice`. + /// + /// Since `StridedSlice` cuts out pieces of its `input` which is size + /// `shape`, its gradient will have the same shape (which is passed here + /// as `shape`). The gradient will be zero in any element that the slice + /// does not select. + /// + /// Must be one of the following types: `int32`, `int64`. + /// Must have the same type as `shape`. + /// Must have the same type as `shape`. + /// Must have the same type as `shape`. + /// A `Tensor`. + /// An optional `int`. Defaults to `0`. + /// An optional `int`. Defaults to `0`. + /// An optional `int`. Defaults to `0`. + /// An optional `int`. Defaults to `0`. + /// An optional `int`. Defaults to `0`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `dy`. + public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, + long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0, + long shrink_axis_mask = 0, string name = null) + => tf.Context.ExecuteOp("StridedSliceGrad", name, + new ExecuteOpArgs(shape, begin, end, strides, dy) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Index = op.get_attr("Index"), + begin_mask = op.get_attr("begin_mask"), + end_mask = op.get_attr("end_mask"), + ellipsis_mask = op.get_attr("ellipsis_mask"), + new_axis_mask = op.get_attr("new_axis_mask"), + shrink_axis_mask = op.get_attr("shrink_axis_mask") + } + }.SetAttributes(new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + })); + + /// + /// Removes dimensions of size 1 from the shape of a tensor. + /// Given a tensor `input`, this operation returns a tensor of the same type with + /// all dimensions of size 1 removed.If you don't want to remove all size 1 + /// dimensions, you can remove specific size 1 dimensions by specifying + /// `axis`. + /// + /// A `Tensor`. The `input` to squeeze. + /// An optional list of `ints`. Defaults to `[]`. + /// If specified, only squeezes the dimensions listed.The dimension + /// index starts at 0. It is an error to squeeze a dimension that is not 1. + /// Must be in the range `[-rank(input), rank(input))`. + /// A name for the operation (optional). + /// Deprecated keyword argument that is now axis. + /// A `Tensor`. Has the same type as `input`. + /// Contains the same data as `input`, but has one or more dimensions of + /// size 1 removed. + public static Tensor squeeze(Tensor input, Axis axis = null, string name = null) + => gen_array_ops.squeeze(input, axis, name); + + public static Tensor identity(Tensor input, string name = null) + => gen_array_ops.identity(input, name); + + public static Tensor invert_permutation(Tensor x, string name = null) + => gen_array_ops.invert_permutation(x, name: name); + + public static Tensor matrix_diag(Tensor diagonal, + string name = "diag", + int k = 0, + int num_rows = -1, + int num_cols = -1, + float padding_value = 0f, + string align = "RIGHT_LEFT") + => tf.Context.ExecuteOp("MatrixDiagV3", name, + new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype)) + .SetAttributes(new { align })); + + public static Tensor matrix_set_diag(Tensor input, + Tensor diagonal, + string name = "set_diag", + int k = 0, + string align = "RIGHT_LEFT") + => tf.Context.ExecuteOp("MatrixSetDiagV3", name, new ExecuteOpArgs(input, diagonal, k) + .SetAttributes(new { align })); + + public static Tensor[] meshgrid(T[] array, bool copy = true, bool sparse = false, string indexing = "xy") + { + return tf_with(ops.name_scope(null, "meshgrid", array), scope => + { + var ndim = array.Length; + var s0 = range(ndim).Select(x => 1).ToArray(); + + var output = new List(); + foreach (var (i, x) in enumerate(array)) + { + var shape = s0.Take(i).ToArray().concat(new[] { -1 }).concat(s0.Skip(i + 1).ToArray()); + output.add(reshape(stack(x), shape)); + } + + // Create parameters for broadcasting each tensor to the full size + var shapes = array.Select(x => size(x)).ToArray(); + var output_dtype = _get_dtype_from_nested_lists(array).as_base_dtype(); + if (indexing == "xy" && ndim > 1) + { + output[0] = reshape(output[0], new[] { 1, -1 }.concat(range(ndim - 2).Select(x => 1).ToArray())); + output[1] = reshape(output[1], new[] { -1, 1 }.concat(range(ndim - 2).Select(x => 1).ToArray())); + (shapes[0], shapes[1]) = (shapes[1], shapes[0]); + } + + if(sparse) + return output.ToArray(); + else + { + var mult_fact = ones(shapes, output_dtype); + return output.Select(x => x * mult_fact).ToArray(); + } + }); + } + + public static Tensor moveaxis(NDArray array, Axis source, Axis destination) + { + List perm = null; + source = source.axis.Select(x => x < 0 ? array.rank + x : x).ToArray(); + destination = destination.axis.Select(x => x < 0 ? array.rank + x : x).ToArray(); + + if (array.shape.rank > -1) + { + perm = range(0, array.rank).Where(i => !source.axis.Contains(i)).ToList(); + foreach (var (dest, src) in zip(destination.axis, source.axis).OrderBy(x => x.Item1)) + { + perm.Insert(dest, src); + } + } + else + throw new NotImplementedException(""); + + return array_ops.transpose(array, perm.ToArray()); + } + + /// + /// Computes the shape of a broadcast given symbolic shapes. + /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of + /// calling tf.shape on another Tensor) this computes a Tensor which is the shape + /// of the result of a broadcasting op applied in tensors of shapes shape_x and + /// shape_y. + /// For example, if shape_x is [1, 2, 3] and shape_y is [5, 1, 3], the result is a + /// Tensor whose value is [5, 2, 3]. + /// This is useful when validating the result of a broadcasting operation when the + /// tensors do not have statically known shapes. + /// + /// A rank 1 integer `Tensor`, representing the shape of x. + /// A rank 1 integer `Tensor`, representing the shape of y. + /// A rank 1 integer `Tensor` representing the broadcasted shape. + public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) + => gen_array_ops.broadcast_args(shape_x, shape_y); + + public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) + => Framework.common_shapes.broadcast_shape(shape_x, shape_y); + + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + /// + public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") + { + return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); + } + + public static Tensor concat(object[] values, int axis, string name = "concat") + { + return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); + } + + /// + /// Gather slices from `params` according to `indices`. `indices` must be an integer tensor of any dimension(often 1-D). + /// + /// Element type of the indexed tensor. + /// Element type of the index tensor. + /// The `Tensor` from which to gather values. Must be at least rank `axis + 1`. + /// The index `Tensor`. Must be one of the following types: `int32`, `int64`. The values must be in range `[0, params.shape[axis])`. + /// A name for the operation (optional). + /// + /// A `Tensor`. Must be one of the following types: `int32`, `int64`. + /// The `axis` in `params` to gather `indices` from.Must be greater than or equal to `batch_dims`. + /// Defaults to the first non-batch dimension. Supports negative indexes. + /// + /// An integer. The number of batch dimensions. Must be less than or equal to rank(indices). + /// + public static Tensor gather(Tensor @params, Tensor indices, string name = null, Tensor axis = null, int batch_dims = 0) + { + if (axis is null) + axis = tf.convert_to_tensor(batch_dims); + if(tensor_util.constant_value(axis) != 0) + { + return gen_array_ops.gather_v2(@params, indices, axis, batch_dims: batch_dims, name: name); + } + + return gen_array_ops.gather_v2(@params, indices, axis, name: name); + } + + public static Tensor gather(Tensor @params, Tensor indices, int axis, string name = null, int batch_dims = 0) + => gather(@params, indices, name, ops.convert_to_tensor(axis), batch_dims); + + public static Tensor gather(ResourceVariable @params, Tensor indices, string name = null, Tensor axis = null, int batch_dims = 0) + { + if (axis is null) + axis = tf.convert_to_tensor(batch_dims); + if (tensor_util.constant_value(axis) != 0) + { + throw new NotImplementedException(); + } + + return @params.sparse_read(indices, name); + } + + public static Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + var a_tensor = ops.convert_to_tensor(a); + if (perm == null) + { + var rank = a_tensor.rank; + perm = range(0, rank).OrderByDescending(x => x).ToArray(); + } + + return gen_array_ops.transpose(a_tensor, perm, name: scope); + }); + } + + public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + return gen_array_ops.transpose(a, perm, name: scope); + }); + } + + /// + /// Transposes last two dimensions of tensor `a`. + /// For example: + /// python + /// x = tf.constant([[1, 2, 3], [4, 5, 6]]) + /// tf.matrix_transpose(x) # [[1, 4], + /// # [2, 5], + /// # [3, 6]] + /// + /// Matrix with two batch dimensions. + /// x.shape is [1, 2, 3, 4] + /// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] + /// + /// + /// + /// + /// + /// + public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + var a_shape = a.shape; + var ndims = a.shape.ndim; + Axis perm; + if(ndims != 0) + { + if (ndims < 2) + { + throw new ValueError("Argument `a` should be a (batch) matrix with rank " + + $">= 2. Received `a` = {a} with shape: {a_shape}"); + } + perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); + } + else + { + var a_rank = a.rank; + perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); + } + return transpose(a, perm:perm, conjugate:conjugate); + }); + } + + public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null, + string name = "split") + { + return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name); + } + + public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1, + string name = "split") + { + if(num_or_size_splits.Length == 0) + { + throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split."); + } + var size_splits = ops.convert_to_tensor(num_or_size_splits); + + if(num == -1) + { + num = (int)size_splits.shape[0]; + } + + return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name); + } + + public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) + => gen_array_ops.slice(input, ops.convert_to_tensor(begin), ops.convert_to_tensor(size), name: name); + + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) + => gen_array_ops.slice(input, begin, size, name: name); + + + public static Tensor stack(object values, int axis = 0, string name = "stack") + { + if (axis == 0) + // If the input is a constant list, it can be converted to a constant op + return ops.convert_to_tensor(values, name: name); + + throw new NotImplementedException("array_ops.stack"); + } + + public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + { + Tensor result = null; + mode = mode.ToUpper(); + if (mode == "CONSTANT") + { + if (constant_values != 0) + throw new NotImplementedException("gen_array_ops.pad_v2"); + else + result = gen_array_ops.pad(tensor, paddings, name: name); + } + + // Restore shape information where possible. + if (!tf.Context.executing_eagerly()) + { + var paddings_constant = tensor_util.constant_value(paddings); + var input_shape = result.op.inputs[0].shape; + if (input_shape.ndim > -1 && + !result.shape.IsFullyDefined && + !(paddings_constant is null)) + { + var new_shape = new List(); + foreach ((NDArray padding, int dim) in zip(paddings_constant, input_shape.as_int_list())) + { + if (padding is null || dim == -1 || padding.ToArray().Contains(-1)) + new_shape.Add(-1); + else + new_shape.Add((int)np.sum(padding) + dim); + } + result.shape = new_shape.ToArray(); + } + } + + return result; + } + + public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string name = null) + { + if (tf.Context.executing_eagerly()) + throw new RuntimeError("tf.placeholder() is not compatible with eager execution."); + + var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); + return _op.output; + } + + public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims") + { + if(ndims != -100) + { + if (axis >= 0 && axis < ndims) return axis; + else if (-ndims <= axis && axis < 0) return axis + ndims; + else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}"); + + } else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known."); + return axis; + } + + } +} diff --git a/src/TensorFlowNET.Core/Operations/audio_ops.cs b/src/TensorFlowNET.Core/Operations/audio_ops.cs new file mode 100644 index 000000000..4f1b5f64c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/audio_ops.cs @@ -0,0 +1,29 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class audio_ops + { + public Tensors decode_wav(Tensor contents, int desired_channels = -1, int desired_samples = -1, string name = null) + => tf.Context.ExecuteOp("DecodeWav", name, new ExecuteOpArgs(contents) + .SetAttributes(new { desired_channels, desired_samples })); + } +} diff --git a/src/TensorFlowNET.Core/Operations/bitwise_ops.cs b/src/TensorFlowNET.Core/Operations/bitwise_ops.cs new file mode 100644 index 000000000..7536357ca --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/bitwise_ops.cs @@ -0,0 +1,111 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + /// + /// Operations for bitwise manipulation of integers. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise + /// + public class bitwise_ops + { + /// + /// Elementwise computes the bitwise left-shift of `x` and `y`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/left_shift + /// + /// + /// + /// + /// + public Tensor left_shift(Tensor x, Tensor y, string name = null) => binary_op(x, y, "LeftShift", name); + + /// + /// Elementwise computes the bitwise right-shift of `x` and `y`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/right_shift + /// + /// + /// + /// + /// + public Tensor right_shift(Tensor x, Tensor y, string name = null) => binary_op(x, y, "RightShift", name); + + /// + /// Elementwise computes the bitwise inversion of `x`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/invert + /// + /// + /// + /// + public Tensor invert(Tensor x, string name = null) => unary_op(x, "Invert", name); + + /// + /// Elementwise computes the bitwise AND of `x` and `y`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/bitwise_and + /// + /// + /// + /// + /// + public Tensor bitwise_and(Tensor x, Tensor y, string name = null) => binary_op(x, y, "BitwiseAnd", name); + + /// + /// Elementwise computes the bitwise OR of `x` and `y`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/bitwise_or + /// + /// + /// + /// + /// + public Tensor bitwise_or(Tensor x, Tensor y, string name = null) => binary_op(x, y, "BitwiseOr", name); + + /// + /// Elementwise computes the bitwise XOR of `x` and `y`. + /// https://www.tensorflow.org/api_docs/python/tf/bitwise/bitwise_xor + /// + /// + /// + /// + /// + public Tensor bitwise_xor(Tensor x, Tensor y, string name = null) => binary_op(x, y, "BitwiseXor", name); + + + #region Private helper methods + + /// + /// Helper method to invoke unary operator with specified name. + /// + /// + /// + /// + /// + Tensor unary_op(Tensor x, string opName, string name) + => tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x)); + + /// + /// Helper method to invoke binary operator with specified name. + /// + /// + /// + /// + /// + /// + Tensor binary_op(Tensor x, Tensor y, string opName, string name) + => tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x, y)); + #endregion + } +} diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 0422aa755..900db8cac 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -1,36 +1,212 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { + /// + /// Request that `desc` be co-located on the device where `op` + /// is placed. + /// + /// Use of this is discouraged since the implementation of device placement is + /// subject to change. Primarily intended for internal libraries + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ColocateWith(IntPtr desc, IntPtr op); + + /// + /// Get the OpList of all OpDefs defined in this address space. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern SafeBufferHandle TF_GetAllOpList(); + /// /// For inputs that take a single tensor. /// + /// TF_OperationDescription* + /// TF_Output + [DllImport(TensorFlowLibName)] + public static extern void TF_AddInput(IntPtr desc, TF_Output input); + + /// + /// Call once per control input to `desc`. + /// + /// TF_OperationDescription* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); + + /// + /// + /// + /// TF_Graph* + /// TF_Operation* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void AddControlInput(IntPtr graph, IntPtr op, IntPtr input); + + /// + /// + /// + /// TF_Graph* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void RemoveAllControlInputs(IntPtr graph, IntPtr op); + + /// + /// For inputs that take a list of tensors. + /// inputs must point to TF_Output[num_inputs]. + /// /// - /// + /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); + public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs); [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status); + public static extern IntPtr TF_FinishOperation(IntPtr desc, SafeStatusHandle status); + /// + /// Operation will only be added to *graph when TF_FinishOperation() is + /// called (assuming TF_FinishOperation() does not return an error). + /// *graph must not be deleted until after TF_FinishOperation() is + /// called. + /// + /// TF_Graph* + /// const char* + /// const char* + /// TF_OperationDescription* [DllImport(TensorFlowLibName)] - public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name); + public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name); [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); + public static extern IntPtr TF_OperationDevice(IntPtr oper); + + /// + /// Get list of all control inputs to an operation. `control_inputs` must + /// point to an array of length `max_control_inputs` (ideally set to + /// TF_OperationNumControlInputs(oper)). Returns the number of control + /// inputs (should match TF_OperationNumControlInputs(oper)). + /// + /// TF_Operation* + /// TF_Operation** + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs); + + /// + /// Get the list of operations that have `*oper` as a control input. + /// `control_outputs` must point to an array of length at least + /// `max_control_outputs` (ideally set to + /// TF_OperationNumControlOutputs(oper)). Beware that a concurrent + /// modification of the graph can increase the number of control + /// outputs. Returns the number of control outputs (should match + /// TF_OperationNumControlOutputs(oper)). + /// + /// TF_Operation* + /// TF_Operation** + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs); + + /// + /// TF_Output producer = TF_OperationInput(consumer); + /// There is an edge from producer.oper's output (given by + /// producer.index) to consumer.oper's input (given by consumer.index). + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern TF_Output TF_OperationInput(TF_Input oper_in); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationInputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern TF_DataType TF_OperationInputType(TF_Input oper_in); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_OperationName(IntPtr oper); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumInputs(IntPtr oper); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_OperationOpType(IntPtr oper); + + /// + /// Get the number of control inputs to an operation. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumControlInputs(IntPtr oper); + + /// + /// Get the number of operations that have `*oper` as a control input. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumControlOutputs(IntPtr oper); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumOutputs(IntPtr oper); + + /// + /// Get the number of current consumers of a specific output of an + /// operation. Note that this number can change when new operations + /// are added to the graph. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + + /// + /// Get list of all current consumers of a specific output of an + /// operation. `consumers` must point to an array of length at least + /// `max_consumers` (ideally set to + /// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent + /// modification of the graph can increase the number of consumers of + /// an operation. Returns the number of output consumers (should match + /// TF_OperationOutputNumConsumers(oper_out)). + /// + /// TF_Output + /// TF_Input* + /// int + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); + public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status); + public static extern void TF_OperationToNodeDef(IntPtr oper, SafeBufferHandle buffer, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); + public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs new file mode 100644 index 000000000..3c4aa535d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/check_ops.cs @@ -0,0 +1,142 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class check_ops + { + /// + /// Assert the condition `x == y` holds element-wise. + /// + /// + /// + /// + public static Operation assert_equal(T1 t1, T2 t2, object[] data = null, string message = null, string name = null) + { + if (message == null) + message = ""; + + return tf_with(ops.name_scope(name, "assert_equal", new { t1, t2, data }), delegate + { + var x = ops.convert_to_tensor(t1, name: "x"); + var y = ops.convert_to_tensor(t2, name: "y"); + + if (data == null) + { + data = new object[] + { + message, + "Condition x == y did not hold element-wise:", + $"x (%s) = {x.name}", + x, + $"y (%s) = {y.name}", + y + }; + } + + var eq = gen_math_ops.equal(x, y); + var condition = math_ops.reduce_all(eq); + var x_static = tensor_util.constant_value(x); + var y_static = tensor_util.constant_value(y); + return control_flow_ops.Assert(condition, data); + }); + } + + public static Operation assert_greater_equal(Tensor x, Tensor y, object[] data = null, string message = null, + string name = null) + { + if (message == null) + message = ""; + + return tf_with(ops.name_scope(name, "assert_greater_equal", new { x, y, data }), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, name: "y"); + string x_name = x.name; + string y_name = y.name; + if (data == null) + { + data = new object[] + { + message, + "Condition x >= y did not hold element-wise:", + $"x (%s) = {x_name}", + x, + $"y (%s) = {y_name}", + y + }; + } + + var condition = math_ops.reduce_all(gen_math_ops.greater_equal(x, y)); + return control_flow_ops.Assert(condition, data); + }); + } + + + public static Operation assert_positive(Tensor x, object[] data = null, string message = null, string name = null) + { + if (message == null) + message = ""; + + return tf_with(ops.name_scope(name, "assert_positive", new { x, data }), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + if (data == null) + { + name = x.name; + data = new object[] + { + message, + "Condition x > 0 did not hold element-wise:", + $"x (%s) = {name}", + x + }; + } + var zero = ops.convert_to_tensor(0, dtype: x.dtype); + return assert_less(zero, x, data: data); + }); + } + + public static Operation assert_less(Tensor x, Tensor y, object[] data = null, string message = null, string name = null) + { + if (message == null) + message = ""; + + return tf_with(ops.name_scope(name, "assert_less", new { x, y, data }), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, name: "y"); + string x_name = x.name; + string y_name = y.name; + if (data == null) + { + data = new object[] + { + message, + "Condition x < y did not hold element-wise:", + $"x (%s) = {x_name}", + $"y (%s) = {y_name}", + y + }; + } + var condition = math_ops.reduce_all(gen_math_ops.less(x, y)); + return control_flow_ops.Assert(condition, data); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs new file mode 100644 index 000000000..59d74fde3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs @@ -0,0 +1,81 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class clip_ops + { + public static (Tensors, Tensor) clip_by_global_norm(Tensor[] t_list, float clip_norm, Tensor use_norm = null, string name = null) + { + use_norm = global_norm(t_list, name); + return tf_with(ops.name_scope(name, "clip_by_global_norm", t_list), delegate + { + // Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm + var scale_for_finite = clip_norm * math_ops.minimum( + 1.0f / use_norm, + constant_op.constant(1.0, dtype: use_norm.dtype) / clip_norm); + + // If use_norm is any finite number, this is a no-op. For inf/-inf/NaN, + // this will make scale NaN. + var scale = scale_for_finite + (use_norm - use_norm); + + Tensors values_clipped = new Tensors(); + foreach (var (i, v) in enumerate(t_list)) + values_clipped.Add(array_ops.identity(v * scale, name: $"{name}_{i}")); + return (values_clipped, use_norm); + }); + } + + public static Tensor clip_by_value(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = null) + { + return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate + { + var values = ops.convert_to_tensor(t, name: "t"); + // Go through list of tensors, for each value in each tensor clip + var t_min = math_ops.minimum(values, clip_value_max); + // Assert that the shape is compatible with the initial shape, + // to prevent unintentional broadcasting. + _ = values.shape.merge_with(t_min.shape); + var t_max = math_ops.maximum(t_min, clip_value_min, name: name); + _ = values.shape.merge_with(t_max.shape); + + return t_max; + }); + } + + /// + /// Computes the global norm of multiple tensors. + /// + /// + /// + /// + public static Tensor global_norm(Tensor[] t_list, string name = null) + { + return tf_with(ops.name_scope(name, "global_norm", t_list), delegate + { + var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray(); + var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)); + var norm = math_ops.sqrt(half_squared_norm * + constant_op.constant(2.0, dtype: half_squared_norm.dtype), + name: "global_norm"); + return norm; + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs new file mode 100644 index 000000000..8b7989e6e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs @@ -0,0 +1,61 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class confusion_matrix + { + /// + /// Squeeze last dim if ranks differ from expected by exactly 1. + /// + /// + /// + /// + /// + /// + public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, + Tensor predictions, + int expected_rank_diff = 0, + string name = null) + { + return tf_with(ops.name_scope(name, default_name: "remove_squeezable_dimensions", (labels, predictions)), delegate + { + predictions = ops.convert_to_tensor(predictions); + labels = ops.convert_to_tensor(labels); + var predictions_shape = predictions.shape; + var predictions_rank = predictions_shape.ndim; + var labels_shape = labels.shape; + var labels_rank = labels_shape.ndim; + if (labels_rank > -1 && predictions_rank > -1) + { + // Use static rank. + var rank_diff = predictions_rank - labels_rank; + if (rank_diff == expected_rank_diff + 1) + predictions = array_ops.squeeze(predictions, new int[] { -1 }); + else if (rank_diff == expected_rank_diff - 1) + labels = array_ops.squeeze(labels, new int[] { -1 }); + return (labels, predictions); + } + + // Use dynamic rank. + throw new NotImplementedException("remove_squeezable_dimensions dynamic rank"); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs new file mode 100644 index 000000000..efd9aba35 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -0,0 +1,812 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Operations.ControlFlows; +using Tensorflow.Util; +using static Tensorflow.Binding; +using util = Tensorflow.control_flow_util; + +namespace Tensorflow +{ + public class control_flow_ops + { + public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true) + { + v = ops.convert_to_tensor(v); + v = _NextIteration(v); + if (enforce_shape_invariant) + _EnforceShapeInvariant(m, v); + m.op._update_input(1, v); + return v; + } + + /// + /// Check if the shapes of the loops variables are invariants. + /// + /// + /// + public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var) + { + + } + + public static Tensor exit(Tensor data, string name = null) + { + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_exit(data, name: name); + else + return gen_control_flow_ops._exit(data, name: name); + } + + public static Tensor _NextIteration(Tensor data, string name = null) + { + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_next_iteration(data, name: name); + else + return gen_control_flow_ops.next_iteration(data, name: name); + } + + public static Operation Assert(Tensor condition, object[] data, long summarize = 3, string name = null) + { + if (tf.executing_eagerly()) + { + if (condition == null) + throw new InvalidArgumentError(""); + + return null; + } + + return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => + { + name = scope; + var xs = ops.convert_n_to_tensor(data); + condition = ops.convert_to_tensor(condition, name: "Condition"); + Func true_assert = () => + { + var assert = gen_logging_ops.assert(condition, data, summarize, name: "Assert"); + return new Operation[] { assert }; + }; + + Func false_assert = () => + { + var op = gen_control_flow_ops.no_op(); + return new Operation[] { op }; + }; + + var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); + + return guarded_assert[0].op; + }); + } + + public static Operation group(T[] inputs, string name = null) where T : ITensorOrOperation + { + return tf_with(ops.name_scope(name, "group_deps", inputs), scope => + { + name = scope; + + // Sorts *inputs according to their devices. + var ops_on_device = new Dictionary>(); + foreach (var inp in inputs) + { + if (ops_on_device.ContainsKey(inp.Device)) + ops_on_device[inp.Device].Add(inp); + else + ops_on_device[inp.Device] = new List { inp }; + } + + // 1-level tree. The root node is the returned NoOp node. + if (ops_on_device.Count == 1) + { + var dev = ops_on_device.Keys.First(); + var deps = ops_on_device.Values.First(); + return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name); + } + + // 2-level tree. The root node is the returned NoOp node. + // deps contains 1 NoOp node for each device. + throw new NotImplementedException("control_flow_ops.group"); + }); + } + + /// + /// Does nothing. Only useful as a placeholder for control edges. + /// + /// + /// + public static Operation no_op(string name = null) + => gen_control_flow_ops.no_op(name: name); + + private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = null) + { + return tf_with(ops.control_dependencies(deps), ctl => + { + if (dev == null) + { + return gen_control_flow_ops.no_op(name); + } + else + { + return gen_control_flow_ops.no_op(name); + } + }); + } + + /// + /// Create the state for all the while loops involved in one gradients(). + /// + /// + /// + /// + public static ControlFlowState MaybeCreateControlFlowState(List between_op_list, List between_ops, bool colocate_gradients_with_ops) + { + var flag = new List(); + ControlFlowState loop_state = null; + + int pos = 0; + while (pos < between_op_list.Count) + { + var op = between_op_list[pos]; + if (IsLoopExit(op)) + { + if (loop_state == null) + { + loop_state = new ControlFlowState(); + } + if (colocate_gradients_with_ops) + ops.colocate_with(op); + loop_state.AddWhileContext(op, between_op_list, between_ops); + } + pos++; + } + + return loop_state; + } + + public static bool IsLoopExit(Operation op) + => op.OpType == "Exit" || op.OpType == "RefExit"; + + public static bool IsLoopSwitch(Operation op) + { + if (IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; + } + + public static bool IsCondSwitch(Operation op) + { + throw new NotImplementedException("IsCondSwitch"); + } + + public static bool IsSwitch(Operation op) + => op.type == "Switch" || op.type == "RefSwitch"; + + public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) + { + return tf_with(ops.name_scope(name, "tuple", tensors), scope => + { + name = scope; + var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList(); + + if (control_inputs != null) + { + foreach (var c in control_inputs) + gating_ops.Add(c); + } + + // Note that in order to ensure ordering in the pbtxt, we must take care to + // ensure the order here. + gating_ops = gating_ops.OrderBy(x => x._id).ToList(); + var gate = group(gating_ops.ToArray()); + + var tpl = new List(); + foreach (var t in tensors) + { + if (t != null) + tpl.Add(with_dependencies(new Operation[] { gate }, t)); + else + tpl.Add(null); + } + + return tpl.ToArray(); + }); + } + + internal static Tensor _case_helper(Func cond_fn, Tensor[] pred_fn_pairs, Func callable_default, bool exclusive, string name, + bool allow_python_preds = false) + { + /* + (Tensor[] predicates, Tensor[] actions) = _case_verify_and_canonicalize_args( + pred_fn_pairs, exclusive, name, allow_python_preds); + return tf_with(ops.name_scope(name, "case", new [] {predicates}), delegate + { + if (callable_default == null) + { + (callable_default, predicates, actions) = _case_create_default_action( + predicates, actions); + } + var fn = callable_default; + }); + */ + + throw new NotImplementedException("_case_helper"); + } + + internal static (Func, Tensor[], Tensor[]) _case_create_default_action(Tensor[] predicates, Tensor[] actions) + { + throw new NotImplementedException("_case_create_default_action"); + } + + internal static (Tensor[], Tensor[]) _case_verify_and_canonicalize_args(Tensor[] pred_fn_pairs, bool exclusive, string name, bool allow_python_preds) + { + throw new NotImplementedException("_case_verify_and_canonicalize_args"); + } + + public static Tensor case_v2(Tensor[] pred_fn_pairs, Func callable_default = null, bool exclusive = false, bool strict = false, string name = "case") + => _case_helper( + cond_fn: (Tensor x) => cond(x), + pred_fn_pairs, + default, + exclusive, + name, + allow_python_preds: false//, + //strict: strict + ); + + /// + /// Produces the content of `output_tensor` only after `dependencies`. + /// + /// In some cases, a user may want the output of an operation to be + /// consumed externally only after some other dependencies have run + /// first.This function ensures returns `output_tensor`, but only after all + /// operations in `dependencies` have run.Note that this means that there is + /// no guarantee that `output_tensor` will be evaluated after any `dependencies` + /// have run. + /// + /// See also `tf.tuple` and `tf.group`. + /// + /// Iterable of operations to run before this op finishes. + /// A `Tensor` or `IndexedSlices` that will be returned. + /// (Optional) A name for this operation. + /// Same as `output_tensor`. + public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null) + { + //TODO: missing original code + //if context.executing_eagerly(): + // return output_tensor + return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope => + { + name = scope; + ops.colocate_with(output_tensor); + { + return tf_with(ops.control_dependencies(dependencies), ctl => + { + output_tensor = ops.convert_to_tensor_or_composite(output_tensor); + return _Identity(output_tensor, name: name); + }); + } + }); + } + + public static Tensor _Identity(Tensor data, string name = null) + { + data = ops.internal_convert_to_tensor_or_composite(data, as_ref: true); + if ((int)data.dtype > 100) + throw new NotImplementedException("_Identity"); + else + return gen_array_ops.identity(data, name: name); + } + + public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, Shape[] shapes = null) + { + if (shapes == null) + return; + + var flat_shapes = nest.flatten2(shapes); + foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes)) + { + var.shape = shape; + } + } + + /// + /// Forwards `data` to an output determined by `pred`. + /// If `pred` is false, the `data` input is forwarded to the first output. + /// Otherwise, the data goes to the second output. + /// + /// This op handles `Tensor`s and `IndexedSlices`. + /// + /// The tensor to be forwarded to the appropriate output. + /// A scalar that specifies which output port will receive data. + /// A name for this operation (optional). + /// + /// `(output_false, output_true)`: If `pred` is true, data will be forwarded to + /// `output_true`, otherwise it goes to `output_false`. + /// + public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") + { + data = ops.convert_to_tensor_or_composite(data, name: "data"); + // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below + // addresses the following scenario. + // + // Assume you execute Optimizer.apply_gradients() in a branch of a cond(). + // + // 1. The update op is created inside a `with ops.colocate(var):` block + // + // 2. Some tensor `data` is captured and a switch is created in a + // `with ops.colocate_with(data):` block. + // + // with ops.colocate_with(var): + // with ops.colocate_with(data): + // op = ... + // + // var and data may be pinned to different devices, so we want to ops + // created within ops.colocate_with(data) to ignore the existing stack. + ops.colocate_with(data, ignore_existing: true); + { + if (data is Tensor) + { + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_switch(data, pred, name: name); + } + return @switch(data, pred, name: name); + } + } + + /// + /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. + /// + /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and + /// `false_fn` must have the same non-zero number and type of outputs. + /// + /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and + /// `false_fn` will be executed regardless of which branch is selected at runtime. + /// + /// Although this behavior is consistent with the dataflow model of TensorFlow, + /// it has frequently surprised users who expected a lazier semantics. + /// Consider the following simple program: + /// + /// z = tf.multiply(a, b) + /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) + /// + /// If `x<y`, the `tf.add` operation will be executed and `tf.square` + /// operation will not be executed.Since `z` is needed for at least one + /// branch of the `cond`, the `tf.multiply` operation is always executed, + /// unconditionally. + /// + /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the + /// call to `cond`, and not at all during `Session.run()`). `cond` + /// stitches together the graph fragments created during the `true_fn` and + /// `false_fn` calls with some additional graph nodes to ensure that the right + /// branch gets executed depending on the value of `pred`. + /// + /// `tf.cond` supports nested structures as implemented in + /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the + /// same(possibly nested) value structure of lists, tuples, and/or named tuples. + /// Singleton lists and tuples form the only exceptions to this: when returned by + /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. + /// This behavior is disabled by passing `strict= True`. + /// + /// A scalar determining whether to return the result of `true_fn` or + /// `false_fn`. + /// The callable to be performed if pred is true. + /// The callable to be performed if pred is false. + /// A boolean that enables/disables 'strict' mode; see above. + /// Optional name prefix for the returned tensors. + /// Tensors returned by the call to either `true_fn` or `false_fn`. If the + /// callables return a singleton list, the element is extracted from the list. + public static Tensor cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return tf_with(ops.name_scope(name, "cond", new { pred }), delegate + { + if (tf.Context.executing_eagerly()) + { + if ((bool)pred) + return true_fn() as Tensor; + else + return false_fn() as Tensor; + } + + // Add the Switch to the graph. + var switch_result = @switch(pred, pred); + var (p_2, p_1) = (switch_result[0], switch_result[1]); + var pivot_1 = array_ops.identity(p_1, name: "switch_t"); + var pivot_2 = array_ops.identity(p_2, name: "switch_f"); + pred = array_ops.identity(pred, name: "pred_id"); + + // Disable the fetching of tensors that are only on one branch of cond. + foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) + tensor.op.graph.prevent_fetching(tensor.op); + + // Build the graph for the true branch in a new context. + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); + ITensorOrOperation orig_res_t; + Tensor res_t; + try + { + context_t.Enter(); + (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + context_t.ExitResult(new[] { res_t }); + } + finally + { + context_t.Exit(); + } + // Build the graph for the false branch in a new context. + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); + ITensorOrOperation orig_res_f; + Tensor res_f; + try + { + context_f.Enter(); + (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + context_f.ExitResult(new[] { res_f }); + } + finally + { + context_f.Exit(); + } + + var res_t_flat = new Tensor[] { res_t }; + var res_f_flat = new Tensor[] { res_f }; + + var merges = zip(res_f_flat, res_t_flat) + .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) + .ToArray(); + + if (orig_res_t is Tensor orig_res_tensor) + merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) + .Select(x => x as Tensor) + .ToArray(); + else + { + + } + + if (context_t.outer_context == null) + { + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } + + return merges[0]; + }); + } + + public static Tensor[] cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + bool strict = false, + string name = null) + { + return tf_with(ops.name_scope(name, "cond", new { pred }), delegate + { + if (tf.Context.executing_eagerly()) + { + if (pred.ToArray()[0]) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + } + + // Add the Switch to the graph. + var switch_result = @switch(pred, pred); + var p_2 = switch_result[0]; + var p_1 = switch_result[1]; + var pivot_1 = array_ops.identity(p_1, name: "switch_t"); + var pivot_2 = array_ops.identity(p_2, name: "switch_f"); + pred = array_ops.identity(pred, name: "pred_id"); + + // Disable the fetching of tensors that are only on one branch of cond. + foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) + tensor.op.graph.prevent_fetching(tensor.op); + + // Build the graph for the true branch in a new context. + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); + context_t.Enter(); + var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + context_t.ExitResult(res_t); + context_t.Exit(); + + // Build the graph for the false branch in a new context. + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); + context_f.Enter(); + var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + context_f.ExitResult(res_f); + context_f.Exit(); + + var res_t_flat = res_t; + var res_f_flat = res_f; + + var merges = zip(res_f_flat, res_t_flat) + .Select(pair => merge(new[] { pair.Item1, pair.Item2 })[0]) + .ToArray(); + + if (orig_res_t is Tensor[] orig_res_tensor) + merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges) + .Select(x => x as Tensor) + .ToArray(); + else if (orig_res_t is float[] orig_res_float) + { + + } + else + { + + } + + if (context_t.outer_context == null) + { + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } + + return merges; + }); + } + + public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) + { + return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x => + { + var (ta, t_or_flow) = (x.Item1, x.Item2); + if (ta is TensorArray ta_1) + return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray; + else + return t_or_flow as ITensorOrTensorArray; + }).ToArray(); + } + + /// + /// Returns the value of an available element of `inputs`. + /// + /// This op tests each of the tensors in `inputs` in turn to determine if any of + /// them is available.If it finds an available tensor, it returns it and its + /// index in `inputs`. + /// + /// It is an error if more than one tensor in `inputs` is available.If no tensor + /// in `inputs` is available, the returned tensor and index are not set. + /// + /// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of + /// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices + /// before merging. + /// + /// inputs: The input tensors, at most one of which is available. + /// A name for this operation (optional). + /// + public static MergeOutput merge(Tensor[] inputs, string name = null) + { + if (inputs.Any(x => x == null)) + throw new ValueError($"At least one of the merge inputs is null: {inputs}"); + return tf_with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + inputs = inputs.Select(inp => + ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) + .ToArray(); + return gen_control_flow_ops.merge(inputs, name); + }); + } + + /// + /// Forwards `data` to an output determined by `pred`. + /// + /// + /// + /// + /// + public static Tensor[] @switch(Tensor data, + Tensor pred, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null) + { + return tf_with(ops.name_scope(name, "Switch", new { data, pred }), scope => + { + name = scope; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, + dtype: dtype, + name: "data", + as_ref: true); + + pred = ops.convert_to_tensor(pred, name: "pred"); + + return gen_control_flow_ops.@switch(data, pred, name: name); + }); + } + + public static Tensor ZerosLikeOutsideLoop(Operation op, int index) + { + var val = op.outputs[index]; + if (!util.IsSwitch(op)) + { + if (val.dtype == TF_DataType.TF_RESOURCE) + throw new NotImplementedException("ZerosLikeOutsideLoop"); + return array_ops.zeros_like(val, optimize: false); + } + else + { + var op_ctxt = op._get_control_flow_context(); + if (op_ctxt != null) + { + // We are in a cond context. Use a switch to create zeros only when needed. + var pred = op_ctxt.pred; + var branch = op_ctxt.branch; + var switch_val = @switch(op.inputs[0], pred)[1 - branch]; + var pivot = array_ops.identity(switch_val); + if (val.dtype == dtypes.resource) + throw new NotImplementedException(""); + var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); + // Ensure ops created within array_ops.zeros are dominated by switch in + // cond context. + return tf_with(ops.control_dependencies(new[] { pivot }), delegate + { + return array_ops.zeros(zeros_shape, dtype: val.dtype); + }); + } + else + { + return array_ops.zeros_like(val, optimize: false); + } + } + } + + public static Tensors while_loop(Func cond, + Func body, + Tensors loop_vars, + int parallel_iterations = 10, + string name = null) + { + var executing_eagerly = tf.Context.executing_eagerly(); + if (!executing_eagerly) + { + return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations, + name: name); + } + + return tf_with(ops.name_scope("name", "while"), delegate + { + while ((bool)cond(loop_vars)) + { + loop_vars = body(loop_vars); + } + + return loop_vars; + }); + } + + /// + /// Repeat `body` while the condition `cond` is true. + /// + /// + /// + /// + /// + public static TItem while_loop(Func cond, Func body, TItem loop_vars, + Shape[] shape_invariants = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + string name = null, + Tensor maximum_iterations = null, + bool return_same_structure = false) where TItem : IFromMergeVars, new() + { + return tf_with(ops.name_scope(name, "while", loop_vars), scope => + { + if (loop_vars == null) + throw new ValueError("No loop variables provided"); + if (cond == null) + throw new ValueError("cond must be callable."); + if (body == null) + throw new ValueError("body must be callable."); + if (parallel_iterations < 1) + throw new ValueError("parallel_iterations must be a positive integer."); + + var try_to_pack = loop_vars is Tensor && !return_same_structure; + var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); + var orig_cond = cond; + var orig_body = body; + + LoopVar loop_vars_1 = null; + Func, LoopVar> body_buildloop = null; + Func, Tensor> cond_buildloop = null; + + if (try_to_pack) + { + + } + else + { + loop_vars_1 = new LoopVar(counter, loop_vars); + cond_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + var oc = orig_cond(lv); + return math_ops.logical_and(i < maximum_iterations, oc); + }; + + body_buildloop = (item) => + { + var (i, lv) = (item.Counter, item.Item); + var ob = orig_body(lv); + return new LoopVar(i + 1, ob); + }; + } + try_to_pack = false; + + var loop_context = new WhileContext( + maximum_iterations: maximum_iterations, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory); + + if (loop_context.outer_context == null) + ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); + + var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, + return_same_structure); + + //if (maximum_iterations != null) + return results.Item; + //else + //return results; + }); + } + + /// + /// Creates or finds a child frame, and makes `data` available to it. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor _Enter(Tensor data, string frame_name, + bool is_constant = false, + int parallel_iterations = 10, + bool use_ref = true, + bool use_input_shape = true, + string name = null) + { + Tensor result; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype() && use_ref) + throw new NotImplementedException("_Enter"); + else + result = gen_control_flow_ops.enter( + data, frame_name, is_constant, parallel_iterations, name: name); + + if (use_input_shape) + result.shape = data.shape; + + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs new file mode 100644 index 000000000..536d4e3c2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -0,0 +1,277 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Graphs; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class control_flow_util + { + public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0"); + /// + /// Return true if `op` is an Exit. + /// + /// + /// + public static bool IsLoopExit(Operation op) + { + return op.type == "Exit" || op.type == "RefExit"; + } + + /// + /// Returns true if `op` is an Enter. + /// + /// + /// + public static bool IsLoopEnter(Operation op) + { + return op.type == "Enter" || op.type == "RefEnter"; + } + + /// + /// Return true iff op is a loop invariant. + /// + /// + /// + public static bool IsLoopConstantEnter(Operation op) + { + return IsLoopEnter(op) && op.get_attr("is_constant"); + } + + /// + /// Return true if `op` is a Switch. + /// + /// + /// + public static bool IsSwitch(Operation op) + { + return op.type == "Switch" || op.type == "RefSwitch"; + } + + public static WhileContext GetWhileContext(Operation op) + => op.GetWhileContext(); + + public static bool IsCondSwitch(Operation op) + { + if (!IsSwitch(op)) + return false; + if (op.outputs == null || op.outputs.Length == 0) + return false; + + // Switch nodes are not part of the cond control flow context that they + // represent, so consider the consumers of its outputs to determine if it is + // cond switch or not. A switch is a cond switch iff all its consumers are in + // cond contexts. + var is_cond_switch = true; + foreach (var o in op.outputs) + { + foreach (var c in o.consumers()) + { + var ctxt = c._get_control_flow_context(); + if (IsLoopEnter(c)) + ctxt = ctxt.outer_context; + is_cond_switch = is_cond_switch && (ctxt != null && ctxt.IsCondContext()); + } + } + + return is_cond_switch; + } + + public static bool IsLoopSwitch(Operation op) + { + if (IsSwitch(op)) + { + var ctxt = op._get_control_flow_context(); + return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); + } + return false; + } + + /// + /// Return the control flow context for the output of an op. + /// + public static ControlFlowContext GetOutputContext(Operation op) + { + var ctxt = op._get_control_flow_context(); + // Exit nodes usually have a control flow context, except in the case where the + // exit node was imported via import_graph_def (in which case no nodes have + // control flow contexts). + if (ctxt != null && IsLoopExit(op)) + ctxt = ctxt.outer_context; + return ctxt; + } + + public static void CheckInputFromValidContext(Operation op, Operation input_op) + { + var op_ctxt = op._get_control_flow_context(); + var input_ctxt = GetOutputContext(input_op); + var valid = false; + if (input_ctxt == null) + valid = true; + else if (op_ctxt == input_ctxt) + valid = true; + else + { + var while_ctxt = GetContainingWhileContext(op_ctxt); + var input_while_ctxt = GetContainingWhileContext(input_ctxt); + + if (while_ctxt == null) + { + // Neither op nor input_op is in a while loop, but one or both are in + // conds. We allow this, although execution will fail if the branch + // corresponding to input_op's cond context isn't taken. + if (input_while_ctxt == null) + valid = true; + // Invalid if op isn't in a while loop and input_op is. Unless... + if (IsLoopEnter(op)) + // WhileContext._BuildLoop clears context for Enter nodes. + valid = true; + if (IsSwitch(op)) + // CondContext.AddValue clears context for Switch nodes. + valid = true; + } + else if (IsContainingContext(while_ctxt, input_while_ctxt)) + { + // input_op is in a while loop which contains op's while loop (or not in a + // while loop at all). + valid = true; + } + else if (while_ctxt.grad_state != null && + IsContainingContext(while_ctxt.grad_state.forward_context, + input_while_ctxt)) + { + valid = true; + } + else + throw new NotImplementedException("CheckInputFromValidContext"); + } + + if (!valid) + { + throw new NotImplementedException("CheckInputFromValidContext"); + } + } + + public static Operation GetLoopConstantEnter(Tensor value) + { + var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" }; + var op = value.op; + while (id_ops.Contains(op.type)) + op = op.inputs[0].op; + return IsLoopConstantEnter(op) ? op : null; + } + + public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt) + { + while (ctxt != maybe_containing_ctxt) + { + if (ctxt == null) + return false; + ctxt = ctxt.outer_context as WhileContext; + } + return true; + } + + public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) + { + while (ctxt != null) + { + if (ctxt.IsWhileContext() || ctxt == stop_ctxt) + return ctxt as WhileContext; + ctxt = ctxt.outer_context; + } + return null; + } + + public static bool EnableControlFlowV2(Graph graph) + { + return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0); + + } + + public static string create_new_tf_function(FuncGraph func_graph) + { + var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary()); + func.AddToGraph(func_graph); + return func_graph.Name; + } + + public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs) + { + if(inputs.Length == 0) + { + return (null, new Tensor[0]); + } + else + { + return (inputs[0], inputs); + } + } + + public static Tensor[] run_as_function_for_tape_gradients(Func make_op, Tensor[] inputs) + { + if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER + && !(ops.get_default_graph().building_function)) + { + throw new NotImplementedException(); + } + else + { + return make_op(inputs); + } + } + + public static string unique_fn_name(string scope, string name) + { + return $"{scope}{name}_{ops.uid()}".Replace("/", "_"); + } + + public static bool output_all_intermediates() + { + if (in_defun()) + { + return false; + } + if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR") + { + return false; + } + // TODO(Rinne): check this after refactoring keras building. + return false; + } + + public static bool in_defun() + { + if (tf.Context.executing_eagerly()) + { + return false; + } + + var graph = ops.get_default_graph(); + // TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph + return graph is FuncGraph; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ctc_ops.cs b/src/TensorFlowNET.Core/Operations/ctc_ops.cs new file mode 100644 index 000000000..348f4e1a6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ctc_ops.cs @@ -0,0 +1,62 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public class ctc_ops + { + /// + /// Performs greedy decoding on the logits given in inputs. + /// + /// + /// 3-D, shape: (max_time x batch_size x num_classes), the logits. + /// + /// + /// A vector containing sequence lengths, size (batch_size). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCGreedyDecoder'. + /// + /// + /// If True, merge repeated classes in output. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// decoded_indices : Indices matrix, size (total_decoded_outputs x 2), + /// of a SparseTensor&lt;int64, 2&gt;. The rows store: [batch, time]. + /// decoded_values : Values vector, size: (total_decoded_outputs), + /// of a SparseTensor&lt;int64, 2&gt;. The vector stores the decoded classes. + /// decoded_shape : Shape vector, size (2), of the decoded SparseTensor. + /// Values are: [batch_size, max_decoded_length]. + /// log_probability : Matrix, size (batch_size x 1), containing sequence + /// log-probabilities. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// A note about the attribute merge_repeated: if enabled, when + /// consecutive logits' maximum indices are the same, only the first of + /// these is emitted. Labeling the blank '*', the sequence "A B B * B B" + /// becomes "A B B" if merge_repeated = True and "A B B B B" if + /// merge_repeated = False. + /// + /// Regardless of the value of merge_repeated, if the maximum index of a given + /// time and batch corresponds to the blank, index (num_classes - 1), no new + /// element is emitted. + /// + public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null) + => gen_ctc_ops.ctc_greedy_decoder(inputs, sequence_length, merge_repeated: merge_repeated, name: name); + } +} diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs new file mode 100644 index 000000000..061fb95e3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -0,0 +1,370 @@ +using System; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Framework.Models; +using Tensorflow.Functions; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class dataset_ops + { + public Tensor tensor_dataset(Tensor[] components, Shape[] output_shapes, string name = null) + => tf.Context.ExecuteOp("TensorDataset", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { components } + }.SetAttributes(new { output_shapes })); + + /// + /// Creates a dataset that emits each dim-0 slice of `components` once. + /// + /// + /// + /// + /// + public Tensor tensor_slice_dataset(Tensor[] components, Shape[] output_shapes, string name = null) + => tf.Context.ExecuteOp("TensorSliceDataset", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { components } + }.SetAttributes(new { output_shapes })); + + public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, Shape[] output_shapes, string name = null) + => tf.Context.ExecuteOp("RangeDataset", name, new ExecuteOpArgs(start, stop, step) + .SetAttributes(new { output_types, output_shapes })); + + public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, Shape[] output_shapes, string name = null) + => tf.Context.ExecuteOp("RepeatDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); + + public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index, + TF_DataType[] output_types, Shape[] output_shapes, + bool require_non_empty = false, string name = null) + => tf.Context.ExecuteOp("ShardDataset", name, new ExecuteOpArgs(input_dataset, num_shards, index) + .SetAttributes(new { require_non_empty, output_types, output_shapes })); + + public Tensor zip_dataset(Tensor[] input_datasets, + TF_DataType[] output_types, + Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("ZipDataset", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { input_datasets } + }.SetAttributes(new { output_types, output_shapes })); + + public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size, + Tensor seed, Tensor seed2, Tensor seed_generator, + TF_DataType[] output_types, Shape[] output_shapes, + bool reshuffle_each_iteration = true, + string name = null) + => tf.Context.ExecuteOp("ShuffleDatasetV3", name, new ExecuteOpArgs(input_dataset, buffer_size, seed, seed2, seed_generator) + .SetAttributes(new { reshuffle_each_iteration, output_types, output_shapes })); + + public Tensor skip_dataset(Tensor input_dataset, Tensor count, + TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("SkipDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); + + public Tensor dummy_seed_generator(string name = null) + => tf.Context.ExecuteOp("DummySeedGenerator", name, new ExecuteOpArgs()); + + public Tensor concatenate_dataset(Tensor input_dataset, Tensor another_dataset, + TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("ConcatenateDataset", name, new ExecuteOpArgs(input_dataset, another_dataset) + .SetAttributes(new { output_types, output_shapes })); + + public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache, + TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("CacheDatasetV2", name, new ExecuteOpArgs(input_dataset, filename, cache) + .SetAttributes(new { output_types, output_shapes })); + + /// + /// Creates a dataset that batches `batch_size` elements from `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_dataset_v2(Tensor input_dataset, Tensor buffer_size, + Tensor drop_remainder, + TF_DataType[] output_types, Shape[] output_shapes, + bool parallel_copy = false, + string name = null) + => tf.Context.ExecuteOp("BatchDatasetV2", name, + new ExecuteOpArgs(input_dataset, buffer_size, drop_remainder) + .SetAttributes(new { parallel_copy, output_types, output_shapes })); + + /// + /// + /// + /// + /// + public Tensor dummy_memory_cache(string name = "") + => tf.Context.ExecuteOp("DummyMemoryCache", name, new ExecuteOpArgs()); + + /// + /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor prefetch_dataset(Tensor input_dataset, Tensor buffer_size, + TF_DataType[] output_types, Shape[] output_shapes, + int? slack_period = 0, + bool legacy_autotune = true, + string name = null) + => tf.Context.ExecuteOp("PrefetchDataset", name, new ExecuteOpArgs(input_dataset, buffer_size) + .SetAttributes(new + { + output_types, + output_shapes, + slack_period, + legacy_autotune + })); + + /// + /// Creates a dataset that contains `count` elements from the `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + public Tensor take_dataset(Tensor input_dataset, Tensor count, + TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("TakeDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); + + /// + /// Creates a dataset by applying optimizations to `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor optimize_dataset(Tensor input_dataset, Tensor optimizations, + TF_DataType[] output_types, Shape[] output_shapes, + string[] optimization_configs = null, + string name = null) + => tf.Context.ExecuteOp("OptimizeDataset", name, new ExecuteOpArgs(input_dataset, optimizations) + .SetAttributes(new + { + output_types, + output_shapes, + optimization_configs = optimization_configs ?? new string[0] + })); + + public Tensor optimize_dataset_v2(Tensor input_dataset, Tensor optimizations_enabled, + Tensor optimizations_disabled, Tensor optimizations_default, + TF_DataType[] output_types, Shape[] output_shapes, + string[] optimization_configs = null, + string name = null) + => tf.Context.ExecuteOp("OptimizeDatasetV2", name, new ExecuteOpArgs(input_dataset, + optimizations_enabled, optimizations_disabled, optimizations_default) + .SetAttributes(new + { + output_types, + output_shapes, + optimization_configs = optimization_configs ?? new string[0] + })); + + /// + /// Identity transformation that models performance. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor model_dataset(Tensor input_dataset, + TF_DataType[] output_types, Shape[] output_shapes, + AutotuneAlgorithm algorithm, long cpu_budget, long ram_budget, + string name = null) + => tf.Context.ExecuteOp("ModelDataset", name, new ExecuteOpArgs(input_dataset) + .SetAttributes(new + { + algorithm, + cpu_budget, + ram_budget, + output_types, + output_shapes + })); + + /// + /// A container for an iterator resource. + /// + /// + /// + /// + /// A tuple of `Tensor` objects (handle, deleter). + public (Tensor, Tensor) anonymous_iterator_v2(TF_DataType[] output_types, Shape[] output_shapes, string name = null) + { + var results = tf.Context.ExecuteOp("AnonymousIteratorV2", name, + new ExecuteOpArgs().SetAttributes(new { output_types, output_shapes })); + return (results[0], results[1]); + } + + public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null) + { + var ctx = tf.Context; + Dictionary attrs = new(); + attrs["output_types"] = output_types; + attrs["output_shapes"] = output_shapes; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "AnonymousIteratorV3", name) + { + attrs = attrs + }); + return result[0]; + } + catch (Exception) + { + return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx); + } + } + return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0]; + } + + public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx) + { + object[] attrs = new object[] { output_types, output_shapes }; + var result = _execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); + return result[0]; + } + + /// + /// Makes a new iterator from the given `dataset` and stores it in `iterator`. + /// + /// + /// + /// + /// The created Operation. + public void make_iterator(Tensor dataset, Tensor iterator, string name = null) + => tf.Context.ExecuteOp("MakeIterator", name, new ExecuteOpArgs(dataset, iterator)); + + /// + /// + /// + /// + /// + /// + /// + public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, Shape[] output_shapes, + bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) + => tf.Context.ExecuteOp("MapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new + { + f, + output_types, + output_shapes, + use_inter_op_parallelism, + preserve_cardinality + })); + + /// + /// Creates a dataset containing elements of `input_dataset` matching `predicate`. + /// + /// + /// + /// + /// + /// + /// + public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new + { + predicate, + output_types, + output_shapes + })); + + /// + /// Creates a dataset that applies `f` to the outputs of `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + public Tensor flat_map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, Shape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("FlatMapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new { f, output_types, output_shapes })); + + /// + /// Creates a dataset that applies `f` to the outputs of `input_dataset`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor parallel_map_dataset_v2(Tensor dataset, Tensor num_parallel_calls, ConcreteFunction f, + TF_DataType[] output_types, Shape[] output_shapes, + bool use_inter_op_parallelism = true, + string deterministic = "default", + bool preserve_cardinality = false, + string name = null) + => tf.Context.ExecuteOp("ParallelMapDatasetV2", name, + new ExecuteOpArgs(dataset, new Tensor[0], num_parallel_calls) + .SetAttributes(new + { + f, + output_types, + output_shapes, + use_inter_op_parallelism, + deterministic, + preserve_cardinality + })); + + /// + /// A container for an iterator resource. + /// + /// + /// + /// + /// The created Operation. + public void delete_iterator(Tensor handle, Tensor deleter, string name = null) + => tf.Context.ExecuteOp("DeleteIterator", name, new ExecuteOpArgs(handle, deleter)); + + /// + /// Gets the next output from the given iterator . + /// + /// + /// + /// + /// + /// + public Tensor[] iterator_get_next(Tensor iterator, TF_DataType[] output_types, Shape[] output_shapes, string name = null) + => tf.Context.ExecuteOp("IteratorGetNext", name, new ExecuteOpArgs(iterator) + .SetAttributes(new { output_types, output_shapes })); + } +} diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs new file mode 100644 index 000000000..2e349ed39 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs @@ -0,0 +1,116 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class embedding_ops + { + /// + /// Helper function for embedding_lookup and _compute_sampled_logits. + /// + /// + /// + /// + /// + /// + /// + public static Tensor _embedding_lookup_and_transform(IVariableV1 @params, + Tensor ids, + string partition_strategy = "mod", + string name = null, + string max_norm = null) + { + return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope => + { + name = scope; + int np = 1; + ids = ops.convert_to_tensor(ids, name: "ids"); + if (np == 1) + { + var gather = array_ops.gather(@params.AsTensor(), ids, name: name); + var result = _clip(gather, ids, max_norm); + + return array_ops.identity(result); + } + + throw new NotImplementedException("_embedding_lookup_and_transform"); + }); + } + + public static Tensor _embedding_lookup_and_transform(Tensor[] @params, + Tensor ids, + string partition_strategy = "mod", + string name = null, + string max_norm = null) + { + return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope => + { + name = scope; + int np = @params.Length; + @params = ops.convert_n_to_tensor_or_indexed_slices(@params, name: "params"); + ids = ops.convert_to_tensor(ids, name: "ids"); + if (np == 1) + { + ops.colocate_with(@params[0]); + var result = _clip(array_ops.gather(@params[0], ids, name: name), ids, max_norm); + return array_ops.identity(result); + } + else + { + // Flatten the ids. There are two cases where we need to do this. + throw new NotImplementedException("_embedding_lookup_and_transform"); + } + }); + } + + public static Tensor _clip(Tensor @params, Tensor ids, string max_norm = null) + { + if (max_norm == null) + return @params; + + throw new NotImplementedException("_clip"); + } + + public static Tensor embedding_lookup(Tensor[] @params, Tensor ids, + string partition_strategy = "mod", + string name = null, + bool validate_indices = true, + string max_norm = null) + { + return _embedding_lookup_and_transform(@params: @params, + ids: ids, + partition_strategy: partition_strategy, + name: name, + max_norm: max_norm); + } + + public static Tensor embedding_lookup(IVariableV1 @params, Tensor ids, + string partition_strategy = "mod", + string name = null, + bool validate_indices = true, + string max_norm = null) + { + return _embedding_lookup_and_transform(@params: @params, + ids: ids, + partition_strategy: partition_strategy, + name: name, + max_norm: max_norm); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs new file mode 100644 index 000000000..105479216 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -0,0 +1,309 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Functions; +using Tensorflow.Operations; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class functional_ops + { + public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, + bool executing_eagerly, string config, string executor_type) + { + if (tout is null) + { + throw new NotImplementedException(); + } + + if (config is null) + { + config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); + } + + if (executor_type is null) + { + executor_type = ""; + } + + if (executing_eagerly) + { + // TODO(Rinne): implement it. + + throw new NotImplementedException(); + } + + var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); + AttrValue tin_attr = new() + { + List = new AttrValue.Types.ListValue() + }; + tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); + AttrValue tout_attr = new() + { + List = new AttrValue.Types.ListValue() + }; + tout_attr.List.Type.AddRange(tout); + AttrValue func_attr = new() + { + Func = new NameAttrList() + }; + func_attr.Func.Name = f.Name; + AttrValue executor_type_attr = new AttrValue() + { + S = tf.compat.as_bytes(executor_type) + }; + AttrValue config_proto = new AttrValue() + { + S = ByteString.CopyFromUtf8(executor_type) + }; + + var graph = ops.get_default_graph(); + f.AddToGraph(graph); + // TODO(Rinne): complete it with `f.stateful` + var op_name = "PartitionedCall"; + string xla_compile_attr = "_XlaMustCompile"; + Dictionary op_attrs = new(); + op_attrs["Tin"] = tin_attr; + op_attrs["Tout"] = tout_attr; + op_attrs["f"] = func_attr; + op_attrs["config_proto"] = config_proto; + op_attrs["executor_type"] = executor_type_attr; + // TODO(Rinne): deal with `f.definition`. + var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), + name: op_name, attrs: op_attrs); + var outputs = op.outputs; + // TODO(Rinne): deal with `f.graph`. + return outputs; + } + public static Tensor scan( + Func fn, + Tensor elems, + Tensor initializer = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + bool reverse = false, + string name = null) + { + bool input_is_sequence = nest.is_sequence(elems); + + Tensor[] input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x).ToArray() : new[] { x }; + Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; + + bool output_is_sequence; + Func output_flatten; + Func output_pack; + if (initializer == null) + { + output_is_sequence = input_is_sequence; + output_flatten = input_flatten; + output_pack = input_pack; + } + else + { + output_is_sequence = nest.is_sequence(initializer); + output_flatten = (x) => output_is_sequence ? nest.flatten(x).ToArray() : new[] { x }; + output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0]; + } + + var elems_flat = input_flatten(elems); + + bool in_graph_mode = tf.Context.executing_eagerly(); + + return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope => + { + if (in_graph_mode) + { + // todo tf.net doesn't expose .caching_device + //// Any get_variable calls in fn will cache the first call locally + //// and not issue repeated network I/O requests for each iteration. + //var varscope = variable_scope.get_variable_scope(); + //bool varscope_caching_device_was_none = false; + //if (varscope.caching_device = null) + //{ + // // varscope.set_caching_device(lambda op: op.device) + // // varscope_caching_device_was_none = True + //} + } + + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToArray(); + + var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); + + // todo python had the below but dimension_value returns int which can't be null + //if (n == null) + //{ + // n = array_ops.shape(elems_flat[0])[0]; + //} + + var elems_ta = elems_flat.Select(elem => tf.TensorArray( + elem.dtype, + size: n, + dynamic_size: false, + element_shape: elem.shape.dims.Skip(1).ToArray(), + infer_shape: true)).ToList(); + + for (int index = 0; index < elems_ta.Count; index++) + { + elems_ta[index].unstack(elems_flat[index]); + } + + Tensor[] a_flat; + int i; + if (initializer == null) + { + a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToArray(); + i = 1; + } + else + { + Tensor[] initializer_flat = output_flatten(initializer); + a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToArray(); + i = 0; + } + + var accs_ta = a_flat.Select(init => tf.TensorArray( + dtype: init.dtype, + size: n, + element_shape: infer_shape ? init.shape : null, + dynamic_size: false, + infer_shape: infer_shape)).ToArray(); + + if (initializer == null) + { + for (int index = 0; index < accs_ta.Length; index++) + { + accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]); + } + } + + BodyItem compute(BodyItem item) + { + var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray()); + var packed_a = output_pack(item.A_Flat); + var a_out = fn(packed_a, packed_elems); + + var flat_a_out = output_flatten(a_out); + for (int j = 0; j < item.Accs_ta.Length; j++) + { + item.Accs_ta[j].write(item.I, flat_a_out[j]); + } + + var next_i = reverse ? item.I - 1 : item.I + 1; + return new BodyItem(next_i, flat_a_out, item.Accs_ta); + } + + int initial_i; + Func condition; + if (reverse) + { + initial_i = n - 1 - i; + condition = x => x.I >= 0; + } + else + { + initial_i = i; + condition = x => x.I < n; + } + + BodyItem bodyItem = + control_flow_ops.while_loop( + condition, + compute, + new BodyItem(tf.constant(initial_i), a_flat, accs_ta), + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + maximum_iterations: tf.constant(n)); + + var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToArray(); + + var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape.with_rank_at_least(1).dims[0])); + + foreach (var elem in elems_flat.Skip(1)) + { + n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape.with_rank_at_least(1).dims[0]))); + } + + foreach (Tensor r in results_flat) + { + r.shape = new Shape(n_static).concatenate(r.dims.Skip(1).ToArray()); + } + + // todo get working when the above caching_device is fixed + //if (in_graph_mode && varscope_caching_device_was_none) { + // varscope.set_caching_device(None); + //} + + return output_pack(results_flat); + }); + } + + internal class BodyItem : ICanBeFlattened, IPackable, IFromMergeVars + { + public Tensor I { get; set; } + public Tensor[] A_Flat { get; set; } + public TensorArray[] Accs_ta { get; set; } + + public BodyItem() + { + } + + public BodyItem(Tensor i, Tensor[] a_flat, TensorArray[] accs_ta) + { + I = i; + A_Flat = a_flat; + Accs_ta = accs_ta; + } + + public object[] Flatten() + { + var elements = new List { I }; + elements.AddRange(A_Flat); + elements.AddRange(Accs_ta); + return elements.ToArray(); + } + + public BodyItem Pack(object[] sequences) + { + I = sequences[0] as Tensor; + A_Flat = new[] { sequences[1] as Tensor }; + Accs_ta = new[] { sequences[2] as TensorArray }; + + return new BodyItem(I, A_Flat, Accs_ta); + } + + public BodyItem FromMergeVars(ITensorOrTensorArray[] merge_vars) + { + I = (Tensor)merge_vars[1]; + A_Flat = new[] { (Tensor)merge_vars[2] }; + Accs_ta = new[] { (TensorArray)merge_vars[3] }; + return this; + } + } + } +} + diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index b4bb76bff..8367c2f94 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -1,49 +1,10806 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Text; -using Tensorflow; +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ -namespace Tensorflow +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_array_ops { - public static class gen_array_ops + /// + /// + /// + /// + /// + /// + /// + public static Tensor batch_matrix_band_part(Tensor input, Tensor num_lower, Tensor num_upper, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatrixBandPart", name) { args = new object[] { input, num_lower, num_upper }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_matrix_band_part_eager_fallback(input, num_lower, num_upper, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["num_lower"] = num_lower; + keywords["num_upper"] = num_upper; + var _op = tf.OpDefLib._apply_op_helper("BatchMatrixBandPart", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BatchMatrixBandPart", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_matrix_band_part_eager_fallback(Tensor input, Tensor num_lower, Tensor num_upper, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, num_lower, num_upper }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("BatchMatrixBandPart", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatrixBandPart", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + public static Tensor batch_matrix_diag(Tensor diagonal, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatrixDiag", name) { args = new object[] { diagonal }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_matrix_diag_eager_fallback(diagonal, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["diagonal"] = diagonal; + var _op = tf.OpDefLib._apply_op_helper("BatchMatrixDiag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BatchMatrixDiag", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_matrix_diag_eager_fallback(Tensor diagonal, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { diagonal }; + object[] _attrs = new object[] { "T", diagonal.dtype }; + var _result = _execute.execute("BatchMatrixDiag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatrixDiag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + public static Tensor batch_matrix_diag_part(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatrixDiagPart", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_matrix_diag_part_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("BatchMatrixDiagPart", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BatchMatrixDiagPart", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_matrix_diag_part_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("BatchMatrixDiagPart", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatrixDiagPart", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + public static Tensor batch_matrix_set_diag(Tensor input, Tensor diagonal, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatrixSetDiag", name) { args = new object[] { input, diagonal }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_matrix_set_diag_eager_fallback(input, diagonal, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["diagonal"] = diagonal; + var _op = tf.OpDefLib._apply_op_helper("BatchMatrixSetDiag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BatchMatrixSetDiag", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_matrix_set_diag_eager_fallback(Tensor input, Tensor diagonal, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, diagonal }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("BatchMatrixSetDiag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatrixSetDiag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// BatchToSpace for 4-D tensors of type T. + /// + /// + /// + /// This is a legacy version of the more general BatchToSpaceND. + /// + /// Rearranges (permutes) data from batch into blocks of spatial data, followed by + /// cropping. This is the reverse transformation of SpaceToBatch. More specifically, + /// this op outputs a copy of the input tensor where values from the `batch` + /// dimension are moved in spatial blocks to the `height` and `width` dimensions, + /// followed by cropping along the `height` and `width` dimensions. + /// + /// + /// + /// + /// + /// + public static Tensor batch_to_space(Tensor input, Tensor crops, int block_size = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchToSpace", name) { args = new object[] { input, crops }, attrs = new Dictionary() { ["block_size"] = block_size } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_to_space_eager_fallback(input, crops, block_size: block_size, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["crops"] = crops; + keywords["block_size"] = block_size; + var _op = tf.OpDefLib._apply_op_helper("BatchToSpace", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "block_size", _op._get_attr_int("block_size"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("BatchToSpace", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_to_space_eager_fallback(Tensor input, Tensor crops, int block_size, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, crops }; + object[] _attrs = new object[] { "T", input.dtype, "block_size", block_size, "Tidx", crops.dtype }; + var _result = _execute.execute("BatchToSpace", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchToSpace", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// BatchToSpace for N-D tensors of type T. + /// + /// + /// + /// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape + /// `block_shape + [batch]`, interleaves these blocks back into the grid defined by + /// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as + /// the input. The spatial dimensions of this intermediate result are then + /// optionally cropped according to `crops` to produce the output. This is the + /// reverse of SpaceToBatch. See below for a precise description. + /// + /// + /// + /// + /// + /// + public static Tensor batch_to_space_nd(Tensor input, Tensor block_shape, Tensor crops, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchToSpaceND", name) { args = new object[] { input, block_shape, crops }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_to_space_nd_eager_fallback(input, block_shape, crops, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["block_shape"] = block_shape; + keywords["crops"] = crops; + var _op = tf.OpDefLib._apply_op_helper("BatchToSpaceND", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tblock_shape", _op._get_attr_type("Tblock_shape"), "Tcrops", _op._get_attr_type("Tcrops") }; + _execute.record_gradient("BatchToSpaceND", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_to_space_nd_eager_fallback(Tensor input, Tensor block_shape, Tensor crops, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, block_shape, crops }; + object[] _attrs = new object[] { "T", input.dtype, "Tblock_shape", block_shape.dtype, "Tcrops", crops.dtype }; + var _result = _execute.execute("BatchToSpaceND", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchToSpaceND", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Bitcasts a tensor from one type to another without copying data. + /// + /// + /// + /// Given a tensor `input`, this operation returns a tensor that has the same buffer + /// data as `input` with datatype `type`. + /// + /// If the input datatype `T` is larger than the output datatype `type` then the + /// shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. + /// + /// If `T` is smaller than `type`, the operator requires that the rightmost + /// dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from + /// [..., sizeof(`type`)/sizeof(`T`)] to [...]. + /// + /// tf.bitcast() and tf.cast() work differently when real dtype is casted as a complex dtype + /// (e.g. tf.complex64 or tf.complex128) as tf.cast() make imaginary part 0 while tf.bitcast() + /// gives module error. + /// For example, + /// + /// Example 1: + /// + /// >>> a = [1., 2., 3.] + /// >>> equality_bitcast = tf.bitcast(a, tf.complex128) + /// Traceback (most recent call last): + /// ... + /// InvalidArgumentError: Cannot bitcast from 1 to 18 [Op:Bitcast] + /// >>> equality_cast = tf.cast(a, tf.complex128) + /// >>> print(equality_cast) + /// tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) + /// + /// Example 2: + /// + /// >>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) + /// + /// + /// Example 3: + /// + /// >>> x = [1., 2., 3.] + /// >>> y = [0., 2., 3.] + /// >>> equality= tf.equal(x,y) + /// >>> equality_cast = tf.cast(equality,tf.float32) + /// >>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8) + /// >>> print(equality) + /// tf.Tensor([False True True], shape=(3,), dtype=bool) + /// >>> print(equality_cast) + /// tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) + /// >>> print(equality_bitcast) + /// tf.Tensor( + /// [[ 0 0 0 0] + /// [ 0 0 128 63] + /// [ 0 0 128 63]], shape=(3, 4), dtype=uint8) + /// + /// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different + /// endian orderings will give different results. + /// + /// + /// + /// + /// + public static Tensor bitcast(Tensor input, TF_DataType type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Bitcast", name) { args = new object[] { input }, attrs = new Dictionary() { ["type"] = type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bitcast_eager_fallback(input, type: type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["type"] = type; + var _op = tf.OpDefLib._apply_op_helper("Bitcast", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "type", _op._get_attr_type("type") }; + _execute.record_gradient("Bitcast", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor bitcast_eager_fallback(Tensor input, TF_DataType type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "type", type }; + var _result = _execute.execute("Bitcast", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Bitcast", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return the shape of s0 op s1 with broadcast. + /// + /// + /// + /// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the + /// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. + /// + /// + /// + /// + /// + public static Tensor broadcast_args(Tensor s0, Tensor s1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BroadcastArgs", name) { args = new object[] { s0, s1 }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return broadcast_args_eager_fallback(s0, s1, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["s0"] = s0; + keywords["s1"] = s1; + var _op = tf.OpDefLib._apply_op_helper("BroadcastArgs", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BroadcastArgs", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor broadcast_args_eager_fallback(Tensor s0, Tensor s1, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { s0, s1 }; + object[] _attrs = new object[] { "T", s0.dtype }; + var _result = _execute.execute("BroadcastArgs", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BroadcastArgs", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. + /// + /// + /// + /// This is typically used by gradient computations for a broadcasting operation. + /// + /// + /// + /// + /// + public static Tensor[] broadcast_gradient_args(Tensor s0, Tensor s1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BroadcastGradientArgs", name) { args = new object[] { s0, s1 }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return broadcast_gradient_args_eager_fallback(s0, s1, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["s0"] = s0; + keywords["s1"] = s1; + var _op = tf.OpDefLib._apply_op_helper("BroadcastGradientArgs", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BroadcastGradientArgs", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] broadcast_gradient_args_eager_fallback(Tensor s0, Tensor s1, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { s0, s1 }; + object[] _attrs = new object[] { "T", s0.dtype }; + var _result = _execute.execute("BroadcastGradientArgs", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BroadcastGradientArgs", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Broadcast an array for a compatible shape. + /// + /// + /// + /// Broadcasting is the process of making arrays to have compatible shapes + /// for arithmetic operations. Two shapes are compatible if for each + /// dimension pair they are either equal or one of them is one. + /// + /// For example: + /// + /// >>> x = tf.constant([[1, 2, 3]]) # Shape (1, 3,) + /// >>> y = tf.broadcast_to(x, [2, 3]) + /// >>> print(y) + /// tf.Tensor( + /// [[1 2 3] + /// [1 2 3]], shape=(2, 3), dtype=int32) + /// + /// In the above example, the input Tensor with the shape of `[1, 3]` + /// is broadcasted to output Tensor with shape of `[2, 3]`. + /// + /// When broadcasting, if a tensor has fewer axes than necessary its shape is + /// padded on the left with ones. So this gives the same result as the previous + /// example: + /// + /// >>> x = tf.constant([1, 2, 3]) # Shape (3,) + /// >>> y = tf.broadcast_to(x, [2, 3]) + /// + /// + /// When doing broadcasted operations such as multiplying a tensor + /// by a scalar, broadcasting (usually) confers some time or space + /// benefit, as the broadcasted tensor is never materialized. + /// + /// However, `broadcast_to` does not carry with it any such benefits. + /// The newly-created tensor takes the full memory of the broadcasted + /// shape. (In a graph context, `broadcast_to` might be fused to + /// subsequent operation and then be optimized away, however.) + /// + /// + /// + /// + /// + public static Tensor broadcast_to(Tensor input, Tensor shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BroadcastTo", name) { args = new object[] { input, shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return broadcast_to_eager_fallback(input, shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("BroadcastTo", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("BroadcastTo", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor broadcast_to_eager_fallback(Tensor input, Tensor shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, shape }; + object[] _attrs = new object[] { "T", input.dtype, "Tidx", shape.dtype }; + var _result = _execute.execute("BroadcastTo", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BroadcastTo", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Checks a tensor for NaN and Inf values. + /// + /// + /// + /// When run, reports an `InvalidArgument` error if `tensor` has any values + /// that are not a number (NaN) or infinity (Inf). Otherwise, returns the input + /// tensor. + /// + /// Example usage: + /// + /// ``` python + /// a = tf.Variable(1.0) + /// tf.debugging.check_numerics(a, message='') + /// + /// b = tf.Variable(np.nan) + /// try: + /// tf.debugging.check_numerics(b, message='Checking b') + /// except Exception as e: + /// assert "Checking b : Tensor had NaN values" in e.message + /// + /// c = tf.Variable(np.inf) + /// try: + /// tf.debugging.check_numerics(c, message='Checking c') + /// except Exception as e: + /// assert "Checking c : Tensor had Inf values" in e.message + /// ``` + /// + /// + /// + /// + /// + /// + /// Prefix of the error message. + /// + /// + /// + public static Tensor check_numerics(Tensor tensor, string message, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CheckNumerics", name) { args = new object[] { tensor }, attrs = new Dictionary() { ["message"] = message } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return check_numerics_eager_fallback(tensor, message: message, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["message"] = message; + var _op = tf.OpDefLib._apply_op_helper("CheckNumerics", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "message", _op.get_attr("message") }; + _execute.record_gradient("CheckNumerics", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor check_numerics_eager_fallback(Tensor tensor, string message, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor }; + object[] _attrs = new object[] { "T", tensor.dtype, "message", message }; + var _result = _execute.execute("CheckNumerics", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("CheckNumerics", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Checks a tensor for NaN, -Inf and +Inf values. + /// + /// + /// + /// When run, reports an `InvalidArgument` error if `tensor` has any values + /// that are not a number (NaN) or infinity (Inf). Otherwise, returns the input + /// tensor. Unlike CheckNumerics (V1), CheckNumericsV2 distinguishes -Inf and +Inf + /// in the errors it throws. + /// + /// + /// + /// + /// + /// Prefix of the error message. + /// + /// + /// + public static Tensor check_numerics_v2(Tensor tensor, string message, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CheckNumericsV2", name) { args = new object[] { tensor }, attrs = new Dictionary() { ["message"] = message } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return check_numerics_v2_eager_fallback(tensor, message: message, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["message"] = message; + var _op = tf.OpDefLib._apply_op_helper("CheckNumericsV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "message", _op.get_attr("message") }; + _execute.record_gradient("CheckNumericsV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor check_numerics_v2_eager_fallback(Tensor tensor, string message, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor }; + object[] _attrs = new object[] { "T", tensor.dtype, "message", message }; + var _result = _execute.execute("CheckNumericsV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("CheckNumericsV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + public static Tensor concat(Tensor concat_dim, Tensors values, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Concat", name) { args = new object[] { concat_dim, values }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return concat_eager_fallback(concat_dim, values, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["concat_dim"] = concat_dim; + keywords["values"] = values; + var _op = tf.OpDefLib._apply_op_helper("Concat", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("Concat", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor concat_eager_fallback(Tensor concat_dim, Tensors values, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.Add(concat_dim); + _inputs_flat_list.AddRange(values); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", values.Length, "T", values.dtype }; + var _result = _execute.execute("Concat", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Concat", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes offsets of concat inputs within its output. + /// + /// + /// + /// For example: + /// + /// >>> x = [2, 2, 7] + /// >>> y = [2, 3, 7] + /// >>> z = [2, 9, 7] + /// >>> offsets = concat_offset(1, [x, y, z]) + /// >>> [list(off.numpy()) for off in offsets] + /// [[0, 0, 0], [0, 2, 0], [0, 5, 0]] + /// + /// This is typically used by gradient computations for a concat operation. + /// + /// + /// + /// + /// + public static Tensor[] concat_offset(Tensor concat_dim, Tensors shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ConcatOffset", name) { args = new object[] { concat_dim, shape }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return concat_offset_eager_fallback(concat_dim, shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["concat_dim"] = concat_dim; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("ConcatOffset", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N") }; + _execute.record_gradient("ConcatOffset", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] concat_offset_eager_fallback(Tensor concat_dim, Tensors shape, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.Add(concat_dim); + _inputs_flat_list.AddRange(shape); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", shape.Length }; + var _result = _execute.execute("ConcatOffset", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ConcatOffset", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + public static Tensor concat_v2(Tensors values, Tensor axis, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ConcatV2", name) { args = new object[] { values, axis }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return concat_v2_eager_fallback(values, axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["values"] = values; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("ConcatV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("ConcatV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor concat_v2_eager_fallback(Tensors values, Tensor axis, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(values); + _inputs_flat_list.Add(axis); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", values.Length, "T", values.dtype, "Tidx", axis.dtype }; + var _result = _execute.execute("ConcatV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ConcatV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Shuffle dimensions of x according to a permutation and conjugate the result. + /// + /// + /// + /// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + /// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + /// `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` + /// + /// + /// + /// + /// + public static Tensor conjugate_transpose(Tensor x, Tensor perm, string? name = null) { - public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ConjugateTranspose", name) { args = new object[] { x, perm }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conjugate_transpose_eager_fallback(x, perm, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["perm"] = perm; + var _op = tf.OpDefLib._apply_op_helper("ConjugateTranspose", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tperm", _op._get_attr_type("Tperm") }; + _execute.record_gradient("ConjugateTranspose", _op.inputs, _attrs, _result); + } + return _result[0]; + } - public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) + public static Tensor conjugate_transpose_eager_fallback(Tensor x, Tensor perm, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, perm }; + object[] _attrs = new object[] { "T", x.dtype, "Tperm", perm.dtype }; + var _result = _execute.execute("ConjugateTranspose", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ConjugateTranspose", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a constant tensor. + /// + /// + /// + /// Attr `value` is the tensor to return. + /// + /// + /// + /// + public static Tensor _const(TensorProto value, TF_DataType dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Const", name) { args = new object[] { }, attrs = new Dictionary() { ["value"] = value, ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return const_eager_fallback(value: value, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("Const", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) { - /*var g = ops.get_default_graph(); - var op = new Operation(g, "Placeholder", "feed"); + object[] _attrs = new object[] { "value", _op.get_attr("value"), "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("Const", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var tensor = new Tensor(op, 0, dtype); + public static Tensor const_eager_fallback(TensorProto value, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "value", value, "dtype", dtype }; + var _result = _execute.execute("Const", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Const", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Identity op for gradient debugging. + /// + /// + /// + /// This op is hidden from public in Python. It is used by TensorFlow Debugger to + /// register gradient tensors for gradient debugging. + /// This op operates on non-reference-type tensors. + /// + /// + /// + /// + public static Tensor debug_gradient_identity(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DebugGradientIdentity", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return debug_gradient_identity_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("DebugGradientIdentity", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("DebugGradientIdentity", _op.inputs, _attrs, _result); + } + return _result[0]; + } - return tensor;*/ + public static Tensor debug_gradient_identity_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("DebugGradientIdentity", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DebugGradientIdentity", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Identity op for gradient debugging. + /// + /// + /// + /// This op is hidden from public in Python. It is used by TensorFlow Debugger to + /// register gradient tensors for gradient debugging. + /// This op operates on reference-type tensors. + /// + /// + /// + /// + public static Tensor debug_gradient_ref_identity(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("debug_gradient_ref_identity op does not support eager execution. Arg input is a ref."); + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("DebugGradientRefIdentity", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("DebugGradientRefIdentity", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var keywords = new Dictionary(); - keywords.Add("dtype", dtype); - keywords.Add("shape", shape); + public static Tensor debug_gradient_ref_identity_eager_fallback(Tensor input, string name, Context ctx) + { + throw new RuntimeError($"debug_gradient_ref_identity op does not support eager execution. Arg 'input' is a ref."); + } + /// + /// Makes a copy of `x`. + /// + /// + /// + public static Tensor deep_copy(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DeepCopy", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return deep_copy_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("DeepCopy", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("DeepCopy", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); - var _result = _op.outputs; - var _inputs_flat = _op.inputs; - var _attrs = new Dictionary(); + public static Tensor deep_copy_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("DeepCopy", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DeepCopy", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// DepthToSpace for tensors of type T. + /// + /// + /// + /// Rearranges data from depth into blocks of spatial data. + /// This is the reverse transformation of SpaceToDepth. More specifically, + /// this op outputs a copy of the input tensor where values from the `depth` + /// dimension are moved in spatial blocks to the `height` and `width` dimensions. + /// The attr `block_size` indicates the input block size and how the data is moved. + /// + /// * Chunks of data of size `block_size * block_size` from depth are rearranged + /// into non-overlapping blocks of size `block_size x block_size` + /// * The width of the output tensor is `input_depth * block_size`, whereas the + /// height is `input_height * block_size`. + /// * The Y, X coordinates within each block of the output image are determined + /// by the high order component of the input channel index. + /// * The depth of the input tensor must be divisible by + /// `block_size * block_size`. + /// + /// The `data_format` attr specifies the layout of the input and output tensors + /// with the following options: + /// "NHWC": `[ batch, height, width, channels ]` + /// "NCHW": `[ batch, channels, height, width ]` + /// "NCHW_VECT_C": + /// `qint8 [ batch, channels / 4, height, width, 4 ]` + /// + /// It is useful to consider the operation as transforming a 6-D Tensor. + /// e.g. for data_format = NHWC, + /// Each element in the input tensor can be specified via 6 coordinates, + /// ordered by decreasing memory layout significance as: + /// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates + /// within the input image, bX, bY means coordinates + /// within the output block, oC means output channels). + /// The output would be the input transposed to the following layout: + /// n,iY,bY,iX,bX,oC + /// + /// This operation is useful for resizing the activations between convolutions + /// (but keeping all data), e.g. instead of pooling. It is also useful for training + /// purely convolutional models. + /// + /// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and + /// block_size = 2: + /// + /// ``` + /// x = [[[[1, 2, 3, 4]]]] + /// + /// ``` + /// + /// This operation will output a tensor of shape `[1, 2, 2, 1]`: + /// + /// ``` + /// [[[[1], [2]], + /// [[3], [4]]]] + /// ``` + /// + /// Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, + /// the corresponding output will have 2x2 elements and will have a depth of + /// 1 channel (1 = `4 / (block_size * block_size)`). + /// The output element shape is `[2, 2, 1]`. + /// + /// For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. + /// + /// ``` + /// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] + /// ``` + /// + /// This operation, for block size of 2, will return the following tensor of shape + /// `[1, 2, 2, 3]` + /// + /// ``` + /// [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// ``` + /// + /// Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: + /// + /// ``` + /// x = [[[[1, 2, 3, 4], + /// [5, 6, 7, 8]], + /// [[9, 10, 11, 12], + /// [13, 14, 15, 16]]]] + /// ``` + /// + /// the operator will return the following tensor of shape `[1 4 4 1]`: + /// + /// ``` + /// x = [[[ [1], [2], [5], [6]], + /// [ [3], [4], [7], [8]], + /// [ [9], [10], [13], [14]], + /// [ [11], [12], [15], [16]]]] + /// + /// ``` + /// + /// + /// + /// + /// + /// The size of the spatial block, same as in Space2Depth. + /// + /// + /// + /// + public static Tensor depth_to_space(Tensor input, int block_size = 0, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DepthToSpace", name) { args = new object[] { input }, attrs = new Dictionary() { ["block_size"] = block_size, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return depth_to_space_eager_fallback(input, block_size: block_size, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["block_size"] = block_size; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("DepthToSpace", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "block_size", _op._get_attr_int("block_size"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("DepthToSpace", _op.inputs, _attrs, _result); + } + return _result[0]; + } - _attrs["dtype"] = _op.get_attr("dtype"); - _attrs["shape"] = _op.get_attr("shape"); + public static Tensor depth_to_space_eager_fallback(Tensor input, int block_size, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "block_size", block_size, "data_format", data_format }; + var _result = _execute.execute("DepthToSpace", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DepthToSpace", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Dequantize the 'input' tensor into a float or bfloat16 Tensor. + /// + /// + /// + /// [min_range, max_range] are scalar floats that specify the range for + /// the output. The 'mode' attribute controls exactly which calculations are + /// used to convert the float values to their quantized equivalents. + /// + /// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: + /// + /// ``` + /// if T == qint8: in[i] += (range(T) + 1)/ 2.0 + /// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) + /// ``` + /// here `range(T) = numeric_limits::max() - numeric_limits::min()` + /// + /// *MIN_COMBINED Mode Example* + /// + /// If the input comes from a QuantizedRelu6, the output type is + /// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is + /// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. + /// Dequantize on quint8 will take each value, cast to float, and multiply + /// by 6 / 255. + /// Note that if quantizedtype is qint8, the operation will additionally add + /// each value by 128 prior to casting. + /// + /// If the mode is 'MIN_FIRST', then this approach is used: + /// + /// ```c++ + /// num_discrete_values = 1 << (# of bits in T) + /// range_adjust = num_discrete_values / (num_discrete_values - 1) + /// range = (range_max - range_min) * range_adjust + /// range_scale = range / num_discrete_values + /// const double offset_input = static_cast(input) - lowest_quantized; + /// result = range_min + ((input - numeric_limits::min()) * range_scale) + /// ``` + /// + /// If the mode is `SCALED`, dequantization is performed by multiplying each + /// input value by a scaling_factor. (Thus an input of 0 always maps to 0.0). + /// + /// The scaling_factor is determined from `min_range`, `max_range`, and + /// `narrow_range` in a way that is compatible with `QuantizeAndDequantize{V2|V3}` + /// and `QuantizeV2`, using the following algorithm: + /// + /// ```c++ + /// + /// const int min_expected_T = std::numeric_limits::min() + + /// (narrow_range ? 1 : 0); + /// const int max_expected_T = std::numeric_limits::max(); + /// const float max_expected_T = std::numeric_limits::max(); + /// + /// const float scale_factor = + /// (std::numeric_limits::min() == 0) ? (max_range / max_expected_T) + /// : std::max(min_range / min_expected_T, + /// max_range / max_expected_T); + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Type of the output tensor. Currently Dequantize supports float and bfloat16. + /// If 'dtype' is 'bfloat16', it only supports 'MIN_COMBINED' mode. + /// + /// + /// + public static Tensor dequantize(Tensor input, Tensor min_range, Tensor max_range, string mode = "MIN_COMBINED", bool narrow_range = false, int axis = -1, TF_DataType dtype = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Dequantize", name) { args = new object[] { input, min_range, max_range }, attrs = new Dictionary() { ["mode"] = mode, ["narrow_range"] = narrow_range, ["axis"] = axis, ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return dequantize_eager_fallback(input, min_range, max_range, mode: mode, narrow_range: narrow_range, axis: axis, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (mode is null) + { + mode = "MIN_COMBINED"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["min_range"] = min_range; + keywords["max_range"] = max_range; + keywords["mode"] = mode; + keywords["narrow_range"] = narrow_range; + keywords["axis"] = axis; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("Dequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "mode", _op.get_attr("mode"), "narrow_range", _op._get_attr_bool("narrow_range"), "axis", _op._get_attr_int("axis"), "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("Dequantize", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var tensor = new Tensor(_op, 0, dtype); - return tensor; + public static Tensor dequantize_eager_fallback(Tensor input, Tensor min_range, Tensor max_range, string mode, bool narrow_range, int axis, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, min_range, max_range }; + object[] _attrs = new object[] { "T", input.dtype, "mode", mode, "narrow_range", narrow_range, "axis", axis, "dtype", dtype }; + var _result = _execute.execute("Dequantize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Dequantize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a diagonal tensor with a given diagonal values. + /// + /// + /// + /// Given a `diagonal`, this operation returns a tensor with the `diagonal` and + /// everything else padded with zeros. The diagonal is computed as follows: + /// + /// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of + /// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: + /// + /// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. + /// + /// For example: + /// + /// ``` + /// # 'diagonal' is [1, 2, 3, 4] + /// tf.diag(diagonal) ==> [[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]] + /// ``` + /// + /// + /// + /// + public static Tensor diag(Tensor diagonal, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Diag", name) { args = new object[] { diagonal }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return diag_eager_fallback(diagonal, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["diagonal"] = diagonal; + var _op = tf.OpDefLib._apply_op_helper("Diag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Diag", _op.inputs, _attrs, _result); } + return _result[0]; + } - private static OpDefLibrary _InitOpDefLibrary() + public static Tensor diag_eager_fallback(Tensor diagonal, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { diagonal }; + object[] _attrs = new object[] { "T", diagonal.dtype }; + var _result = _execute.execute("Diag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Diag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the diagonal part of the tensor. + /// + /// + /// + /// This operation returns a tensor with the `diagonal` part + /// of the `input`. The `diagonal` part is computed as follows: + /// + /// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a + /// tensor of rank `k` with dimensions `[D1,..., Dk]` where: + /// + /// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. + /// + /// For example: + /// + /// ``` + /// # 'input' is [[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]] + /// + /// tf.diag_part(input) ==> [1, 2, 3, 4] + /// ``` + /// + /// + /// + /// + public static Tensor diag_part(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DiagPart", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return diag_part_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("DiagPart", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) { - // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); - var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); - var op_list = OpList.Parser.ParseFrom(bytes); - var op_def_lib = new OpDefLibrary(); - op_def_lib.add_op_list(op_list); + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("DiagPart", _op.inputs, _attrs, _result); + } + return _result[0]; + } - return op_def_lib; + public static Tensor diag_part_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("DiagPart", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DiagPart", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the (possibly normalized) Levenshtein Edit Distance. + /// + /// + /// + /// The inputs are variable-length sequences provided by SparseTensors + /// (hypothesis_indices, hypothesis_values, hypothesis_shape) + /// and + /// (truth_indices, truth_values, truth_shape). + /// + /// The inputs are: + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// boolean (if true, edit distances are normalized by length of truth). + /// + /// The output is: + /// + /// + /// + public static Tensor edit_distance(Tensor hypothesis_indices, Tensor hypothesis_values, Tensor hypothesis_shape, Tensor truth_indices, Tensor truth_values, Tensor truth_shape, bool normalize = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EditDistance", name) { args = new object[] { hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape }, attrs = new Dictionary() { ["normalize"] = normalize } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return edit_distance_eager_fallback(hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, normalize: normalize, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["hypothesis_indices"] = hypothesis_indices; + keywords["hypothesis_values"] = hypothesis_values; + keywords["hypothesis_shape"] = hypothesis_shape; + keywords["truth_indices"] = truth_indices; + keywords["truth_values"] = truth_values; + keywords["truth_shape"] = truth_shape; + keywords["normalize"] = normalize; + var _op = tf.OpDefLib._apply_op_helper("EditDistance", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "normalize", _op._get_attr_bool("normalize"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("EditDistance", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor edit_distance_eager_fallback(Tensor hypothesis_indices, Tensor hypothesis_values, Tensor hypothesis_shape, Tensor truth_indices, Tensor truth_values, Tensor truth_shape, bool normalize, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape }; + object[] _attrs = new object[] { "normalize", normalize, "T", hypothesis_values.dtype }; + var _result = _execute.execute("EditDistance", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EditDistance", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor empty(Tensor shape, TF_DataType dtype, bool init = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Empty", name) { args = new object[] { shape }, attrs = new Dictionary() { ["dtype"] = dtype, ["init"] = init } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return empty_eager_fallback(shape, dtype: dtype, init: init, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["shape"] = shape; + keywords["dtype"] = dtype; + keywords["init"] = init; + var _op = tf.OpDefLib._apply_op_helper("Empty", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "init", _op._get_attr_bool("init") }; + _execute.record_gradient("Empty", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor empty_eager_fallback(Tensor shape, TF_DataType dtype, bool init, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { shape }; + object[] _attrs = new object[] { "dtype", dtype, "init", init }; + var _result = _execute.execute("Empty", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Empty", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Ensures that the tensor's shape matches the expected shape. + /// + /// + /// + /// Raises an error if the input tensor's shape does not match the specified shape. + /// Returns the input tensor otherwise. + /// + /// + /// + /// + /// + /// The expected (possibly partially specified) shape of the input tensor. + /// + /// + /// + public static Tensor ensure_shape(Tensor input, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EnsureShape", name) { args = new object[] { input }, attrs = new Dictionary() { ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return ensure_shape_eager_fallback(input, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("EnsureShape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "shape", _op.get_attr("shape"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("EnsureShape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "shape", shape, "T", input.dtype }; + var _result = _execute.execute("EnsureShape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EnsureShape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Inserts a dimension of 1 into a tensor's shape. + /// + /// + /// + /// Given a tensor `input`, this operation inserts a dimension of 1 at the + /// dimension index `dim` of `input`'s shape. The dimension index `dim` starts at + /// zero; if you specify a negative number for `dim` it is counted backward from + /// the end. + /// + /// This operation is useful if you want to add a batch dimension to a single + /// element. For example, if you have a single image of shape `[height, width, + /// channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, + /// which will make the shape `[1, height, width, channels]`. + /// + /// Other examples: + /// + /// ``` + /// # 't' is a tensor of shape [2] + /// shape(expand_dims(t, 0)) ==> [1, 2] + /// shape(expand_dims(t, 1)) ==> [2, 1] + /// shape(expand_dims(t, -1)) ==> [2, 1] + /// + /// # 't2' is a tensor of shape [2, 3, 5] + /// shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] + /// shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] + /// shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] + /// ``` + /// + /// This operation requires that: + /// + /// `-1-input.dims() <= dim <= input.dims()` + /// + /// This operation is related to `squeeze()`, which removes dimensions of + /// size 1. + /// + /// + /// + /// + /// + public static Tensor expand_dims(Tensor input, Tensor dim, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ExpandDims", name) { args = new object[] { input, dim }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return expand_dims_eager_fallback(input, dim, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["dim"] = dim; + var _op = tf.OpDefLib._apply_op_helper("ExpandDims", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tdim", _op._get_attr_type("Tdim") }; + _execute.record_gradient("ExpandDims", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor expand_dims_eager_fallback(Tensor input, Tensor dim, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, dim }; + object[] _attrs = new object[] { "T", input.dtype, "Tdim", dim.dtype }; + var _result = _execute.execute("ExpandDims", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ExpandDims", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Extract `patches` from `images` and put them in the "depth" output dimension. + /// + /// + /// + /// + /// The size of the sliding window for each dimension of `images`. + /// + /// + /// + /// + /// How far the centers of two consecutive patches are in + /// the images. Must be: `[1, stride_rows, stride_cols, 1]`. + /// + /// + /// + /// + /// Must be: `[1, rate_rows, rate_cols, 1]`. This is the + /// input stride, specifying how far two consecutive patch samples are in the + /// input. Equivalent to extracting patches with + /// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by + /// subsampling them spatially by a factor of `rates`. This is equivalent to + /// `rate` in dilated (a.k.a. Atrous) convolutions. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor extract_image_patches(Tensor images, int[] ksizes, int[] strides, int[] rates, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ExtractImagePatches", name) { args = new object[] { images }, attrs = new Dictionary() { ["ksizes"] = ksizes, ["strides"] = strides, ["rates"] = rates, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return extract_image_patches_eager_fallback(images, ksizes: ksizes, strides: strides, rates: rates, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["images"] = images; + keywords["ksizes"] = ksizes; + keywords["strides"] = strides; + keywords["rates"] = rates; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("ExtractImagePatches", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksizes", _op.get_attr("ksizes"), "strides", _op.get_attr("strides"), "rates", _op.get_attr("rates"), "T", _op._get_attr_type("T"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("ExtractImagePatches", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor extract_image_patches_eager_fallback(Tensor images, int[] ksizes, int[] strides, int[] rates, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { images }; + object[] _attrs = new object[] { "ksizes", ksizes, "strides", strides, "rates", rates, "T", images.dtype, "padding", padding }; + var _result = _execute.execute("ExtractImagePatches", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ExtractImagePatches", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Extract `patches` from `input` and put them in the `"depth"` output dimension. 3D extension of `extract_image_patches`. + /// + /// + /// + /// + /// The size of the sliding window for each dimension of `input`. + /// + /// + /// + /// + /// 1-D of length 5. How far the centers of two consecutive patches are in + /// `input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// The size-related attributes are specified as follows: + /// + /// ```python + /// ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1] + /// strides = [1, stride_planes, strides_rows, strides_cols, 1] + /// ``` + /// + /// + /// + public static Tensor extract_volume_patches(Tensor input, int[] ksizes, int[] strides, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ExtractVolumePatches", name) { args = new object[] { input }, attrs = new Dictionary() { ["ksizes"] = ksizes, ["strides"] = strides, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return extract_volume_patches_eager_fallback(input, ksizes: ksizes, strides: strides, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksizes"] = ksizes; + keywords["strides"] = strides; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("ExtractVolumePatches", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksizes", _op.get_attr("ksizes"), "strides", _op.get_attr("strides"), "T", _op._get_attr_type("T"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("ExtractVolumePatches", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor extract_volume_patches_eager_fallback(Tensor input, int[] ksizes, int[] strides, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "ksizes", ksizes, "strides", strides, "T", input.dtype, "padding", padding }; + var _result = _execute.execute("ExtractVolumePatches", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ExtractVolumePatches", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. + /// + /// + /// + /// Attributes + /// + /// * `[min; max]` define the clamping range for the `inputs` data. + /// * `inputs` values are quantized into the quantization range ( + /// `[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]` + /// when it is true) and then de-quantized and output as floats in `[min; max]` + /// interval. + /// * `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// Before quantization, `min` and `max` values are adjusted with the following + /// logic. + /// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, + /// the behavior can be unexpected: + /// + /// * If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. + /// * If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. + /// * If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, + /// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + /// + /// Quantization is called fake since the output is still in floating point. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor fake_quant_with_min_max_args(Tensor inputs, float min = -6f, float max = 6f, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxArgs", name) { args = new object[] { inputs }, attrs = new Dictionary() { ["min"] = min, ["max"] = max, ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_args_eager_fallback(inputs, min: min, max: max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxArgs", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "min", _op.get_attr("min"), "max", _op.get_attr("max"), "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxArgs", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_quant_with_min_max_args_eager_fallback(Tensor inputs, float min, float max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { inputs }; + object[] _attrs = new object[] { "min", min, "max", max, "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxArgs", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxArgs", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute gradients for a FakeQuantWithMinMaxArgs operation. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor fake_quant_with_min_max_args_gradient(Tensor gradients, Tensor inputs, float min = -6f, float max = 6f, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxArgsGradient", name) { args = new object[] { gradients, inputs }, attrs = new Dictionary() { ["min"] = min, ["max"] = max, ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_args_gradient_eager_fallback(gradients, inputs, min: min, max: max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxArgsGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "min", _op.get_attr("min"), "max", _op.get_attr("max"), "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxArgsGradient", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_quant_with_min_max_args_gradient_eager_fallback(Tensor gradients, Tensor inputs, float min, float max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, inputs }; + object[] _attrs = new object[] { "min", min, "max", max, "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxArgsGradient", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxArgsGradient", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Fake-quantize the 'inputs' tensor of type float via global float scalars + /// + /// + /// + /// Fake-quantize the `inputs` tensor of type float via global float scalars + /// `min` and `max` to `outputs` tensor of same shape as `inputs`. + /// + /// Attributes + /// + /// * `[min; max]` define the clamping range for the `inputs` data. + /// * `inputs` values are quantized into the quantization range ( + /// `[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]` + /// when it is true) and then de-quantized and output as floats in `[min; max]` + /// interval. + /// * `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// Before quantization, `min` and `max` values are adjusted with the following + /// logic. + /// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, + /// the behavior can be unexpected: + /// + /// * If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. + /// * If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. + /// * If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, + /// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + /// + /// This operation has a gradient and thus allows for training `min` and `max` + /// values. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor fake_quant_with_min_max_vars(Tensor inputs, Tensor min, Tensor max, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxVars", name) { args = new object[] { inputs, min, max }, attrs = new Dictionary() { ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_vars_eager_fallback(inputs, min, max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVars", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxVars", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_quant_with_min_max_vars_eager_fallback(Tensor inputs, Tensor min, Tensor max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { inputs, min, max }; + object[] _attrs = new object[] { "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxVars", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxVars", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute gradients for a FakeQuantWithMinMaxVars operation. + /// + /// + /// + /// + /// + /// + /// + /// The bitwidth of the quantization; between 2 and 8, inclusive. + /// + /// + /// + /// + /// Whether to quantize into 2^num_bits - 1 distinct values. + /// + /// + /// + public static Tensor[] fake_quant_with_min_max_vars_gradient(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxVarsGradient", name) { args = new object[] { gradients, inputs, min, max }, attrs = new Dictionary() { ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_vars_gradient_eager_fallback(gradients, inputs, min, max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxVarsGradient", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fake_quant_with_min_max_vars_gradient_eager_fallback(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, inputs, min, max }; + object[] _attrs = new object[] { "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxVarsGradient", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxVarsGradient", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Fake-quantize the 'inputs' tensor of type float via per-channel floats + /// + /// + /// + /// Fake-quantize the `inputs` tensor of type float per-channel and one of the + /// shapes: `[d]`, `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` + /// of shape `[d]` to `outputs` tensor of same shape as `inputs`. + /// + /// Attributes + /// + /// * `[min; max]` define the clamping range for the `inputs` data. + /// * `inputs` values are quantized into the quantization range ( + /// `[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]` + /// when it is true) and then de-quantized and output as floats in `[min; max]` + /// interval. + /// * `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// Before quantization, `min` and `max` values are adjusted with the following + /// logic. + /// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, + /// the behavior can be unexpected: + /// + /// * If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. + /// * If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. + /// * If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, + /// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. + /// + /// This operation has a gradient and thus allows for training `min` and `max` + /// values. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor fake_quant_with_min_max_vars_per_channel(Tensor inputs, Tensor min, Tensor max, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxVarsPerChannel", name) { args = new object[] { inputs, min, max }, attrs = new Dictionary() { ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_vars_per_channel_eager_fallback(inputs, min, max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsPerChannel", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxVarsPerChannel", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_quant_with_min_max_vars_per_channel_eager_fallback(Tensor inputs, Tensor min, Tensor max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { inputs, min, max }; + object[] _attrs = new object[] { "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxVarsPerChannel", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxVarsPerChannel", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. + /// + /// + /// + /// + /// + /// + /// + /// The bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// + /// + /// + /// Whether to quantize into 2^num_bits - 1 distinct values. + /// + /// + /// + public static Tensor[] fake_quant_with_min_max_vars_per_channel_gradient(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int num_bits = 8, bool narrow_range = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeQuantWithMinMaxVarsPerChannelGradient", name) { args = new object[] { gradients, inputs, min, max }, attrs = new Dictionary() { ["num_bits"] = num_bits, ["narrow_range"] = narrow_range } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_quant_with_min_max_vars_per_channel_gradient_eager_fallback(gradients, inputs, min, max, num_bits: num_bits, narrow_range: narrow_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["inputs"] = inputs; + keywords["min"] = min; + keywords["max"] = max; + keywords["num_bits"] = num_bits; + keywords["narrow_range"] = narrow_range; + var _op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsPerChannelGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_bits", _op._get_attr_int("num_bits"), "narrow_range", _op._get_attr_bool("narrow_range") }; + _execute.record_gradient("FakeQuantWithMinMaxVarsPerChannelGradient", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fake_quant_with_min_max_vars_per_channel_gradient_eager_fallback(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int num_bits, bool narrow_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, inputs, min, max }; + object[] _attrs = new object[] { "num_bits", num_bits, "narrow_range", narrow_range }; + var _result = _execute.execute("FakeQuantWithMinMaxVarsPerChannelGradient", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeQuantWithMinMaxVarsPerChannelGradient", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Creates a tensor filled with a scalar value. + /// + /// + /// + /// This operation creates a tensor of shape `dims` and fills it with `value`. + /// + /// For example: + /// + /// ``` + /// # Output tensor has shape [2, 3]. + /// fill([2, 3], 9) ==> [[9, 9, 9] + /// [9, 9, 9]] + /// ``` + /// + /// `tf.fill` differs from `tf.constant` in a few ways: + /// + /// * `tf.fill` only supports scalar contents, whereas `tf.constant` supports + /// Tensor values. + /// * `tf.fill` creates an Op in the computation graph that constructs the actual + /// Tensor value at runtime. This is in contrast to `tf.constant` which embeds + /// the entire Tensor into the graph with a `Const` node. + /// * Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes + /// based on other runtime Tensors, unlike `tf.constant`. + /// + /// + /// + /// + /// + public static Tensor fill(Tensor dims, Tensor value, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Fill", name) { args = new object[] { dims, value }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fill_eager_fallback(dims, value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dims"] = dims; + keywords["value"] = value; + var _op = tf.OpDefLib._apply_op_helper("Fill", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "index_type", _op._get_attr_type("index_type") }; + _execute.record_gradient("Fill", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { dims, value }; + object[] _attrs = new object[] { "T", value.dtype, "index_type", dims.dtype }; + var _result = _execute.execute("Fill", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Fill", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Generates fingerprint values. + /// + /// + /// + /// Generates fingerprint values of `data`. + /// + /// Fingerprint op considers the first dimension of `data` as the batch dimension, + /// and `output[i]` contains the fingerprint value generated from contents in + /// `data[i, ...]` for all `i`. + /// + /// Fingerprint op writes fingerprint values as byte arrays. For example, the + /// default method `farmhash64` generates a 64-bit fingerprint value at a time. + /// This 8-byte value is written out as an `uint8` array of size 8, in little-endian + /// order. + /// + /// For example, suppose that `data` has data type `DT_INT32` and shape (2, 3, 4), + /// and that the fingerprint method is `farmhash64`. In this case, the output shape + /// is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the size of + /// each fingerprint value in bytes. `output[0, :]` is generated from 12 integers in + /// `data[0, :, :]` and similarly `output[1, :]` is generated from other 12 integers + /// in `data[1, :, :]`. + /// + /// Note that this op fingerprints the raw underlying buffer, and it does not + /// fingerprint Tensor's metadata such as data type and/or shape. For example, the + /// fingerprint values are invariant under reshapes and bitcasts as long as the + /// batch dimension remain the same: + /// + /// ``` + /// Fingerprint(data) == Fingerprint(Reshape(data, ...)) + /// Fingerprint(data) == Fingerprint(Bitcast(data, ...)) + /// ``` + /// + /// For string data, one should expect `Fingerprint(data) != + /// Fingerprint(ReduceJoin(data))` in general. + /// + /// + /// + /// + /// + public static Tensor fingerprint(Tensor data, Tensor method, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Fingerprint", name) { args = new object[] { data, method }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fingerprint_eager_fallback(data, method, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["method"] = method; + var _op = tf.OpDefLib._apply_op_helper("Fingerprint", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Fingerprint", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fingerprint_eager_fallback(Tensor data, Tensor method, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, method }; + object[] _attrs = new object[] { "T", data.dtype }; + var _result = _execute.execute("Fingerprint", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Fingerprint", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gather slices from `params` according to `indices`. + /// + /// + /// + /// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: + /// + /// ```python + /// # Scalar indices + /// output[:, ..., :] = params[indices, :, ... :] + /// + /// # Vector indices + /// output[i, :, ..., :] = params[indices[i], :, ... :] + /// + /// # Higher rank indices + /// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] + /// ``` + /// + /// If `indices` is a permutation and `len(indices) == params.shape[0]` then + /// this operation will permute `params` accordingly. + /// + /// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in + /// `indices` are always validated to be within range. If assigned to GPU, + /// out-of-bound indices result in safe but unspecified behavior, which may include + /// raising an error. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Tensor gather(Tensor params_, Tensor indices, bool validate_indices = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Gather", name) { args = new object[] { params_, indices }, attrs = new Dictionary() { ["validate_indices"] = validate_indices } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return gather_eager_fallback(params_, indices, validate_indices: validate_indices, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["params"] = params_; + keywords["indices"] = indices; + keywords["validate_indices"] = validate_indices; + var _op = tf.OpDefLib._apply_op_helper("Gather", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "validate_indices", _op._get_attr_bool("validate_indices"), "Tparams", _op._get_attr_type("Tparams"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("Gather", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor gather_eager_fallback(Tensor params_, Tensor indices, bool validate_indices, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { params_, indices }; + object[] _attrs = new object[] { "validate_indices", validate_indices, "Tparams", params_.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("Gather", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Gather", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gather slices from `params` into a Tensor with shape specified by `indices`. + /// + /// + /// + /// `indices` is a K-dimensional integer tensor, best thought of as a + /// (K-1)-dimensional tensor of indices into `params`, where each element defines a + /// slice of `params`: + /// + /// output[\(i_0, ..., i_{K-2}\)] = params[indices[\(i_0, ..., i_{K-2}\)]] + /// + /// Whereas in `tf.gather` `indices` defines slices into the `axis` + /// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the + /// first `N` dimensions of `params`, where `N = indices.shape[-1]`. + /// + /// The last dimension of `indices` can be at most the rank of + /// `params`: + /// + /// indices.shape[-1] <= params.rank + /// + /// The last dimension of `indices` corresponds to elements + /// (if `indices.shape[-1] == params.rank`) or slices + /// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]` + /// of `params`. The output tensor has shape + /// + /// indices.shape[:-1] + params.shape[indices.shape[-1]:] + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, a 0 is stored in the + /// corresponding output value. + /// + /// Some examples below. + /// + /// Simple indexing into a matrix: + /// + /// ```python + /// indices = [[0, 0], [1, 1]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = ['a', 'd'] + /// ``` + /// + /// Slice indexing into a matrix: + /// + /// ```python + /// indices = [[1], [0]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [['c', 'd'], ['a', 'b']] + /// ``` + /// + /// Indexing into a 3-tensor: + /// + /// ```python + /// indices = [[1]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[['a1', 'b1'], ['c1', 'd1']]] + /// + /// + /// indices = [[0, 1], [1, 0]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [['c0', 'd0'], ['a1', 'b1']] + /// + /// + /// indices = [[0, 0, 1], [1, 0, 1]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = ['b0', 'b1'] + /// ``` + /// + /// Batched indexing into a matrix: + /// + /// ```python + /// indices = [[[0, 0]], [[0, 1]]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [['a'], ['b']] + /// ``` + /// + /// Batched slice indexing into a matrix: + /// + /// ```python + /// indices = [[[1]], [[0]]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [[['c', 'd']], [['a', 'b']]] + /// ``` + /// + /// Batched indexing into a 3-tensor: + /// + /// ```python + /// indices = [[[1]], [[0]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[[['a1', 'b1'], ['c1', 'd1']]], + /// [[['a0', 'b0'], ['c0', 'd0']]]] + /// + /// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[['c0', 'd0'], ['a1', 'b1']], + /// [['a0', 'b0'], ['c1', 'd1']]] + /// + /// + /// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [['b0', 'b1'], ['d0', 'c1']] + /// ``` + /// + /// See also `tf.gather` and `tf.batch_gather`. + /// + /// + /// + /// + /// + public static Tensor gather_nd(Tensor params_, Tensor indices, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "GatherNd", name) { args = new object[] { params_, indices }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return gather_nd_eager_fallback(params_, indices, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["params"] = params_; + keywords["indices"] = indices; + var _op = tf.OpDefLib._apply_op_helper("GatherNd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tparams", _op._get_attr_type("Tparams"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("GatherNd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor gather_nd_eager_fallback(Tensor params_, Tensor indices, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { params_, indices }; + object[] _attrs = new object[] { "Tparams", params_.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("GatherNd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("GatherNd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gather slices from `params` axis `axis` according to `indices`. + /// + /// + /// + /// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape `params.shape[:axis] + + /// indices.shape[batch_dims:] + params.shape[axis + 1:]` where: + /// + /// ```python + /// # Scalar indices (output is rank(params) - 1). + /// output[a_0, ..., a_n, b_0, ..., b_n] = + /// params[a_0, ..., a_n, indices, b_0, ..., b_n] + /// + /// # Vector indices (output is rank(params)). + /// output[a_0, ..., a_n, i, b_0, ..., b_n] = + /// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] + /// + /// # Higher rank indices (output is rank(params) + rank(indices) - 1). + /// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = + /// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] + /// ``` + /// + ///
+ /// + ///
+ /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, a 0 is stored in the + /// corresponding output value. + /// + /// See also `tf.batch_gather` and `tf.gather_nd`. + /// + ///
+ /// + /// + /// + /// + /// + public static Tensor gather_v2(Tensor params_, Tensor indices, Tensor axis, int batch_dims = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "GatherV2", name) { args = new object[] { params_, indices, axis }, attrs = new Dictionary() { ["batch_dims"] = batch_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return gather_v2_eager_fallback(params_, indices, axis, batch_dims: batch_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["params"] = params_; + keywords["indices"] = indices; + keywords["axis"] = axis; + keywords["batch_dims"] = batch_dims; + var _op = tf.OpDefLib._apply_op_helper("GatherV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "batch_dims", _op._get_attr_int("batch_dims"), "Tparams", _op._get_attr_type("Tparams"), "Tindices", _op._get_attr_type("Tindices"), "Taxis", _op._get_attr_type("Taxis") }; + _execute.record_gradient("GatherV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor gather_v2_eager_fallback(Tensor params_, Tensor indices, Tensor axis, int batch_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { params_, indices, axis }; + object[] _attrs = new object[] { "batch_dims", batch_dims, "Tparams", params_.dtype, "Tindices", indices.dtype, "Taxis", axis.dtype }; + var _result = _execute.execute("GatherV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("GatherV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gives a guarantee to the TF runtime that the input tensor is a constant. + /// + /// + /// + /// The runtime is then free to make optimizations based on this. + /// + /// Only accepts value typed tensors as inputs and rejects resource variable handles + /// as input. + /// + /// Returns the input tensor without modification. + /// + /// + /// + /// + public static Tensor guarantee_const(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "GuaranteeConst", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return guarantee_const_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("GuaranteeConst", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("GuaranteeConst", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor guarantee_const_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("GuaranteeConst", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("GuaranteeConst", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return a tensor with the same shape and contents as the input tensor or value. + /// + /// + /// + public static Tensor identity(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Identity", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return identity_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("Identity", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Identity", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor identity_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("Identity", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Identity", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a list of tensors with the same shapes and contents as the input + /// + /// + /// + /// tensors. + /// + /// This op can be used to override the gradient for complicated functions. For + /// example, suppose y = f(x) and we wish to apply a custom function g for backprop + /// such that dx = g(dy). In Python, + /// + /// ```python + /// with tf.get_default_graph().gradient_override_map( + /// {'IdentityN': 'OverrideGradientWithG'}): + /// y, _ = identity_n([f(x), x]) + /// + /// @tf.RegisterGradient('OverrideGradientWithG') + /// def ApplyG(op, dy, _): + /// return [None, g(dy)] # Do not backprop to f(x). + /// ``` + /// + /// + /// + /// + public static Tensor[] identity_n(Tensors input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IdentityN", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return identity_n_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("IdentityN", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T") }; + _execute.record_gradient("IdentityN", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] identity_n_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("IdentityN", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IdentityN", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns immutable tensor from memory region. + /// + /// + /// + /// The current implementation memmaps the tensor from a file. + /// + /// + /// + /// + /// Type of the returned tensor. + /// + /// + /// + /// + /// Shape of the returned tensor. + /// + /// + /// + /// + /// Name of readonly memory region used by the tensor, see + /// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. + /// + /// + /// + public static Tensor immutable_const(TF_DataType dtype, Shape shape, string memory_region_name, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ImmutableConst", name) { args = new object[] { }, attrs = new Dictionary() { ["dtype"] = dtype, ["shape"] = shape, ["memory_region_name"] = memory_region_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return immutable_const_eager_fallback(dtype: dtype, shape: shape, memory_region_name: memory_region_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dtype"] = dtype; + keywords["shape"] = shape; + keywords["memory_region_name"] = memory_region_name; + var _op = tf.OpDefLib._apply_op_helper("ImmutableConst", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape"), "memory_region_name", _op.get_attr("memory_region_name") }; + _execute.record_gradient("ImmutableConst", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor immutable_const_eager_fallback(TF_DataType dtype, Shape shape, string memory_region_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "dtype", dtype, "shape", shape, "memory_region_name", memory_region_name }; + var _result = _execute.execute("ImmutableConst", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ImmutableConst", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor inplace_add(Tensor x, Tensor i, Tensor v, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InplaceAdd", name) { args = new object[] { x, i, v }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return inplace_add_eager_fallback(x, i, v, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["i"] = i; + keywords["v"] = v; + var _op = tf.OpDefLib._apply_op_helper("InplaceAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InplaceAdd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor inplace_add_eager_fallback(Tensor x, Tensor i, Tensor v, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, i, v }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("InplaceAdd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InplaceAdd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor inplace_sub(Tensor x, Tensor i, Tensor v, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InplaceSub", name) { args = new object[] { x, i, v }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return inplace_sub_eager_fallback(x, i, v, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["i"] = i; + keywords["v"] = v; + var _op = tf.OpDefLib._apply_op_helper("InplaceSub", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InplaceSub", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor inplace_sub_eager_fallback(Tensor x, Tensor i, Tensor v, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, i, v }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("InplaceSub", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InplaceSub", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor inplace_update(Tensor x, Tensor i, Tensor v, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InplaceUpdate", name) { args = new object[] { x, i, v }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return inplace_update_eager_fallback(x, i, v, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["i"] = i; + keywords["v"] = v; + var _op = tf.OpDefLib._apply_op_helper("InplaceUpdate", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InplaceUpdate", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor inplace_update_eager_fallback(Tensor x, Tensor i, Tensor v, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, i, v }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("InplaceUpdate", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InplaceUpdate", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the inverse permutation of a tensor. + /// + /// + /// + /// This operation computes the inverse of an index permutation. It takes a 1-D + /// integer tensor `x`, which represents the indices of a zero-based array, and + /// swaps each value with its index position. In other words, for an output tensor + /// `y` and an input tensor `x`, this operation computes the following: + /// + /// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` + /// + /// The values must include 0. There can be no duplicate values or negative values. + /// + /// For example: + /// + /// ``` + /// # tensor `x` is [3, 4, 0, 2, 1] + /// invert_permutation(x) ==> [2, 4, 3, 0, 1] + /// ``` + /// + /// + /// + /// + public static Tensor invert_permutation(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InvertPermutation", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return invert_permutation_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("InvertPermutation", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InvertPermutation", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor invert_permutation_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("InvertPermutation", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InvertPermutation", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the difference between two lists of numbers or strings. + /// + /// + /// + /// Given a list `x` and a list `y`, this operation returns a list `out` that + /// represents all values that are in `x` but not in `y`. The returned list `out` + /// is sorted in the same order that the numbers appear in `x` (duplicates are + /// preserved). This operation also returns a list `idx` that represents the + /// position of each `out` element in `x`. In other words: + /// + /// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` + /// + /// For example, given this input: + /// + /// ``` + /// x = [1, 2, 3, 4, 5, 6] + /// y = [1, 3, 5] + /// ``` + /// + /// This operation would return: + /// + /// ``` + /// out ==> [2, 4, 6] + /// idx ==> [1, 3, 5] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor[] list_diff(Tensor x, Tensor y, TF_DataType out_idx = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ListDiff", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["out_idx"] = out_idx } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return list_diff_eager_fallback(x, y, out_idx: out_idx, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["out_idx"] = out_idx; + var _op = tf.OpDefLib._apply_op_helper("ListDiff", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_idx", _op._get_attr_type("out_idx") }; + _execute.record_gradient("ListDiff", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] list_diff_eager_fallback(Tensor x, Tensor y, TF_DataType out_idx, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "out_idx", out_idx }; + var _result = _execute.execute("ListDiff", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ListDiff", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Applies lower_bound(sorted_search_values, values) along each row. + /// + /// + /// + /// Each set of rows with the same index in (sorted_inputs, values) is treated + /// independently. The resulting row is the equivalent of calling + /// `np.searchsorted(sorted_inputs, values, side='left')`. + /// + /// The result is not a global index to the entire + /// `Tensor`, but rather just the index in the last dimension. + /// + /// A 2-D example: + /// sorted_sequence = [[0, 3, 9, 9, 10], + /// [1, 2, 3, 4, 5]] + /// values = [[2, 4, 9], + /// [0, 2, 6]] + /// + /// result = LowerBound(sorted_sequence, values) + /// + /// result == [[1, 2, 2], + /// [0, 1, 5]] + /// + /// + /// + /// + /// + /// + public static Tensor lower_bound(Tensor sorted_inputs, Tensor values, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LowerBound", name) { args = new object[] { sorted_inputs, values }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return lower_bound_eager_fallback(sorted_inputs, values, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["sorted_inputs"] = sorted_inputs; + keywords["values"] = values; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("LowerBound", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("LowerBound", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor lower_bound_eager_fallback(Tensor sorted_inputs, Tensor values, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { sorted_inputs, values }; + object[] _attrs = new object[] { "T", sorted_inputs.dtype, "out_type", out_type }; + var _result = _execute.execute("LowerBound", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LowerBound", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Copy a tensor setting everything outside a central band in each innermost matrix to zero. + /// + /// + /// + /// The `band` part is computed as follows: + /// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a + /// tensor with the same shape where + /// + /// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. + /// + /// The indicator function + /// + /// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && + /// (num_upper < 0 || (n-m) <= num_upper)`. + /// + /// For example: + /// + /// ``` + /// # if 'input' is [[ 0, 1, 2, 3] + /// # [-1, 0, 1, 2] + /// # [-2, -1, 0, 1] + /// # [-3, -2, -1, 0]], + /// + /// tf.linalg.band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] + /// [-1, 0, 1, 2] + /// [ 0, -1, 0, 1] + /// [ 0, 0, -1, 0]], + /// + /// tf.linalg.band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] + /// [-1, 0, 1, 0] + /// [-2, -1, 0, 1] + /// [ 0, -2, -1, 0]] + /// ``` + /// + /// Useful special cases: + /// + /// ``` + /// tf.linalg.band_part(input, 0, -1) ==> Upper triangular part. + /// tf.linalg.band_part(input, -1, 0) ==> Lower triangular part. + /// tf.linalg.band_part(input, 0, 0) ==> Diagonal. + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor matrix_band_part(Tensor input, Tensor num_lower, Tensor num_upper, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixBandPart", name) { args = new object[] { input, num_lower, num_upper }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_band_part_eager_fallback(input, num_lower, num_upper, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["num_lower"] = num_lower; + keywords["num_upper"] = num_upper; + var _op = tf.OpDefLib._apply_op_helper("MatrixBandPart", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindex", _op._get_attr_type("Tindex") }; + _execute.record_gradient("MatrixBandPart", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_band_part_eager_fallback(Tensor input, Tensor num_lower, Tensor num_upper, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, num_lower, num_upper }; + object[] _attrs = new object[] { "T", input.dtype, "Tindex", num_lower.dtype }; + var _result = _execute.execute("MatrixBandPart", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixBandPart", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched diagonal tensor with a given batched diagonal values. + /// + /// + /// + /// Given a `diagonal`, this operation returns a tensor with the `diagonal` and + /// everything else padded with zeros. The diagonal is computed as follows: + /// + /// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a + /// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: + /// + /// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. + /// + /// For example: + /// + /// ``` + /// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] + /// + /// and diagonal.shape = (2, 4) + /// + /// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0] + /// [0, 6, 0, 0] + /// [0, 0, 7, 0] + /// [0, 0, 0, 8]]] + /// + /// which has shape (2, 4, 4) + /// ``` + /// + /// + /// + /// + public static Tensor matrix_diag(Tensor diagonal, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiag", name) { args = new object[] { diagonal }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_eager_fallback(diagonal, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["diagonal"] = diagonal; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixDiag", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_eager_fallback(Tensor diagonal, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { diagonal }; + object[] _attrs = new object[] { "T", diagonal.dtype }; + var _result = _execute.execute("MatrixDiag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the batched diagonal part of a batched tensor. + /// + /// + /// + /// This operation returns a tensor with the `diagonal` part + /// of the batched `input`. The `diagonal` part is computed as follows: + /// + /// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a + /// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: + /// + /// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. + /// + /// The input must be at least a matrix. + /// + /// For example: + /// + /// ``` + /// # 'input' is [[[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0] + /// [0, 6, 0, 0] + /// [0, 0, 7, 0] + /// [0, 0, 0, 8]]] + /// + /// and input.shape = (2, 4, 4) + /// + /// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] + /// + /// which has shape (2, 4) + /// ``` + /// + /// + /// + /// + public static Tensor matrix_diag_part(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiagPart", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_part_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiagPart", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixDiagPart", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_part_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("MatrixDiagPart", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiagPart", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the batched diagonal part of a batched tensor. + /// + /// + /// + /// Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched + /// `input`. + /// + /// Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`. + /// Let `max_diag_len` be the maximum length among all diagonals to be extracted, + /// `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + /// Let `num_diags` be the number of diagonals to extract, + /// `num_diags = k[1] - k[0] + 1`. + /// + /// If `num_diags == 1`, the output tensor is of rank `r - 1` with shape + /// `[I, J, ..., L, max_diag_len]` and values: + /// + /// ``` + /// diagonal[i, j, ..., l, n] + /// = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + /// padding_value ; otherwise. + /// ``` + /// where `y = max(-k[1], 0)`, `x = max(k[1], 0)`. + /// + /// Otherwise, the output tensor has rank `r` with dimensions + /// `[I, J, ..., L, num_diags, max_diag_len]` with values: + /// + /// ``` + /// diagonal[i, j, ..., l, m, n] + /// = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + /// padding_value ; otherwise. + /// ``` + /// where `d = k[1] - m`, `y = max(-d, 0)`, and `x = max(d, 0)`. + /// + /// The input must be at least a matrix. + /// + /// For example: + /// + /// ``` + /// input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4) + /// [5, 6, 7, 8], + /// [9, 8, 7, 6]], + /// [[5, 4, 3, 2], + /// [1, 2, 3, 4], + /// [5, 6, 7, 8]]]) + /// + /// # A main diagonal from each batch. + /// tf.matrix_diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3) + /// [5, 2, 7]] + /// + /// # A superdiagonal from each batch. + /// tf.matrix_diag_part(input, k = 1) + /// ==> [[2, 7, 6], # Output shape: (2, 3) + /// [4, 3, 8]] + /// + /// # A tridiagonal band from each batch. + /// tf.matrix_diag_part(input, k = (-1, 1)) + /// ==> [[[2, 7, 6], # Output shape: (2, 3, 3) + /// [1, 6, 7], + /// [5, 8, 0]], + /// [[4, 3, 8], + /// [5, 2, 7], + /// [1, 6, 0]]] + /// + /// # Padding value = 9 + /// tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) + /// ==> [[[4, 9, 9], # Output shape: (2, 3, 3) + /// [3, 8, 9], + /// [2, 7, 6]], + /// [[2, 9, 9], + /// [3, 4, 9], + /// [4, 3, 8]]] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor matrix_diag_part_v2(Tensor input, Tensor k, Tensor padding_value, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiagPartV2", name) { args = new object[] { input, k, padding_value }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_part_v2_eager_fallback(input, k, padding_value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["k"] = k; + keywords["padding_value"] = padding_value; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiagPartV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixDiagPartV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_part_v2_eager_fallback(Tensor input, Tensor k, Tensor padding_value, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, k, padding_value }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("MatrixDiagPartV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiagPartV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the batched diagonal part of a batched tensor. + /// + /// + /// + /// Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched + /// `input`. + /// + /// Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`. + /// Let `max_diag_len` be the maximum length among all diagonals to be extracted, + /// `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + /// Let `num_diags` be the number of diagonals to extract, + /// `num_diags = k[1] - k[0] + 1`. + /// + /// If `num_diags == 1`, the output tensor is of rank `r - 1` with shape + /// `[I, J, ..., L, max_diag_len]` and values: + /// + /// ``` + /// diagonal[i, j, ..., l, n] + /// = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + /// padding_value ; otherwise. + /// ``` + /// where `y = max(-k[1], 0)`, `x = max(k[1], 0)`. + /// + /// Otherwise, the output tensor has rank `r` with dimensions + /// `[I, J, ..., L, num_diags, max_diag_len]` with values: + /// + /// ``` + /// diagonal[i, j, ..., l, m, n] + /// = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + /// padding_value ; otherwise. + /// ``` + /// where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`. + /// + /// `offset` is zero except when the alignment of the diagonal is to the right. + /// ``` + /// offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + /// and `d >= 0`) or + /// (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + /// and `d <= 0`) + /// 0 ; otherwise + /// ``` + /// where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + /// + /// The input must be at least a matrix. + /// + /// For example: + /// + /// ``` + /// input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4) + /// [5, 6, 7, 8], + /// [9, 8, 7, 6]], + /// [[5, 4, 3, 2], + /// [1, 2, 3, 4], + /// [5, 6, 7, 8]]]) + /// + /// # A main diagonal from each batch. + /// tf.matrix_diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3) + /// [5, 2, 7]] + /// + /// # A superdiagonal from each batch. + /// tf.matrix_diag_part(input, k = 1) + /// ==> [[2, 7, 6], # Output shape: (2, 3) + /// [4, 3, 8]] + /// + /// # A band from each batch. + /// tf.matrix_diag_part(input, k = (-1, 2)) + /// ==> [[[0, 3, 8], # Output shape: (2, 4, 3) + /// [2, 7, 6], + /// [1, 6, 7], + /// [5, 8, 0]], + /// [[0, 3, 4], + /// [4, 3, 8], + /// [5, 2, 7], + /// [1, 6, 0]]] + /// + /// # LEFT_RIGHT alignment. + /// tf.matrix_diag_part(input, k = (-1, 2), align="LEFT_RIGHT") + /// ==> [[[3, 8, 0], # Output shape: (2, 4, 3) + /// [2, 7, 6], + /// [1, 6, 7], + /// [0, 5, 8]], + /// [[3, 4, 0], + /// [4, 3, 8], + /// [5, 2, 7], + /// [0, 1, 6]]] + /// + /// # max_diag_len can be shorter than the main diagonal. + /// tf.matrix_diag_part(input, k = (-2, -1)) + /// ==> [[[5, 8], + /// [9, 0]], + /// [[1, 6], + /// [5, 0]]] + /// + /// # padding_value = 9 + /// tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) + /// ==> [[[9, 9, 4], # Output shape: (2, 3, 3) + /// [9, 3, 8], + /// [2, 7, 6]], + /// [[9, 9, 2], + /// [9, 3, 4], + /// [4, 3, 8]]] + /// + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is + /// a string specifying how superdiagonals and subdiagonals should be aligned, + /// respectively. There are four possible alignments: "RIGHT_LEFT" (default), + /// "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals + /// to the right (left-pads the row) and subdiagonals to the left (right-pads the + /// row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is + /// the opposite alignment. + /// + /// + /// + public static Tensor matrix_diag_part_v3(Tensor input, Tensor k, Tensor padding_value, string align = "RIGHT_LEFT", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiagPartV3", name) { args = new object[] { input, k, padding_value }, attrs = new Dictionary() { ["align"] = align } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_part_v3_eager_fallback(input, k, padding_value, align: align, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (align is null) + { + align = "RIGHT_LEFT"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["k"] = k; + keywords["padding_value"] = padding_value; + keywords["align"] = align; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiagPartV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "align", _op.get_attr("align") }; + _execute.record_gradient("MatrixDiagPartV3", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_part_v3_eager_fallback(Tensor input, Tensor k, Tensor padding_value, string align, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, k, padding_value }; + object[] _attrs = new object[] { "T", input.dtype, "align", align }; + var _result = _execute.execute("MatrixDiagPartV3", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiagPartV3", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched diagonal tensor with given batched diagonal values. + /// + /// + /// + /// Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th + /// diagonals of a matrix, with everything else padded with `padding`. `num_rows` + /// and `num_cols` specify the dimension of the innermost matrix of the output. If + /// both are not specified, the op assumes the innermost matrix is square and infers + /// its size from `k` and the innermost dimension of `diagonal`. If only one of them + /// is specified, the op assumes the unspecified value is the smallest possible + /// based on other criteria. + /// + /// Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor has + /// rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only one + /// diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has rank + /// `r` with shape `[I, J, ..., L, num_rows, num_cols]`. + /// + /// The second innermost dimension of `diagonal` has double meaning. + /// When `k` is scalar or `k[0] == k[1]`, `M` is part of the batch size + /// [I, J, ..., M], and the output tensor is: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper + /// padding_value ; otherwise + /// ``` + /// + /// Otherwise, `M` is treated as the number of diagonals for the matrix in the + /// same batch (`M = k[1]-k[0]+1`), and the output tensor is: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + /// padding_value ; otherwise + /// ``` + /// where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0)`. + /// + /// For example: + /// + /// ``` + /// # The main diagonal. + /// diagonal = np.array([[1, 2, 3, 4], # Input shape: (2, 4) + /// [5, 6, 7, 8]]) + /// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0], # Output shape: (2, 4, 4) + /// [0, 2, 0, 0], + /// [0, 0, 3, 0], + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0], + /// [0, 6, 0, 0], + /// [0, 0, 7, 0], + /// [0, 0, 0, 8]]] + /// + /// # A superdiagonal (per batch). + /// diagonal = np.array([[1, 2, 3], # Input shape: (2, 3) + /// [4, 5, 6]]) + /// tf.matrix_diag(diagonal, k = 1) + /// ==> [[[0, 1, 0, 0], # Output shape: (2, 4, 4) + /// [0, 0, 2, 0], + /// [0, 0, 0, 3], + /// [0, 0, 0, 0]], + /// [[0, 4, 0, 0], + /// [0, 0, 5, 0], + /// [0, 0, 0, 6], + /// [0, 0, 0, 0]]] + /// + /// # A band of diagonals. + /// diagonals = np.array([[[1, 2, 3], # Input shape: (2, 2, 3) + /// [4, 5, 0]], + /// [[6, 7, 9], + /// [9, 1, 0]]]) + /// tf.matrix_diag(diagonals, k = (-1, 0)) + /// ==> [[[1, 0, 0], # Output shape: (2, 3, 3) + /// [4, 2, 0], + /// [0, 5, 3]], + /// [[6, 0, 0], + /// [9, 7, 0], + /// [0, 1, 9]]] + /// + /// # Rectangular matrix. + /// diagonal = np.array([1, 2]) # Input shape: (2) + /// tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4) + /// ==> [[0, 0, 0, 0], # Output shape: (3, 4) + /// [1, 0, 0, 0], + /// [0, 2, 0, 0]] + /// + /// # Rectangular matrix with inferred num_cols and padding_value = 9. + /// tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) + /// ==> [[9, 9], # Output shape: (3, 2) + /// [1, 9], + /// [9, 2]] + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor matrix_diag_v2(Tensor diagonal, Tensor k, Tensor num_rows, Tensor num_cols, Tensor padding_value, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiagV2", name) { args = new object[] { diagonal, k, num_rows, num_cols, padding_value }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_v2_eager_fallback(diagonal, k, num_rows, num_cols, padding_value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["diagonal"] = diagonal; + keywords["k"] = k; + keywords["num_rows"] = num_rows; + keywords["num_cols"] = num_cols; + keywords["padding_value"] = padding_value; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiagV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixDiagV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_v2_eager_fallback(Tensor diagonal, Tensor k, Tensor num_rows, Tensor num_cols, Tensor padding_value, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { diagonal, k, num_rows, num_cols, padding_value }; + object[] _attrs = new object[] { "T", diagonal.dtype }; + var _result = _execute.execute("MatrixDiagV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiagV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched diagonal tensor with given batched diagonal values. + /// + /// + /// + /// Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th + /// diagonals of a matrix, with everything else padded with `padding`. `num_rows` + /// and `num_cols` specify the dimension of the innermost matrix of the output. If + /// both are not specified, the op assumes the innermost matrix is square and infers + /// its size from `k` and the innermost dimension of `diagonal`. If only one of them + /// is specified, the op assumes the unspecified value is the smallest possible + /// based on other criteria. + /// + /// Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor has + /// rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only one + /// diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has rank + /// `r` with shape `[I, J, ..., L, num_rows, num_cols]`. + /// + /// The second innermost dimension of `diagonal` has double meaning. + /// When `k` is scalar or `k[0] == k[1]`, `M` is part of the batch size + /// [I, J, ..., M], and the output tensor is: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper + /// padding_value ; otherwise + /// ``` + /// + /// Otherwise, `M` is treated as the number of diagonals for the matrix in the + /// same batch (`M = k[1]-k[0]+1`), and the output tensor is: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + /// padding_value ; otherwise + /// ``` + /// where `d = n - m`, `diag_index = [k] - d`, and + /// `index_in_diag = n - max(d, 0) + offset`. + /// + /// `offset` is zero except when the alignment of the diagonal is to the right. + /// ``` + /// offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + /// and `d >= 0`) or + /// (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + /// and `d <= 0`) + /// 0 ; otherwise + /// ``` + /// where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + /// + /// For example: + /// + /// ``` + /// # The main diagonal. + /// diagonal = np.array([[1, 2, 3, 4], # Input shape: (2, 4) + /// [5, 6, 7, 8]]) + /// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0], # Output shape: (2, 4, 4) + /// [0, 2, 0, 0], + /// [0, 0, 3, 0], + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0], + /// [0, 6, 0, 0], + /// [0, 0, 7, 0], + /// [0, 0, 0, 8]]] + /// + /// # A superdiagonal (per batch). + /// diagonal = np.array([[1, 2, 3], # Input shape: (2, 3) + /// [4, 5, 6]]) + /// tf.matrix_diag(diagonal, k = 1) + /// ==> [[[0, 1, 0, 0], # Output shape: (2, 4, 4) + /// [0, 0, 2, 0], + /// [0, 0, 0, 3], + /// [0, 0, 0, 0]], + /// [[0, 4, 0, 0], + /// [0, 0, 5, 0], + /// [0, 0, 0, 6], + /// [0, 0, 0, 0]]] + /// + /// # A tridiagonal band (per batch). + /// diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3) + /// [1, 2, 3], + /// [4, 5, 0]], + /// [[0, 2, 3], + /// [6, 7, 9], + /// [9, 1, 0]]]) + /// tf.matrix_diag(diagonals, k = (-1, 1)) + /// ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + /// [4, 2, 9], + /// [0, 5, 3]], + /// [[6, 2, 0], + /// [9, 7, 3], + /// [0, 1, 9]]] + /// + /// # LEFT_RIGHT alignment. + /// diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3) + /// [1, 2, 3], + /// [0, 4, 5]], + /// [[2, 3, 0], + /// [6, 7, 9], + /// [0, 9, 1]]]) + /// tf.matrix_diag(diagonals, k = (-1, 1), align="LEFT_RIGHT") + /// ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + /// [4, 2, 9], + /// [0, 5, 3]], + /// [[6, 2, 0], + /// [9, 7, 3], + /// [0, 1, 9]]] + /// + /// # Rectangular matrix. + /// diagonal = np.array([1, 2]) # Input shape: (2) + /// tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4) + /// ==> [[0, 0, 0, 0], # Output shape: (3, 4) + /// [1, 0, 0, 0], + /// [0, 2, 0, 0]] + /// + /// # Rectangular matrix with inferred num_cols and padding_value = 9. + /// tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) + /// ==> [[9, 9], # Output shape: (3, 2) + /// [1, 9], + /// [9, 2]] + /// + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is + /// a string specifying how superdiagonals and subdiagonals should be aligned, + /// respectively. There are four possible alignments: "RIGHT_LEFT" (default), + /// "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals + /// to the right (left-pads the row) and subdiagonals to the left (right-pads the + /// row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is + /// the opposite alignment. + /// + /// + /// + public static Tensor matrix_diag_v3(Tensor diagonal, Tensor k, Tensor num_rows, Tensor num_cols, Tensor padding_value, string align = "RIGHT_LEFT", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixDiagV3", name) { args = new object[] { diagonal, k, num_rows, num_cols, padding_value }, attrs = new Dictionary() { ["align"] = align } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_diag_v3_eager_fallback(diagonal, k, num_rows, num_cols, padding_value, align: align, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (align is null) + { + align = "RIGHT_LEFT"; + } + Dictionary keywords = new(); + keywords["diagonal"] = diagonal; + keywords["k"] = k; + keywords["num_rows"] = num_rows; + keywords["num_cols"] = num_cols; + keywords["padding_value"] = padding_value; + keywords["align"] = align; + var _op = tf.OpDefLib._apply_op_helper("MatrixDiagV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "align", _op.get_attr("align") }; + _execute.record_gradient("MatrixDiagV3", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_diag_v3_eager_fallback(Tensor diagonal, Tensor k, Tensor num_rows, Tensor num_cols, Tensor padding_value, string align, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { diagonal, k, num_rows, num_cols, padding_value }; + object[] _attrs = new object[] { "T", diagonal.dtype, "align", align }; + var _result = _execute.execute("MatrixDiagV3", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixDiagV3", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched matrix tensor with new batched diagonal values. + /// + /// + /// + /// Given `input` and `diagonal`, this operation returns a tensor with the + /// same shape and values as `input`, except for the main diagonal of the + /// innermost matrices. These will be overwritten by the values in `diagonal`. + /// + /// The output is computed as follows: + /// + /// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has + /// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a + /// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: + /// + /// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. + /// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. + /// + /// + /// + /// + /// + public static Tensor matrix_set_diag(Tensor input, Tensor diagonal, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixSetDiag", name) { args = new object[] { input, diagonal }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_set_diag_eager_fallback(input, diagonal, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["diagonal"] = diagonal; + var _op = tf.OpDefLib._apply_op_helper("MatrixSetDiag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixSetDiag", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_set_diag_eager_fallback(Tensor input, Tensor diagonal, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, diagonal }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("MatrixSetDiag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixSetDiag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched matrix tensor with new batched diagonal values. + /// + /// + /// + /// Given `input` and `diagonal`, this operation returns a tensor with the + /// same shape and values as `input`, except for the specified diagonals of the + /// innermost matrices. These will be overwritten by the values in `diagonal`. + /// + /// `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or + /// `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. + /// Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`. + /// `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`. + /// `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`, + /// `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + /// + /// The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`. + /// If `k` is scalar or `k[0] == k[1]`: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1] + /// input[i, j, ..., l, m, n] ; otherwise + /// ``` + /// + /// Otherwise, + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + /// input[i, j, ..., l, m, n] ; otherwise + /// ``` + /// where `d = n - m`, `diag_index = k[1] - d`, and `index_in_diag = n - max(d, 0)`. + /// + /// For example: + /// + /// ``` + /// # The main diagonal. + /// input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4) + /// [7, 7, 7, 7], + /// [7, 7, 7, 7]], + /// [[7, 7, 7, 7], + /// [7, 7, 7, 7], + /// [7, 7, 7, 7]]]) + /// diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3) + /// [4, 5, 6]]) + /// tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + /// [7, 2, 7, 7], + /// [7, 7, 3, 7]], + /// [[4, 7, 7, 7], + /// [7, 5, 7, 7], + /// [7, 7, 6, 7]]] + /// + /// # A superdiagonal (per batch). + /// tf.matrix_set_diag(diagonal, k = 1) + /// ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4) + /// [7, 7, 2, 7], + /// [7, 7, 7, 3]], + /// [[7, 4, 7, 7], + /// [7, 7, 5, 7], + /// [7, 7, 7, 6]]] + /// + /// # A band of diagonals. + /// diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3) + /// [4, 5, 0]], + /// [[6, 1, 2], + /// [3, 4, 0]]]) + /// tf.matrix_set_diag(diagonals, k = (-1, 0)) + /// ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + /// [4, 2, 7, 7], + /// [0, 5, 3, 7]], + /// [[6, 7, 7, 7], + /// [3, 1, 7, 7], + /// [7, 4, 2, 7]]] + /// + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor matrix_set_diag_v2(Tensor input, Tensor diagonal, Tensor k, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixSetDiagV2", name) { args = new object[] { input, diagonal, k }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_set_diag_v2_eager_fallback(input, diagonal, k, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["diagonal"] = diagonal; + keywords["k"] = k; + var _op = tf.OpDefLib._apply_op_helper("MatrixSetDiagV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatrixSetDiagV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_set_diag_v2_eager_fallback(Tensor input, Tensor diagonal, Tensor k, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, diagonal, k }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("MatrixSetDiagV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixSetDiagV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a batched matrix tensor with new batched diagonal values. + /// + /// + /// + /// Given `input` and `diagonal`, this operation returns a tensor with the + /// same shape and values as `input`, except for the specified diagonals of the + /// innermost matrices. These will be overwritten by the values in `diagonal`. + /// + /// `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or + /// `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. + /// Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`. + /// `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`. + /// `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`, + /// `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + /// + /// The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`. + /// If `k` is scalar or `k[0] == k[1]`: + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1] + /// input[i, j, ..., l, m, n] ; otherwise + /// ``` + /// + /// Otherwise, + /// + /// ``` + /// output[i, j, ..., l, m, n] + /// = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + /// input[i, j, ..., l, m, n] ; otherwise + /// ``` + /// where `d = n - m`, `diag_index = k[1] - d`, and + /// `index_in_diag = n - max(d, 0) + offset`. + /// + /// `offset` is zero except when the alignment of the diagonal is to the right. + /// ``` + /// offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + /// and `d >= 0`) or + /// (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + /// and `d <= 0`) + /// 0 ; otherwise + /// ``` + /// where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + /// + /// For example: + /// + /// ``` + /// # The main diagonal. + /// input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4) + /// [7, 7, 7, 7], + /// [7, 7, 7, 7]], + /// [[7, 7, 7, 7], + /// [7, 7, 7, 7], + /// [7, 7, 7, 7]]]) + /// diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3) + /// [4, 5, 6]]) + /// tf.matrix_set_diag(input, diagonal) + /// ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + /// [7, 2, 7, 7], + /// [7, 7, 3, 7]], + /// [[4, 7, 7, 7], + /// [7, 5, 7, 7], + /// [7, 7, 6, 7]]] + /// + /// # A superdiagonal (per batch). + /// tf.matrix_set_diag(input, diagonal, k = 1) + /// ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4) + /// [7, 7, 2, 7], + /// [7, 7, 7, 3]], + /// [[7, 4, 7, 7], + /// [7, 7, 5, 7], + /// [7, 7, 7, 6]]] + /// + /// # A band of diagonals. + /// diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3) + /// [6, 5, 8], + /// [1, 2, 3], + /// [4, 5, 0]], + /// [[0, 1, 2], + /// [5, 6, 4], + /// [6, 1, 2], + /// [3, 4, 0]]]) + /// tf.matrix_set_diag(input, diagonals, k = (-1, 2)) + /// ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + /// [4, 2, 5, 1], + /// [7, 5, 3, 8]], + /// [[6, 5, 1, 7], + /// [3, 1, 6, 2], + /// [7, 4, 2, 4]]] + /// + /// # LEFT_RIGHT alignment. + /// diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3) + /// [6, 5, 8], + /// [1, 2, 3], + /// [0, 4, 5]], + /// [[1, 2, 0], + /// [5, 6, 4], + /// [6, 1, 2], + /// [0, 3, 4]]]) + /// tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") + /// ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + /// [4, 2, 5, 1], + /// [7, 5, 3, 8]], + /// [[6, 5, 1, 7], + /// [3, 1, 6, 2], + /// [7, 4, 2, 4]]] + /// + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is + /// a string specifying how superdiagonals and subdiagonals should be aligned, + /// respectively. There are four possible alignments: "RIGHT_LEFT" (default), + /// "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals + /// to the right (left-pads the row) and subdiagonals to the left (right-pads the + /// row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is + /// the opposite alignment. + /// + /// + /// + public static Tensor matrix_set_diag_v3(Tensor input, Tensor diagonal, Tensor k, string align = "RIGHT_LEFT", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatrixSetDiagV3", name) { args = new object[] { input, diagonal, k }, attrs = new Dictionary() { ["align"] = align } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matrix_set_diag_v3_eager_fallback(input, diagonal, k, align: align, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (align is null) + { + align = "RIGHT_LEFT"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["diagonal"] = diagonal; + keywords["k"] = k; + keywords["align"] = align; + var _op = tf.OpDefLib._apply_op_helper("MatrixSetDiagV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "align", _op.get_attr("align") }; + _execute.record_gradient("MatrixSetDiagV3", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matrix_set_diag_v3_eager_fallback(Tensor input, Tensor diagonal, Tensor k, string align, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, diagonal, k }; + object[] _attrs = new object[] { "T", input.dtype, "align", align }; + var _result = _execute.execute("MatrixSetDiagV3", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatrixSetDiagV3", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Pads a tensor with mirrored values. + /// + /// + /// + /// This operation pads a `input` with mirrored values according to the `paddings` + /// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is + /// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates + /// how many values to add before the contents of `input` in that dimension, and + /// `paddings[D, 1]` indicates how many values to add after the contents of `input` + /// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater + /// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true + /// (if false, respectively). + /// + /// The padded size of each dimension D of the output is: + /// + /// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + /// + /// For example: + /// + /// ``` + /// # 't' is [[1, 2, 3], [4, 5, 6]]. + /// # 'paddings' is [[1, 1]], [2, 2]]. + /// # 'mode' is SYMMETRIC. + /// # rank of 't' is 2. + /// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] + /// [2, 1, 1, 2, 3, 3, 2] + /// [5, 4, 4, 5, 6, 6, 5] + /// [5, 4, 4, 5, 6, 6, 5]] + /// ``` + /// + /// + /// + /// + /// + /// + /// Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions + /// do not include the borders, while in symmetric mode the padded regions + /// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` + /// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and + /// it is `[1, 2, 3, 3, 2]` in symmetric mode. + /// + /// + /// + public static Tensor mirror_pad(Tensor input, Tensor paddings, string mode, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MirrorPad", name) { args = new object[] { input, paddings }, attrs = new Dictionary() { ["mode"] = mode } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mirror_pad_eager_fallback(input, paddings, mode: mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + keywords["mode"] = mode; + var _op = tf.OpDefLib._apply_op_helper("MirrorPad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tpaddings", _op._get_attr_type("Tpaddings"), "mode", _op.get_attr("mode") }; + _execute.record_gradient("MirrorPad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mirror_pad_eager_fallback(Tensor input, Tensor paddings, string mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings }; + object[] _attrs = new object[] { "T", input.dtype, "Tpaddings", paddings.dtype, "mode", mode }; + var _result = _execute.execute("MirrorPad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MirrorPad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. + /// + /// + /// + /// This operation folds the padded areas of `input` by `MirrorPad` according to the + /// `paddings` you specify. `paddings` must be the same as `paddings` argument + /// given to the corresponding `MirrorPad` op. + /// + /// The folded size of each dimension D of the output is: + /// + /// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)` + /// + /// For example: + /// + /// ``` + /// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. + /// # 'paddings' is [[0, 1]], [0, 1]]. + /// # 'mode' is SYMMETRIC. + /// # rank of 't' is 2. + /// pad(t, paddings) ==> [[ 1, 5] + /// [11, 28]] + /// ``` + /// + /// + /// + /// + /// + /// + /// The mode used in the `MirrorPad` op. + /// + /// + /// + public static Tensor mirror_pad_grad(Tensor input, Tensor paddings, string mode, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MirrorPadGrad", name) { args = new object[] { input, paddings }, attrs = new Dictionary() { ["mode"] = mode } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mirror_pad_grad_eager_fallback(input, paddings, mode: mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + keywords["mode"] = mode; + var _op = tf.OpDefLib._apply_op_helper("MirrorPadGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tpaddings", _op._get_attr_type("Tpaddings"), "mode", _op.get_attr("mode") }; + _execute.record_gradient("MirrorPadGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mirror_pad_grad_eager_fallback(Tensor input, Tensor paddings, string mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings }; + object[] _attrs = new object[] { "T", input.dtype, "Tpaddings", paddings.dtype, "mode", mode }; + var _result = _execute.execute("MirrorPadGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MirrorPadGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a one-hot tensor. + /// + /// + /// + /// The locations represented by indices in `indices` take value `on_value`, + /// while all other locations take value `off_value`. + /// + /// If the input `indices` is rank `N`, the output will have rank `N+1`, + /// The new axis is created at dimension `axis` (default: the new axis is + /// appended at the end). + /// + /// If `indices` is a scalar the output shape will be a vector of length `depth`. + /// + /// If `indices` is a vector of length `features`, the output shape will be: + /// ``` + /// features x depth if axis == -1 + /// depth x features if axis == 0 + /// ``` + /// + /// If `indices` is a matrix (batch) with shape `[batch, features]`, + /// the output shape will be: + /// ``` + /// batch x features x depth if axis == -1 + /// batch x depth x features if axis == 1 + /// depth x batch x features if axis == 0 + /// ``` + /// + /// + /// Examples + /// ========= + /// + /// Suppose that + /// ``` + /// indices = [0, 2, -1, 1] + /// depth = 3 + /// on_value = 5.0 + /// off_value = 0.0 + /// axis = -1 + /// ``` + /// + /// Then output is `[4 x 3]`: + /// ``` + /// output = + /// [5.0 0.0 0.0] // one_hot(0) + /// [0.0 0.0 5.0] // one_hot(2) + /// [0.0 0.0 0.0] // one_hot(-1) + /// [0.0 5.0 0.0] // one_hot(1) + /// ``` + /// + /// Suppose that + /// ``` + /// indices = [0, 2, -1, 1] + /// depth = 3 + /// on_value = 0.0 + /// off_value = 3.0 + /// axis = 0 + /// ``` + /// + /// Then output is `[3 x 4]`: + /// ``` + /// output = + /// [0.0 3.0 3.0 3.0] + /// [3.0 3.0 3.0 0.0] + /// [3.0 3.0 3.0 3.0] + /// [3.0 0.0 3.0 3.0] + /// // ^ one_hot(0) + /// // ^ one_hot(2) + /// // ^ one_hot(-1) + /// // ^ one_hot(1) + /// ``` + /// + /// Suppose that + /// ``` + /// indices = [[0, 2], [1, -1]] + /// depth = 3 + /// on_value = 1.0 + /// off_value = 0.0 + /// axis = -1 + /// ``` + /// + /// Then output is `[2 x 2 x 3]`: + /// ``` + /// output = + /// [ + /// [1.0, 0.0, 0.0] // one_hot(0) + /// [0.0, 0.0, 1.0] // one_hot(2) + /// ][ + /// [0.0, 1.0, 0.0] // one_hot(1) + /// [0.0, 0.0, 0.0] // one_hot(-1) + /// ] + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + /// The axis to fill (default: -1, a new inner-most axis). + /// + /// + /// + public static Tensor one_hot(Tensor indices, Tensor depth, Tensor on_value, Tensor off_value, int axis = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "OneHot", name) { args = new object[] { indices, depth, on_value, off_value }, attrs = new Dictionary() { ["axis"] = axis } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return one_hot_eager_fallback(indices, depth, on_value, off_value, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["indices"] = indices; + keywords["depth"] = depth; + keywords["on_value"] = on_value; + keywords["off_value"] = off_value; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("OneHot", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "axis", _op._get_attr_int("axis"), "T", _op._get_attr_type("T"), "TI", _op._get_attr_type("TI") }; + _execute.record_gradient("OneHot", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor one_hot_eager_fallback(Tensor indices, Tensor depth, Tensor on_value, Tensor off_value, int axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { indices, depth, on_value, off_value }; + object[] _attrs = new object[] { "axis", axis, "T", on_value.dtype, "TI", indices.dtype }; + var _result = _execute.execute("OneHot", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("OneHot", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a tensor of ones with the same shape and type as x. + /// + /// + /// + public static Tensor ones_like(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "OnesLike", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return ones_like_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("OnesLike", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("OnesLike", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ones_like_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("OnesLike", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("OnesLike", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. + /// + /// + /// + /// Packs the `N` tensors in `values` into a tensor with rank one higher than each + /// tensor in `values`, by packing them along the `axis` dimension. + /// Given a list of tensors of shape `(A, B, C)`; + /// + /// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. + /// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. + /// Etc. + /// + /// For example: + /// + /// ``` + /// # 'x' is [1, 4] + /// # 'y' is [2, 5] + /// # 'z' is [3, 6] + /// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. + /// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] + /// ``` + /// + /// This is the opposite of `unpack`. + /// + /// + /// + /// + /// + /// Dimension along which to pack. Negative values wrap around, so the + /// valid range is `[-(R+1), R+1)`. + /// + /// + /// + public static Tensor pack(Tensors values, int axis = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Pack", name) { args = new object[] { values }, attrs = new Dictionary() { ["axis"] = axis } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return pack_eager_fallback(values, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["values"] = values; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("Pack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), "axis", _op._get_attr_int("axis") }; + _execute.record_gradient("Pack", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor pack_eager_fallback(Tensors values, int axis, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(values); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", values.Length, "T", values.dtype, "axis", axis }; + var _result = _execute.execute("Pack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Pack", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Pads a tensor with zeros. + /// + /// + /// + /// This operation pads a `input` with zeros according to the `paddings` you + /// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the + /// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates + /// how many zeros to add before the contents of `input` in that dimension, and + /// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` + /// in that dimension. + /// + /// The padded size of each dimension D of the output is: + /// + /// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + /// + /// For example: + /// + /// ``` + /// # 't' is [[1, 1], [2, 2]] + /// # 'paddings' is [[1, 1], [2, 2]] + /// # rank of 't' is 2 + /// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + /// [0, 0, 1, 1, 0, 0] + /// [0, 0, 2, 2, 0, 0] + /// [0, 0, 0, 0, 0, 0]] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor pad(Tensor input, Tensor paddings, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Pad", name) { args = new object[] { input, paddings }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return pad_eager_fallback(input, paddings, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + var _op = tf.OpDefLib._apply_op_helper("Pad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tpaddings", _op._get_attr_type("Tpaddings") }; + _execute.record_gradient("Pad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor pad_eager_fallback(Tensor input, Tensor paddings, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings }; + object[] _attrs = new object[] { "T", input.dtype, "Tpaddings", paddings.dtype }; + var _result = _execute.execute("Pad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Pad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Pads a tensor. + /// + /// + /// + /// This operation pads `input` according to the `paddings` and `constant_values` + /// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is + /// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates + /// how many padding values to add before the contents of `input` in that dimension, + /// and `paddings[D, 1]` indicates how many padding values to add after the contents + /// of `input` in that dimension. `constant_values` is a scalar tensor of the same + /// type as `input` that indicates the value to use for padding `input`. + /// + /// The padded size of each dimension D of the output is: + /// + /// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` + /// + /// For example: + /// + /// ``` + /// # 't' is [[1, 1], [2, 2]] + /// # 'paddings' is [[1, 1], [2, 2]] + /// # 'constant_values' is 0 + /// # rank of 't' is 2 + /// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] + /// [0, 0, 1, 1, 0, 0] + /// [0, 0, 2, 2, 0, 0] + /// [0, 0, 0, 0, 0, 0]] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor pad_v2(Tensor input, Tensor paddings, Tensor constant_values, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PadV2", name) { args = new object[] { input, paddings, constant_values }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return pad_v2_eager_fallback(input, paddings, constant_values, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + keywords["constant_values"] = constant_values; + var _op = tf.OpDefLib._apply_op_helper("PadV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tpaddings", _op._get_attr_type("Tpaddings") }; + _execute.record_gradient("PadV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor pad_v2_eager_fallback(Tensor input, Tensor paddings, Tensor constant_values, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings, constant_values }; + object[] _attrs = new object[] { "T", input.dtype, "Tpaddings", paddings.dtype }; + var _result = _execute.execute("PadV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PadV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concatenates a list of `N` tensors along the first dimension. + /// + /// + /// + /// The input tensors are all required to have size 1 in the first dimension. + /// + /// For example: + /// + /// ``` + /// # 'x' is [[1, 4]] + /// # 'y' is [[2, 5]] + /// # 'z' is [[3, 6]] + /// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. + /// ``` + /// + /// The difference between concat and parallel_concat is that concat requires all + /// of the inputs be computed before the operation will begin but doesn't require + /// that the input shapes be known during graph construction. Parallel concat + /// will copy pieces of the input into the output as they become available, in + /// some situations this can provide a performance benefit. + /// + /// + /// + /// + /// + /// the final shape of the result; should be equal to the shapes of any input + /// but with the number of input values in the first dimension. + /// + /// + /// + public static Tensor parallel_concat(Tensors values, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ParallelConcat", name) { args = new object[] { values }, attrs = new Dictionary() { ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return parallel_concat_eager_fallback(values, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["values"] = values; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("ParallelConcat", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("ParallelConcat", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor parallel_concat_eager_fallback(Tensors values, Shape shape, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(values); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", values.Length, "T", values.dtype, "shape", shape }; + var _result = _execute.execute("ParallelConcat", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ParallelConcat", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A placeholder op for a value that will be fed into the computation. + /// + /// + /// + /// N.B. This operation will fail with an error if it is executed. It is + /// intended as a way to represent a value that will always be fed, and to + /// provide attrs that enable the fed value to be checked at runtime. + /// + /// + /// + /// + /// The type of elements in the tensor. + /// + /// + /// + /// + /// (Optional) The shape of the tensor. If the shape has 0 dimensions, the + /// shape is unconstrained. + /// + /// + /// + public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Placeholder", name) { args = new object[] { }, attrs = new Dictionary() { ["dtype"] = dtype, ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return placeholder_eager_fallback(dtype: dtype, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dtype"] = dtype; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("Placeholder", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("Placeholder", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor placeholder_eager_fallback(TF_DataType dtype, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "dtype", dtype, "shape", shape }; + var _result = _execute.execute("Placeholder", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A placeholder op for a value that will be fed into the computation. + /// + /// + /// + /// N.B. This operation will fail with an error if it is executed. It is + /// intended as a way to represent a value that will always be fed, and to + /// provide attrs that enable the fed value to be checked at runtime. + /// + /// + /// + /// + /// The type of elements in the tensor. + /// + /// + /// + /// + /// The shape of the tensor. The shape can be any partially-specified + /// shape. To be unconstrained, pass in a shape with unknown rank. + /// + /// + /// + public static Tensor placeholder_v2(TF_DataType dtype, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PlaceholderV2", name) { args = new object[] { }, attrs = new Dictionary() { ["dtype"] = dtype, ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return placeholder_v2_eager_fallback(dtype: dtype, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dtype"] = dtype; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("PlaceholderV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("PlaceholderV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor placeholder_v2_eager_fallback(TF_DataType dtype, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "dtype", dtype, "shape", shape }; + var _result = _execute.execute("PlaceholderV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PlaceholderV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A placeholder op that passes through `input` when its output is not fed. + /// + /// + /// + /// + /// The (possibly partial) shape of the tensor. + /// + /// + /// + public static Tensor placeholder_with_default(Tensor input, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PlaceholderWithDefault", name) { args = new object[] { input }, attrs = new Dictionary() { ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return placeholder_with_default_eager_fallback(input, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("PlaceholderWithDefault", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("PlaceholderWithDefault", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor placeholder_with_default_eager_fallback(Tensor input, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "dtype", input.dtype, "shape", shape }; + var _result = _execute.execute("PlaceholderWithDefault", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PlaceholderWithDefault", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// An identity op that triggers an error if a gradient is requested. + /// + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, the TensorFlow gradient system + /// will return an error when trying to lookup the gradient of this op, + /// because no gradient must ever be registered for this function. This + /// op exists to prevent subtle bugs from silently returning unimplemented + /// gradients in some corner cases. + /// + /// + /// + /// + /// + /// Will be printed in the error when anyone tries to differentiate + /// this operation. + /// + /// + /// + public static Tensor prevent_gradient(Tensor input, string message = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PreventGradient", name) { args = new object[] { input }, attrs = new Dictionary() { ["message"] = message } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return prevent_gradient_eager_fallback(input, message: message, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (message is null) + { + message = ""; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["message"] = message; + var _op = tf.OpDefLib._apply_op_helper("PreventGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "message", _op.get_attr("message") }; + _execute.record_gradient("PreventGradient", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor prevent_gradient_eager_fallback(Tensor input, string message, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "message", message }; + var _result = _execute.execute("PreventGradient", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PreventGradient", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Use QuantizeAndDequantizeV2 instead. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor quantize_and_dequantize(Tensor input, bool signed_input = true, int num_bits = 8, bool range_given = false, float input_min = 0f, float input_max = 0f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeAndDequantize", name) { args = new object[] { input }, attrs = new Dictionary() { ["signed_input"] = signed_input, ["num_bits"] = num_bits, ["range_given"] = range_given, ["input_min"] = input_min, ["input_max"] = input_max } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_and_dequantize_eager_fallback(input, signed_input: signed_input, num_bits: num_bits, range_given: range_given, input_min: input_min, input_max: input_max, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["signed_input"] = signed_input; + keywords["num_bits"] = num_bits; + keywords["range_given"] = range_given; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + var _op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "signed_input", _op._get_attr_bool("signed_input"), "num_bits", _op._get_attr_int("num_bits"), "range_given", _op._get_attr_bool("range_given"), "input_min", _op.get_attr("input_min"), "input_max", _op.get_attr("input_max"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("QuantizeAndDequantize", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor quantize_and_dequantize_eager_fallback(Tensor input, bool signed_input, int num_bits, bool range_given, float input_min, float input_max, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "signed_input", signed_input, "num_bits", num_bits, "range_given", range_given, "input_min", input_min, "input_max", input_max, "T", input.dtype }; + var _result = _execute.execute("QuantizeAndDequantize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeAndDequantize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Quantizes then dequantizes a tensor. + /// + /// + /// + /// This op simulates the precision loss from the quantized forward pass by: + /// + /// 1. Quantizing the tensor to fixed point numbers, which should match the target + /// quantization method when it is used in inference. + /// 2. Dequantizing it back to floating point numbers for the following ops, most + /// likely matmul. + /// + /// There are different ways to quantize. This version uses only scaling, so 0.0 + /// maps to 0. + /// + /// From the specified 'num_bits' in the quantized output type, it determines + /// minimum and maximum representable quantized values. + /// + /// e.g. + /// + /// * [-128, 127] for signed, num_bits = 8, or + /// * [0, 255] for unsigned, num_bits = 8. + /// + /// If range_given == False, the initial input_min, input_max will be determined + /// automatically as the minimum and maximum values in the input tensor, otherwise + /// the specified values of input_min, input_max are used. + /// + /// Note: If the input_min, input_max are specified, they do not need to equal the + /// actual minimum and maximum values in the tensor. e.g. in some cases it may be + /// beneficial to specify these values such that the low probability extremes of the + /// input distribution are clipped. + /// + /// This op determines the maximum scale_factor that would map the initial + /// [input_min, input_max] range to a range that lies within the representable + /// quantized range. + /// + /// It determines the scale from one of input_min and input_max, then updates the + /// other one to maximize the representable range. + /// + /// e.g. + /// + /// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + /// 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it + /// would update input_max to be 127 / 12.8 = 9.921875 + /// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + /// 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it + /// would update input_min to be 128.0 / 12.7 = -10.07874 + /// * if the output is unsigned, input_min is forced to be 0, and only the + /// specified input_max is used. + /// + /// After determining the scale_factor and updating the input range, it applies the + /// following to each value in the 'input' tensor. + /// + /// output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. + /// + /// The above round function rounds the value based on the given round_mode. + /// + /// + /// + /// + /// + /// + /// + /// + /// Whether the quantization is signed or unsigned. (actually this parameter should + /// have been called `signed_output`) + /// + /// + /// + /// + /// The bitwidth of the quantization. + /// + /// + /// + /// + /// Whether the range is given or should be determined from the `input` tensor. + /// + /// + /// + /// + /// The 'round_mode' attribute controls which rounding tie-breaking algorithm is + /// used when rounding float values to their quantized equivalents. The following + /// rounding modes are currently supported: + /// + /// * HALF_TO_EVEN: this is the default round_mode. + /// * HALF_UP: round towards positive. In this mode 7.5 rounds up to 8 and -7.5 + /// rounds up to -7. + /// + /// + /// + /// + /// + /// If True, then the absolute value of the quantized minimum value is the same as + /// the quantized maximum value, instead of 1 greater. + /// i.e. for 8 bit quantization, the minimum value is -127 instead of -128. + /// + /// + /// + /// + /// If specified, this axis is treated as a channel or slice axis, and a separate + /// quantization range is used for each channel or slice along this axis. + /// + /// + /// + public static Tensor quantize_and_dequantize_v2(Tensor input, Tensor input_min, Tensor input_max, bool signed_input = true, int num_bits = 8, bool range_given = false, string round_mode = "HALF_TO_EVEN", bool narrow_range = false, int axis = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeAndDequantizeV2", name) { args = new object[] { input, input_min, input_max }, attrs = new Dictionary() { ["signed_input"] = signed_input, ["num_bits"] = num_bits, ["range_given"] = range_given, ["round_mode"] = round_mode, ["narrow_range"] = narrow_range, ["axis"] = axis } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_and_dequantize_v2_eager_fallback(input, input_min, input_max, signed_input: signed_input, num_bits: num_bits, range_given: range_given, round_mode: round_mode, narrow_range: narrow_range, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (round_mode is null) + { + round_mode = "HALF_TO_EVEN"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["signed_input"] = signed_input; + keywords["num_bits"] = num_bits; + keywords["range_given"] = range_given; + keywords["round_mode"] = round_mode; + keywords["narrow_range"] = narrow_range; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantizeV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "signed_input", _op._get_attr_bool("signed_input"), "num_bits", _op._get_attr_int("num_bits"), "range_given", _op._get_attr_bool("range_given"), "T", _op._get_attr_type("T"), "round_mode", _op.get_attr("round_mode"), "narrow_range", _op._get_attr_bool("narrow_range"), "axis", _op._get_attr_int("axis") }; + _execute.record_gradient("QuantizeAndDequantizeV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor quantize_and_dequantize_v2_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, bool signed_input, int num_bits, bool range_given, string round_mode, bool narrow_range, int axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max }; + object[] _attrs = new object[] { "signed_input", signed_input, "num_bits", num_bits, "range_given", range_given, "T", input.dtype, "round_mode", round_mode, "narrow_range", narrow_range, "axis", axis }; + var _result = _execute.execute("QuantizeAndDequantizeV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeAndDequantizeV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Quantizes then dequantizes a tensor. + /// + /// + /// + /// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a + /// tensor, so its value can change during training. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor quantize_and_dequantize_v3(Tensor input, Tensor input_min, Tensor input_max, Tensor num_bits, bool signed_input = true, bool range_given = true, bool narrow_range = false, int axis = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeAndDequantizeV3", name) { args = new object[] { input, input_min, input_max, num_bits }, attrs = new Dictionary() { ["signed_input"] = signed_input, ["range_given"] = range_given, ["narrow_range"] = narrow_range, ["axis"] = axis } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_and_dequantize_v3_eager_fallback(input, input_min, input_max, num_bits, signed_input: signed_input, range_given: range_given, narrow_range: narrow_range, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["num_bits"] = num_bits; + keywords["signed_input"] = signed_input; + keywords["range_given"] = range_given; + keywords["narrow_range"] = narrow_range; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantizeV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "signed_input", _op._get_attr_bool("signed_input"), "range_given", _op._get_attr_bool("range_given"), "T", _op._get_attr_type("T"), "narrow_range", _op._get_attr_bool("narrow_range"), "axis", _op._get_attr_int("axis") }; + _execute.record_gradient("QuantizeAndDequantizeV3", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor quantize_and_dequantize_v3_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, Tensor num_bits, bool signed_input, bool range_given, bool narrow_range, int axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max, num_bits }; + object[] _attrs = new object[] { "signed_input", signed_input, "range_given", range_given, "T", input.dtype, "narrow_range", narrow_range, "axis", axis }; + var _result = _execute.execute("QuantizeAndDequantizeV3", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeAndDequantizeV3", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Quantizes then dequantizes a tensor. + /// + /// + /// + /// This is almost identical to QuantizeAndDequantizeV2, except that it returns a + /// gradient of 1 for inputs that are within the quantization range, or 0 otherwise. + /// + /// + /// + /// + /// + /// + /// + /// Whether the quantization is signed or unsigned. (actually this parameter should + /// have been called `signed_output`) + /// + /// + /// + /// + /// The bitwidth of the quantization. + /// + /// + /// + /// + /// Whether the range is given or should be determined from the `input` tensor. + /// + /// + /// + /// + /// The 'round_mode' attribute controls which rounding tie-breaking algorithm is + /// used when rounding float values to their quantized equivalents. The following + /// rounding modes are currently supported: + /// + /// * HALF_TO_EVEN: this is the default round_mode. + /// * HALF_UP: round towards positive. In this mode 7.5 rounds up to 8 and -7.5 + /// rounds up to -7. + /// + /// + /// + /// + /// + /// If True, then the absolute value of the quantized minimum value is the same as + /// the quantized maximum value, instead of 1 greater. + /// i.e. for 8 bit quantization, the minimum value is -127 instead of -128. + /// + /// + /// + /// + /// If specified, this axis is treated as a channel or slice axis, and a separate + /// quantization range is used for each channel or slice along this axis. + /// + /// + /// + public static Tensor quantize_and_dequantize_v4(Tensor input, Tensor input_min, Tensor input_max, bool signed_input = true, int num_bits = 8, bool range_given = false, string round_mode = "HALF_TO_EVEN", bool narrow_range = false, int axis = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeAndDequantizeV4", name) { args = new object[] { input, input_min, input_max }, attrs = new Dictionary() { ["signed_input"] = signed_input, ["num_bits"] = num_bits, ["range_given"] = range_given, ["round_mode"] = round_mode, ["narrow_range"] = narrow_range, ["axis"] = axis } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_and_dequantize_v4_eager_fallback(input, input_min, input_max, signed_input: signed_input, num_bits: num_bits, range_given: range_given, round_mode: round_mode, narrow_range: narrow_range, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (round_mode is null) + { + round_mode = "HALF_TO_EVEN"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["signed_input"] = signed_input; + keywords["num_bits"] = num_bits; + keywords["range_given"] = range_given; + keywords["round_mode"] = round_mode; + keywords["narrow_range"] = narrow_range; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantizeV4", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "signed_input", _op._get_attr_bool("signed_input"), "num_bits", _op._get_attr_int("num_bits"), "range_given", _op._get_attr_bool("range_given"), "T", _op._get_attr_type("T"), "round_mode", _op.get_attr("round_mode"), "narrow_range", _op._get_attr_bool("narrow_range"), "axis", _op._get_attr_int("axis") }; + _execute.record_gradient("QuantizeAndDequantizeV4", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor quantize_and_dequantize_v4_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, bool signed_input, int num_bits, bool range_given, string round_mode, bool narrow_range, int axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max }; + object[] _attrs = new object[] { "signed_input", signed_input, "num_bits", num_bits, "range_given", range_given, "T", input.dtype, "round_mode", round_mode, "narrow_range", narrow_range, "axis", axis }; + var _result = _execute.execute("QuantizeAndDequantizeV4", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeAndDequantizeV4", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. + /// + /// + /// + /// [min_range, max_range] are scalar floats that specify the range for + /// the 'input' data. The 'mode' attribute controls exactly which calculations are + /// used to convert the float values to their quantized equivalents. The + /// 'round_mode' attribute controls which rounding tie-breaking algorithm is used + /// when rounding float values to their quantized equivalents. + /// + /// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: + /// + /// ``` + /// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) + /// if T == qint8: out[i] -= (range(T) + 1) / 2.0 + /// ``` + /// + /// here `range(T) = numeric_limits::max() - numeric_limits::min()` + /// + /// *MIN_COMBINED Mode Example* + /// + /// Assume the input is type float and has a possible range of [0.0, 6.0] and the + /// output type is quint8 ([0, 255]). The min_range and max_range values should be + /// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each + /// value of the input by 255/6 and cast to quint8. + /// + /// If the output type was qint8 ([-128, 127]), the operation will additionally + /// subtract each value by 128 prior to casting, so that the range of values aligns + /// with the range of qint8. + /// + /// If the mode is 'MIN_FIRST', then this approach is used: + /// + /// ``` + /// num_discrete_values = 1 << (# of bits in T) + /// range_adjust = num_discrete_values / (num_discrete_values - 1) + /// range = (range_max - range_min) * range_adjust + /// range_scale = num_discrete_values / range + /// quantized = round(input * range_scale) - round(range_min * range_scale) + + /// numeric_limits::min() + /// quantized = max(quantized, numeric_limits::min()) + /// quantized = min(quantized, numeric_limits::max()) + /// ``` + /// + /// The biggest difference between this and MIN_COMBINED is that the minimum range + /// is rounded first, before it's subtracted from the rounded value. With + /// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing + /// and dequantizing will introduce a larger and larger error. + /// + /// *SCALED mode Example* + /// + /// `SCALED` mode matches the quantization approach used in + /// `QuantizeAndDequantize{V2|V3}`. + /// + /// If the mode is `SCALED`, the quantization is performed by multiplying each + /// input value by a scaling_factor. + /// The scaling_factor is determined from `min_range` and `max_range` to be as large + /// as possible such that the range from `min_range` to `max_range` is representable + /// within values of type T. + /// + /// ```c++ + /// + /// const int min_T = std::numeric_limits::min(); + /// const int max_T = std::numeric_limits::max(); + /// const float max_float = std::numeric_limits::max(); + /// + /// const float scale_factor_from_min_side = + /// (min_T * min_range > 0) ? min_T / min_range : max_float; + /// const float scale_factor_from_max_side = + /// (max_T * max_range > 0) ? max_T / max_range : max_float; + /// + /// const float scale_factor = std::min(scale_factor_from_min_side, + /// scale_factor_from_max_side); + /// ``` + /// + /// We next use the scale_factor to adjust min_range and max_range as follows: + /// + /// ```c++ + /// min_range = min_T / scale_factor; + /// max_range = max_T / scale_factor; + /// ``` + /// + /// + /// e.g. if T = qint8, and initially min_range = -10, and max_range = 9, we would + /// compare -128/-10.0 = 12.8 to 127/9.0 = 14.11, and set scaling_factor = 12.8 + /// In this case, min_range would remain -10, but max_range would be adjusted to + /// 127 / 12.8 = 9.921875 + /// + /// So we will quantize input values in the range (-10, 9.921875) to (-128, 127). + /// + /// The input tensor can now be quantized by clipping values to the range + /// `min_range` to `max_range`, then multiplying by scale_factor as follows: + /// + /// ```c++ + /// result = round(min(max_range, max(min_range, input)) * scale_factor) + /// ``` + /// + /// The adjusted `min_range` and `max_range` are returned as outputs 2 and 3 of + /// this operation. These outputs should be used as the range for any further + /// calculations. + /// + /// + /// *narrow_range (bool) attribute* + /// + /// If true, we do not use the minimum quantized value. + /// i.e. for int8 the quantized output, it would be restricted to the range + /// -127..127 instead of the full -128..127 range. + /// This is provided for compatibility with certain inference backends. + /// (Only applies to SCALED mode) + /// + /// + /// *axis (int) attribute* + /// + /// An optional `axis` attribute can specify a dimension index of the input tensor, + /// such that quantization ranges will be calculated and applied separately for each + /// slice of the tensor along that dimension. This is useful for per-channel + /// quantization. + /// + /// If axis is specified, min_range and max_range + /// + /// if `axis`=None, per-tensor quantization is performed as normal. + /// + /// + /// *ensure_minimum_range (float) attribute* + /// + /// Ensures the minimum quantization range is at least this value. + /// The legacy default value for this is 0.01, but it is strongly suggested to + /// set it to 0 for new uses. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantize_v2(Tensor input, Tensor min_range, Tensor max_range, TF_DataType T, string mode = "MIN_COMBINED", string round_mode = "HALF_AWAY_FROM_ZERO", bool narrow_range = false, int axis = -1, float ensure_minimum_range = 0.01f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeV2", name) { args = new object[] { input, min_range, max_range }, attrs = new Dictionary() { ["T"] = T, ["mode"] = mode, ["round_mode"] = round_mode, ["narrow_range"] = narrow_range, ["axis"] = axis, ["ensure_minimum_range"] = ensure_minimum_range } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_v2_eager_fallback(input, min_range, max_range, T: T, mode: mode, round_mode: round_mode, narrow_range: narrow_range, axis: axis, ensure_minimum_range: ensure_minimum_range, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (mode is null) + { + mode = "MIN_COMBINED"; + } + if (round_mode is null) + { + round_mode = "HALF_AWAY_FROM_ZERO"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["min_range"] = min_range; + keywords["max_range"] = max_range; + keywords["T"] = T; + keywords["mode"] = mode; + keywords["round_mode"] = round_mode; + keywords["narrow_range"] = narrow_range; + keywords["axis"] = axis; + keywords["ensure_minimum_range"] = ensure_minimum_range; + var _op = tf.OpDefLib._apply_op_helper("QuantizeV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "mode", _op.get_attr("mode"), "round_mode", _op.get_attr("round_mode"), "narrow_range", _op._get_attr_bool("narrow_range"), "axis", _op._get_attr_int("axis"), "ensure_minimum_range", _op.get_attr("ensure_minimum_range") }; + _execute.record_gradient("QuantizeV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantize_v2_eager_fallback(Tensor input, Tensor min_range, Tensor max_range, TF_DataType T, string mode, string round_mode, bool narrow_range, int axis, float ensure_minimum_range, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, min_range, max_range }; + object[] _attrs = new object[] { "T", T, "mode", mode, "round_mode", round_mode, "narrow_range", narrow_range, "axis", axis, "ensure_minimum_range", ensure_minimum_range }; + var _result = _execute.execute("QuantizeV2", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Concatenates quantized tensors along one dimension. + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_concat(Tensor concat_dim, Tensors values, Tensors input_mins, Tensors input_maxes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConcat", name) { args = new object[] { concat_dim, values, input_mins, input_maxes }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_concat_eager_fallback(concat_dim, values, input_mins, input_maxes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["concat_dim"] = concat_dim; + keywords["values"] = values; + keywords["input_mins"] = input_mins; + keywords["input_maxes"] = input_maxes; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConcat", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("QuantizedConcat", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_concat_eager_fallback(Tensor concat_dim, Tensors values, Tensors input_mins, Tensors input_maxes, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.Add(concat_dim); + _inputs_flat_list.AddRange(values); + _inputs_flat_list.AddRange(input_mins); + _inputs_flat_list.AddRange(input_maxes); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", values.Length, "T", values.dtype }; + var _result = _execute.execute("QuantizedConcat", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConcat", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Quantized Instance normalization. + /// + /// + /// + /// + /// + /// + /// If True, `given_y_min` and `given_y_min` + /// and `given_y_max` are used as the output range. Otherwise, + /// the implementation computes the output range. + /// + /// + /// + /// + /// Output in `y_min` if `output_range_given` is True. + /// + /// + /// + /// + /// Output in `y_max` if `output_range_given` is True. + /// + /// + /// + /// + /// A small float number to avoid dividing by 0. + /// + /// + /// + /// + /// Minimum value of `y_max - y_min` + /// + /// + /// + public static Tensor[] quantized_instance_norm(Tensor x, Tensor x_min, Tensor x_max, bool output_range_given = false, float given_y_min = 0f, float given_y_max = 0f, float variance_epsilon = 1E-05f, float min_separation = 0.001f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedInstanceNorm", name) { args = new object[] { x, x_min, x_max }, attrs = new Dictionary() { ["output_range_given"] = output_range_given, ["given_y_min"] = given_y_min, ["given_y_max"] = given_y_max, ["variance_epsilon"] = variance_epsilon, ["min_separation"] = min_separation } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_instance_norm_eager_fallback(x, x_min, x_max, output_range_given: output_range_given, given_y_min: given_y_min, given_y_max: given_y_max, variance_epsilon: variance_epsilon, min_separation: min_separation, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["x_min"] = x_min; + keywords["x_max"] = x_max; + keywords["output_range_given"] = output_range_given; + keywords["given_y_min"] = given_y_min; + keywords["given_y_max"] = given_y_max; + keywords["variance_epsilon"] = variance_epsilon; + keywords["min_separation"] = min_separation; + var _op = tf.OpDefLib._apply_op_helper("QuantizedInstanceNorm", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "output_range_given", _op._get_attr_bool("output_range_given"), "given_y_min", _op.get_attr("given_y_min"), "given_y_max", _op.get_attr("given_y_max"), "variance_epsilon", _op.get_attr("variance_epsilon"), "min_separation", _op.get_attr("min_separation") }; + _execute.record_gradient("QuantizedInstanceNorm", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_instance_norm_eager_fallback(Tensor x, Tensor x_min, Tensor x_max, bool output_range_given, float given_y_min, float given_y_max, float variance_epsilon, float min_separation, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, x_min, x_max }; + object[] _attrs = new object[] { "T", x.dtype, "output_range_given", output_range_given, "given_y_min", given_y_min, "given_y_max", given_y_max, "variance_epsilon", variance_epsilon, "min_separation", min_separation }; + var _result = _execute.execute("QuantizedInstanceNorm", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedInstanceNorm", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Reshapes a quantized tensor as per the Reshape op. + /// + /// + /// + /// ``` + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_reshape(Tensor tensor, Tensor shape, Tensor input_min, Tensor input_max, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedReshape", name) { args = new object[] { tensor, shape, input_min, input_max }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_reshape_eager_fallback(tensor, shape, input_min, input_max, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["shape"] = shape; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + var _op = tf.OpDefLib._apply_op_helper("QuantizedReshape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tshape", _op._get_attr_type("Tshape") }; + _execute.record_gradient("QuantizedReshape", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_reshape_eager_fallback(Tensor tensor, Tensor shape, Tensor input_min, Tensor input_max, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, shape, input_min, input_max }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tshape", shape.dtype }; + var _result = _execute.execute("QuantizedReshape", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedReshape", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// This operation returns an integer representing the rank of `input`. + /// + /// For example: + /// + /// ``` + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// # shape of tensor 't' is [2, 2, 3] + /// rank(t) ==> 3 + /// ``` + /// + /// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank + /// of a tensor is the number of indices required to uniquely select each element + /// of the tensor. Rank is also known as "order", "degree", or "ndims." + /// + /// + /// + /// + public static Tensor rank(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Rank", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return rank_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("Rank", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Rank", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor rank_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("Rank", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Rank", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return the same ref tensor as the input ref tensor. + /// + /// + /// + public static Tensor ref_identity(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("ref_identity op does not support eager execution. Arg input is a ref."); + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("RefIdentity", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("RefIdentity", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ref_identity_eager_fallback(Tensor input, string name, Context ctx) + { + throw new RuntimeError($"ref_identity op does not support eager execution. Arg 'input' is a ref."); + } + /// + /// Reshapes a tensor. + /// + /// + /// + /// Given `tensor`, this operation returns a tensor that has the same values + /// as `tensor` with shape `shape`. + /// + /// If one component of 1-D tensor `shape` is the special value -1, the size of that + /// dimension is computed so that the total size remains constant. In particular, a + /// `shape` of `[-1]` flattens into 1-D. At most one component of `shape` may be + /// unknown. + /// + /// The `shape` must be 1-D and the operation returns a tensor with shape + /// `shape` filled with the values of `tensor`. In this case, the number of elements + /// implied by `shape` must be the same as the number of elements in `tensor`. + /// + /// It is an error if `shape` is not 1-D. + /// + /// For example: + /// + /// ``` + /// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] + /// # tensor 't' has shape [9] + /// reshape(t, [3, 3]) ==> [[1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9]] + /// + /// # tensor 't' is [[[1, 1], [2, 2]], + /// # [[3, 3], [4, 4]]] + /// # tensor 't' has shape [2, 2, 2] + /// reshape(t, [2, 4]) ==> [[1, 1, 2, 2], + /// [3, 3, 4, 4]] + /// + /// # tensor 't' is [[[1, 1, 1], + /// # [2, 2, 2]], + /// # [[3, 3, 3], + /// # [4, 4, 4]], + /// # [[5, 5, 5], + /// # [6, 6, 6]]] + /// # tensor 't' has shape [3, 2, 3] + /// # pass '[-1]' to flatten 't' + /// reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] + /// + /// # -1 can also be used to infer the shape + /// + /// # -1 is inferred to be 9: + /// reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + /// [4, 4, 4, 5, 5, 5, 6, 6, 6]] + /// # -1 is inferred to be 2: + /// reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], + /// [4, 4, 4, 5, 5, 5, 6, 6, 6]] + /// # -1 is inferred to be 3: + /// reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], + /// [2, 2, 2], + /// [3, 3, 3]], + /// [[4, 4, 4], + /// [5, 5, 5], + /// [6, 6, 6]]] + /// + /// # tensor 't' is [7] + /// # shape `[]` reshapes to a scalar + /// reshape(t, []) ==> 7 + /// ``` + /// + /// + /// + /// + /// + public static Tensor reshape(Tensor tensor, Tensor shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Reshape", name) { args = new object[] { tensor, shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reshape_eager_fallback(tensor, shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("Reshape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tshape", _op._get_attr_type("Tshape") }; + _execute.record_gradient("Reshape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reshape_eager_fallback(Tensor tensor, Tensor shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, shape }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tshape", shape.dtype }; + var _result = _execute.execute("Reshape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Reshape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Assign `value` to the sliced l-value reference of `ref`. + /// + /// + /// + /// The values of `value` are assigned to the positions in the variable + /// `ref` that are selected by the slice parameters. The slice parameters + /// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. + /// + /// NOTE this op currently does not support broadcasting and so `value`'s + /// shape must be exactly the shape produced by the slice of `ref`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Operation resource_strided_slice_assign(Tensor ref_, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceStridedSliceAssign", name) { args = new object[] { ref_, begin, end, strides, value }, attrs = new Dictionary() { ["begin_mask"] = begin_mask, ["end_mask"] = end_mask, ["ellipsis_mask"] = ellipsis_mask, ["new_axis_mask"] = new_axis_mask, ["shrink_axis_mask"] = shrink_axis_mask } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_strided_slice_assign_eager_fallback(ref_, begin, end, strides, value, begin_mask: begin_mask, end_mask: end_mask, ellipsis_mask: ellipsis_mask, new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["ref"] = ref_; + keywords["begin"] = begin; + keywords["end"] = end; + keywords["strides"] = strides; + keywords["value"] = value; + keywords["begin_mask"] = begin_mask; + keywords["end_mask"] = end_mask; + keywords["ellipsis_mask"] = ellipsis_mask; + keywords["new_axis_mask"] = new_axis_mask; + keywords["shrink_axis_mask"] = shrink_axis_mask; + var _op = tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index"), "begin_mask", _op._get_attr_int("begin_mask"), "end_mask", _op._get_attr_int("end_mask"), "ellipsis_mask", _op._get_attr_int("ellipsis_mask"), "new_axis_mask", _op._get_attr_int("new_axis_mask"), "shrink_axis_mask", _op._get_attr_int("shrink_axis_mask") }; + _execute.record_gradient("ResourceStridedSliceAssign", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_strided_slice_assign_eager_fallback(Tensor ref_, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { ref_, begin, end, strides, value }; + object[] _attrs = new object[] { "T", value.dtype, "Index", begin.dtype, "begin_mask", begin_mask, "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask }; + var _result = _execute.execute("ResourceStridedSliceAssign", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceStridedSliceAssign", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// + /// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions + /// of `tensor`, this operation reverses each dimension i of `tensor` where + /// `dims[i]` is `True`. + /// + /// `tensor` can have up to 8 dimensions. The number of dimensions + /// of `tensor` must equal the number of elements in `dims`. In other words: + /// + /// `rank(tensor) = size(dims)` + /// + /// For example: + /// + /// ``` + /// # tensor 't' is [[[[ 0, 1, 2, 3], + /// # [ 4, 5, 6, 7], + /// # [ 8, 9, 10, 11]], + /// # [[12, 13, 14, 15], + /// # [16, 17, 18, 19], + /// # [20, 21, 22, 23]]]] + /// # tensor 't' shape is [1, 2, 3, 4] + /// + /// # 'dims' is [False, False, False, True] + /// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], + /// [ 7, 6, 5, 4], + /// [ 11, 10, 9, 8]], + /// [[15, 14, 13, 12], + /// [19, 18, 17, 16], + /// [23, 22, 21, 20]]]] + /// + /// # 'dims' is [False, True, False, False] + /// reverse(t, dims) ==> [[[[12, 13, 14, 15], + /// [16, 17, 18, 19], + /// [20, 21, 22, 23] + /// [[ 0, 1, 2, 3], + /// [ 4, 5, 6, 7], + /// [ 8, 9, 10, 11]]]] + /// + /// # 'dims' is [False, False, True, False] + /// reverse(t, dims) ==> [[[[8, 9, 10, 11], + /// [4, 5, 6, 7], + /// [0, 1, 2, 3]] + /// [[20, 21, 22, 23], + /// [16, 17, 18, 19], + /// [12, 13, 14, 15]]]] + /// ``` + /// + /// + /// + /// + /// + public static Tensor reverse(Tensor tensor, Tensor dims, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Reverse", name) { args = new object[] { tensor, dims }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reverse_eager_fallback(tensor, dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["dims"] = dims; + var _op = tf.OpDefLib._apply_op_helper("Reverse", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Reverse", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reverse_eager_fallback(Tensor tensor, Tensor dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, dims }; + object[] _attrs = new object[] { "T", tensor.dtype }; + var _result = _execute.execute("Reverse", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Reverse", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Reverses variable length slices. + /// + /// + /// + /// This op first slices `input` along the dimension `batch_dim`, and for each + /// slice `i`, reverses the first `seq_lengths[i]` elements along + /// the dimension `seq_dim`. + /// + /// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, + /// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. + /// + /// The output slice `i` along dimension `batch_dim` is then given by input + /// slice `i`, with the first `seq_lengths[i]` slices along dimension + /// `seq_dim` reversed. + /// + /// For example: + /// + /// ``` + /// # Given this: + /// batch_dim = 0 + /// seq_dim = 1 + /// input.dims = (4, 8, ...) + /// seq_lengths = [7, 2, 3, 5] + /// + /// # then slices of input are reversed on seq_dim, but only up to seq_lengths: + /// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] + /// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] + /// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] + /// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] + /// + /// # while entries past seq_lens are copied through: + /// output[0, 7:, :, ...] = input[0, 7:, :, ...] + /// output[1, 2:, :, ...] = input[1, 2:, :, ...] + /// output[2, 3:, :, ...] = input[2, 3:, :, ...] + /// output[3, 2:, :, ...] = input[3, 2:, :, ...] + /// ``` + /// + /// In contrast, if: + /// + /// ``` + /// # Given this: + /// batch_dim = 2 + /// seq_dim = 0 + /// input.dims = (8, ?, 4, ...) + /// seq_lengths = [7, 2, 3, 5] + /// + /// # then slices of input are reversed on seq_dim, but only up to seq_lengths: + /// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] + /// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] + /// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] + /// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] + /// + /// # while entries past seq_lens are copied through: + /// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] + /// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] + /// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] + /// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] + /// ``` + /// + /// + /// + /// + /// + /// + /// The dimension which is partially reversed. + /// + /// + /// + /// + /// The dimension along which reversal is performed. + /// + /// + /// + public static Tensor reverse_sequence(Tensor input, Tensor seq_lengths, int seq_dim = 0, int batch_dim = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReverseSequence", name) { args = new object[] { input, seq_lengths }, attrs = new Dictionary() { ["seq_dim"] = seq_dim, ["batch_dim"] = batch_dim } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reverse_sequence_eager_fallback(input, seq_lengths, seq_dim: seq_dim, batch_dim: batch_dim, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["seq_lengths"] = seq_lengths; + keywords["seq_dim"] = seq_dim; + keywords["batch_dim"] = batch_dim; + var _op = tf.OpDefLib._apply_op_helper("ReverseSequence", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "seq_dim", _op._get_attr_int("seq_dim"), "batch_dim", _op._get_attr_int("batch_dim"), "T", _op._get_attr_type("T"), "Tlen", _op._get_attr_type("Tlen") }; + _execute.record_gradient("ReverseSequence", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reverse_sequence_eager_fallback(Tensor input, Tensor seq_lengths, int seq_dim, int batch_dim, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, seq_lengths }; + object[] _attrs = new object[] { "seq_dim", seq_dim, "batch_dim", batch_dim, "T", input.dtype, "Tlen", seq_lengths.dtype }; + var _result = _execute.execute("ReverseSequence", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReverseSequence", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// + /// Given a `tensor`, and a `int32` tensor `axis` representing the set of + /// dimensions of `tensor` to reverse. This operation reverses each dimension + /// `i` for which there exists `j` s.t. `axis[j] == i`. + /// + /// `tensor` can have up to 8 dimensions. The number of dimensions specified + /// in `axis` may be 0 or more entries. If an index is specified more than + /// once, a InvalidArgument error is raised. + /// + /// For example: + /// + /// ``` + /// # tensor 't' is [[[[ 0, 1, 2, 3], + /// # [ 4, 5, 6, 7], + /// # [ 8, 9, 10, 11]], + /// # [[12, 13, 14, 15], + /// # [16, 17, 18, 19], + /// # [20, 21, 22, 23]]]] + /// # tensor 't' shape is [1, 2, 3, 4] + /// + /// # 'dims' is [3] or 'dims' is [-1] + /// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], + /// [ 7, 6, 5, 4], + /// [ 11, 10, 9, 8]], + /// [[15, 14, 13, 12], + /// [19, 18, 17, 16], + /// [23, 22, 21, 20]]]] + /// + /// # 'dims' is '[1]' (or 'dims' is '[-3]') + /// reverse(t, dims) ==> [[[[12, 13, 14, 15], + /// [16, 17, 18, 19], + /// [20, 21, 22, 23] + /// [[ 0, 1, 2, 3], + /// [ 4, 5, 6, 7], + /// [ 8, 9, 10, 11]]]] + /// + /// # 'dims' is '[2]' (or 'dims' is '[-2]') + /// reverse(t, dims) ==> [[[[8, 9, 10, 11], + /// [4, 5, 6, 7], + /// [0, 1, 2, 3]] + /// [[20, 21, 22, 23], + /// [16, 17, 18, 19], + /// [12, 13, 14, 15]]]] + /// ``` + /// + /// + /// + /// + /// + public static Tensor reverse_v2(Tensor tensor, Tensor axis, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReverseV2", name) { args = new object[] { tensor, axis }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reverse_v2_eager_fallback(tensor, axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("ReverseV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("ReverseV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reverse_v2_eager_fallback(Tensor tensor, Tensor axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, axis }; + object[] _attrs = new object[] { "Tidx", axis.dtype, "T", tensor.dtype }; + var _result = _execute.execute("ReverseV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReverseV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Scatters `updates` into a tensor of shape `shape` according to `indices`. + /// + /// + /// + /// Scatter sparse `updates` according to individual values at the specified + /// `indices`. This op returns an output tensor with the `shape` you specify. This + /// op is the inverse of the `tf.gather_nd` operator which extracts values or slices + /// from a given tensor. + /// + /// This operation is similar to `tf.tensor_scatter_nd_add`, except that the tensor + /// is zero-initialized. Calling `tf.scatter_nd(indices, updates, shape)` + /// is identical to calling + /// `tf.tensor_scatter_nd_add(tf.zeros(shape, updates.dtype), indices, updates)` + /// + /// If `indices` contains duplicates, the associated `updates` are accumulated + /// (summed) into the output tensor. + /// + /// **WARNING**: For floating-point data types, the output may be nondeterministic. + /// This is because the order in which the updates are applied is nondeterministic + /// and when floating-point numbers are added in different orders the resulting + /// numerical approximation error can be slightly different. However, the output + /// will be deterministic if op determinism is enabled via + /// `tf.config.experimental.enable_op_determinism`. + /// + /// `indices` is an integer tensor containing indices into the output tensor. The + /// last dimension of `indices` can be at most the rank of `shape`: + /// + /// indices.shape[-1] <= shape.rank + /// + /// The last dimension of `indices` corresponds to indices of elements + /// (if `indices.shape[-1] = shape.rank`) or slices + /// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of + /// `shape`. + /// + /// `updates` is a tensor with shape: + /// + /// indices.shape[:-1] + shape[indices.shape[-1]:] + /// + /// The simplest form of the scatter op is to insert individual elements in + /// a tensor by index. Consider an example where you want to insert 4 scattered + /// elements in a rank-1 tensor with 8 elements. + /// + ///
+ /// + ///
+ /// + /// In Python, this scatter operation would look like this: + /// + /// ```python + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// shape = tf.constant([8]) + /// scatter = tf.scatter_nd(indices, updates, shape) + /// print(scatter) + /// ``` + /// + /// The resulting tensor would look like this: + /// + /// [0, 11, 0, 10, 9, 0, 0, 12] + /// + /// You can also insert entire slices of a higher rank tensor all at once. For + /// example, you can insert two slices in the first dimension of a rank-3 tensor + /// with two matrices of new values. + /// + ///
+ /// + ///
+ /// + /// In Python, this scatter operation would look like this: + /// + /// ```python + /// indices = tf.constant([[1], [3]]) + /// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]]]) + /// shape = tf.constant([4, 4, 4]) + /// scatter = tf.scatter_nd(indices, updates, shape) + /// print(scatter) + /// ``` + /// + /// The resulting tensor would look like this: + /// + /// [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]] + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, the index is ignored. + /// + ///
+ /// + /// + /// + /// + public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ScatterNd", name) { args = new object[] { indices, updates, shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return scatter_nd_eager_fallback(indices, updates, shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["indices"] = indices; + keywords["updates"] = updates; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("ScatterNd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ScatterNd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor scatter_nd_eager_fallback(Tensor indices, Tensor updates, Tensor shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { indices, updates, shape }; + object[] _attrs = new object[] { "T", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ScatterNd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ScatterNd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Applies sparse addition to `input` using individual values or slices + /// + /// + /// + /// from `updates` according to indices `indices`. The updates are non-aliasing: + /// `input` is only modified in-place if no other operations will use it. + /// Otherwise, a copy of `input` is made. This operation has a gradient with + /// respect to both `input` and `updates`. + /// + /// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + /// + /// `indices` must be integer tensor, containing indices into `input`. + /// It must be shape \([d_0, ..., d_{Q-2}, K]\) where `0 < K <= P`. + /// + /// The innermost dimension of `indices` (with length `K`) corresponds to + /// indices into elements (if `K = P`) or `(P-K)`-dimensional slices + /// (if `K < P`) along the `K`th dimension of `input`. + /// + /// `updates` is `Tensor` of rank `Q-1+P-K` with shape: + /// + /// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ + /// + /// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 + /// elements. In Python, that addition would look like this: + /// + /// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) + /// with tf.Session() as sess: + /// print(sess.run(output)) + /// + /// The resulting value `output` would look like this: + /// + /// [1, 13, 3, 14, 14, 6, 7, 20] + /// + /// See `tf.scatter_nd` for more details about how to make updates to slices. + /// + /// + /// + /// + /// + /// + public static Tensor scatter_nd_non_aliasing_add(Tensor input, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ScatterNdNonAliasingAdd", name) { args = new object[] { input, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return scatter_nd_non_aliasing_add_eager_fallback(input, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ScatterNdNonAliasingAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ScatterNdNonAliasingAdd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor scatter_nd_non_aliasing_add_eager_fallback(Tensor input, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, indices, updates }; + object[] _attrs = new object[] { "T", input.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ScatterNdNonAliasingAdd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ScatterNdNonAliasingAdd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the shape of a tensor. + /// + /// + /// + /// This operation returns a 1-D integer tensor representing the shape of `input`. + /// + /// For example: + /// + /// ``` + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// shape(t) ==> [2, 2, 3] + /// ``` + /// + /// + /// + /// + /// + public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Shape", name) { args = new object[] { input }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return shape_eager_fallback(input, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("Shape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("Shape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor shape_eager_fallback(Tensor input, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "out_type", out_type }; + var _result = _execute.execute("Shape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Shape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns shape of tensors. + /// + /// + /// + /// This operation returns N 1-D integer tensors representing shape of `input[i]s`. + /// + /// + /// + /// + /// + public static Tensor[] shape_n(Tensors input, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ShapeN", name) { args = new object[] { input }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return shape_n_eager_fallback(input, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("ShapeN", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] shape_n_eager_fallback(Tensors input, TF_DataType out_type, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(input); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", input.Length, "T", input.dtype, "out_type", out_type }; + var _result = _execute.execute("ShapeN", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ShapeN", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns the size of a tensor. + /// + /// + /// + /// This operation returns an integer representing the number of elements in + /// `input`. + /// + /// For example: + /// + /// ``` + /// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]] + /// size(t) ==> 12 + /// ``` + /// + /// + /// + /// + /// + public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Size", name) { args = new object[] { input }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return size_eager_fallback(input, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("Size", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("Size", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor size_eager_fallback(Tensor input, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "out_type", out_type }; + var _result = _execute.execute("Size", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Size", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return a slice from 'input'. + /// + /// + /// + /// The output tensor is a tensor with dimensions described by 'size' + /// whose values are extracted from 'input' starting at the offsets in + /// 'begin'. + /// + /// *Requirements*: + /// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) + /// + /// + /// + /// + /// + /// + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Slice", name) { args = new object[] { input, begin, size }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return slice_eager_fallback(input, begin, size, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["begin"] = begin; + keywords["size"] = size; + var _op = tf.OpDefLib._apply_op_helper("Slice", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index") }; + _execute.record_gradient("Slice", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor slice_eager_fallback(Tensor input, Tensor begin, Tensor size, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, begin, size }; + object[] _attrs = new object[] { "T", input.dtype, "Index", begin.dtype }; + var _result = _execute.execute("Slice", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Slice", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a copy of the input tensor. + /// + /// + /// + public static Tensor snapshot(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Snapshot", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return snapshot_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("Snapshot", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Snapshot", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor snapshot_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("Snapshot", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Snapshot", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// SpaceToBatch for 4-D tensors of type T. + /// + /// + /// + /// This is a legacy version of the more general SpaceToBatchND. + /// + /// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + /// More specifically, this op outputs a copy of the input tensor where values from + /// the `height` and `width` dimensions are moved to the `batch` dimension. After + /// the zero-padding, both `height` and `width` of the input must be divisible by the + /// block size. + /// + /// The attr `block_size` must be greater than one. It indicates the block size. + /// + /// * Non-overlapping blocks of size `block_size x block size` in the height and + /// width dimensions are rearranged into the batch dimension at each location. + /// * The batch of the output tensor is `batch * block_size * block_size`. + /// * Both height_pad and width_pad must be divisible by block_size. + /// + /// The shape of the output will be: + /// + /// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, + /// depth] + /// + /// Some examples: + /// + /// (1) For the following input of shape `[1, 2, 2, 1]` and block_size of 2: + /// + /// ``` + /// x = [[[[1], [2]], [[3], [4]]]] + /// ``` + /// + /// The output tensor has shape `[4, 1, 1, 1]` and value: + /// + /// ``` + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// ``` + /// + /// (2) For the following input of shape `[1, 2, 2, 3]` and block_size of 2: + /// + /// ``` + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// ``` + /// + /// The output tensor has shape `[4, 1, 1, 3]` and value: + /// + /// ``` + /// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + /// ``` + /// + /// (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: + /// + /// ``` + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// ``` + /// + /// The output tensor has shape `[4, 2, 2, 1]` and value: + /// + /// ``` + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// ``` + /// + /// (4) For the following input of shape `[2, 2, 4, 1]` and block_size of 2: + /// + /// ``` + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]]], + /// [[[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// ``` + /// + /// The output tensor has shape `[8, 1, 2, 1]` and value: + /// + /// ``` + /// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + /// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + /// ``` + /// + /// Among others, this operation is useful for reducing atrous convolution into + /// regular convolution. + /// + /// + /// + /// + /// + /// + public static Tensor space_to_batch(Tensor input, Tensor paddings, int block_size = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SpaceToBatch", name) { args = new object[] { input, paddings }, attrs = new Dictionary() { ["block_size"] = block_size } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return space_to_batch_eager_fallback(input, paddings, block_size: block_size, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + keywords["block_size"] = block_size; + var _op = tf.OpDefLib._apply_op_helper("SpaceToBatch", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tpaddings", _op._get_attr_type("Tpaddings"), "block_size", _op._get_attr_int("block_size") }; + _execute.record_gradient("SpaceToBatch", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor space_to_batch_eager_fallback(Tensor input, Tensor paddings, int block_size, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings }; + object[] _attrs = new object[] { "T", input.dtype, "Tpaddings", paddings.dtype, "block_size", block_size }; + var _result = _execute.execute("SpaceToBatch", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SpaceToBatch", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// SpaceToBatch for N-D tensors of type T. + /// + /// + /// + /// This operation divides "spatial" dimensions `[1, ..., M]` of the input into a + /// grid of blocks of shape `block_shape`, and interleaves these blocks with the + /// "batch" dimension (0) such that in the output, the spatial dimensions + /// `[1, ..., M]` correspond to the position within the grid, and the batch + /// dimension combines both the position within a spatial block and the original + /// batch position. Prior to division into blocks, the spatial dimensions of the + /// input are optionally zero padded according to `paddings`. See below for a + /// precise description. + /// + /// This operation is equivalent to the following steps: + /// + /// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the + /// input according to `paddings` to produce `padded` of shape `padded_shape`. + /// + /// 2. Reshape `padded` to `reshaped_padded` of shape: + /// + /// [batch] + + /// [padded_shape[1] / block_shape[0], + /// block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1], + /// block_shape[M-1]] + + /// remaining_shape + /// + /// 3. Permute dimensions of `reshaped_padded` to produce + /// `permuted_reshaped_padded` of shape: + /// + /// block_shape + + /// [batch] + + /// [padded_shape[1] / block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1]] + + /// remaining_shape + /// + /// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the batch + /// dimension, producing an output tensor of shape: + /// + /// [batch * prod(block_shape)] + + /// [padded_shape[1] / block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1]] + + /// remaining_shape + /// + /// Some examples: + /// + /// (1) For the following input of shape `[1, 2, 2, 1]`, `block_shape = [2, 2]`, and + /// `paddings = [[0, 0], [0, 0]]`: + /// + /// ``` + /// x = [[[[1], [2]], [[3], [4]]]] + /// ``` + /// + /// The output tensor has shape `[4, 1, 1, 1]` and value: + /// + /// ``` + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// ``` + /// + /// (2) For the following input of shape `[1, 2, 2, 3]`, `block_shape = [2, 2]`, and + /// `paddings = [[0, 0], [0, 0]]`: + /// + /// ``` + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// ``` + /// + /// The output tensor has shape `[4, 1, 1, 3]` and value: + /// + /// ``` + /// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + /// ``` + /// + /// (3) For the following input of shape `[1, 4, 4, 1]`, `block_shape = [2, 2]`, and + /// `paddings = [[0, 0], [0, 0]]`: + /// + /// ``` + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// ``` + /// + /// The output tensor has shape `[4, 2, 2, 1]` and value: + /// + /// ``` + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// ``` + /// + /// (4) For the following input of shape `[2, 2, 4, 1]`, block_shape = `[2, 2]`, and + /// paddings = `[[0, 0], [2, 0]]`: + /// + /// ``` + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]]], + /// [[[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// ``` + /// + /// The output tensor has shape `[8, 1, 3, 1]` and value: + /// + /// ``` + /// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], + /// [[[0], [2], [4]]], [[[0], [10], [12]]], + /// [[[0], [5], [7]]], [[[0], [13], [15]]], + /// [[[0], [6], [8]]], [[[0], [14], [16]]]] + /// ``` + /// + /// Among others, this operation is useful for reducing atrous convolution into + /// regular convolution. + /// + /// + /// + /// + /// + /// + public static Tensor space_to_batch_nd(Tensor input, Tensor block_shape, Tensor paddings, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SpaceToBatchND", name) { args = new object[] { input, block_shape, paddings }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return space_to_batch_nd_eager_fallback(input, block_shape, paddings, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["block_shape"] = block_shape; + keywords["paddings"] = paddings; + var _op = tf.OpDefLib._apply_op_helper("SpaceToBatchND", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tblock_shape", _op._get_attr_type("Tblock_shape"), "Tpaddings", _op._get_attr_type("Tpaddings") }; + _execute.record_gradient("SpaceToBatchND", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor space_to_batch_nd_eager_fallback(Tensor input, Tensor block_shape, Tensor paddings, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, block_shape, paddings }; + object[] _attrs = new object[] { "T", input.dtype, "Tblock_shape", block_shape.dtype, "Tpaddings", paddings.dtype }; + var _result = _execute.execute("SpaceToBatchND", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SpaceToBatchND", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// SpaceToDepth for tensors of type T. + /// + /// + /// + /// Rearranges blocks of spatial data, into depth. More specifically, + /// this op outputs a copy of the input tensor where values from the `height` + /// and `width` dimensions are moved to the `depth` dimension. + /// The attr `block_size` indicates the input block size. + /// + /// * Non-overlapping blocks of size `block_size x block size` are rearranged + /// into depth at each location. + /// * The depth of the output tensor is `block_size * block_size * input_depth`. + /// * The Y, X coordinates within each block of the input become the high order + /// component of the output channel index. + /// * The input tensor's height and width must be divisible by block_size. + /// + /// The `data_format` attr specifies the layout of the input and output tensors + /// with the following options: + /// "NHWC": `[ batch, height, width, channels ]` + /// "NCHW": `[ batch, channels, height, width ]` + /// "NCHW_VECT_C": + /// `qint8 [ batch, channels / 4, height, width, 4 ]` + /// + /// It is useful to consider the operation as transforming a 6-D Tensor. + /// e.g. for data_format = NHWC, + /// Each element in the input tensor can be specified via 6 coordinates, + /// ordered by decreasing memory layout significance as: + /// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates + /// within the output image, bX, bY means coordinates + /// within the input block, iC means input channels). + /// The output would be a transpose to the following layout: + /// n,oY,oX,bY,bX,iC + /// + /// This operation is useful for resizing the activations between convolutions + /// (but keeping all data), e.g. instead of pooling. It is also useful for training + /// purely convolutional models. + /// + /// For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and + /// block_size = 2: + /// + /// ``` + /// x = [[[[1], [2]], + /// [[3], [4]]]] + /// ``` + /// + /// This operation will output a tensor of shape `[1, 1, 1, 4]`: + /// + /// ``` + /// [[[[1, 2, 3, 4]]]] + /// ``` + /// + /// Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, + /// the corresponding output will have a single element (i.e. width and height are + /// both 1) and will have a depth of 4 channels (1 * block_size * block_size). + /// The output element shape is `[1, 1, 4]`. + /// + /// For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. + /// + /// ``` + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// ``` + /// + /// This operation, for block_size of 2, will return the following tensor of shape + /// `[1, 1, 1, 12]` + /// + /// ``` + /// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] + /// ``` + /// + /// Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: + /// + /// ``` + /// x = [[[[1], [2], [5], [6]], + /// [[3], [4], [7], [8]], + /// [[9], [10], [13], [14]], + /// [[11], [12], [15], [16]]]] + /// ``` + /// + /// the operator will return the following tensor of shape `[1 2 2 4]`: + /// + /// ``` + /// x = [[[[1, 2, 3, 4], + /// [5, 6, 7, 8]], + /// [[9, 10, 11, 12], + /// [13, 14, 15, 16]]]] + /// ``` + /// + /// + /// + /// + /// + /// The size of the spatial block. + /// + /// + /// + /// + public static Tensor space_to_depth(Tensor input, int block_size = 0, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SpaceToDepth", name) { args = new object[] { input }, attrs = new Dictionary() { ["block_size"] = block_size, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return space_to_depth_eager_fallback(input, block_size: block_size, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["block_size"] = block_size; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("SpaceToDepth", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "block_size", _op._get_attr_int("block_size"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("SpaceToDepth", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor space_to_depth_eager_fallback(Tensor input, int block_size, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "block_size", block_size, "data_format", data_format }; + var _result = _execute.execute("SpaceToDepth", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SpaceToDepth", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Splits a tensor into `num_split` tensors along one dimension. + /// + /// + /// + /// + /// + /// The number of ways to split. Must evenly divide + /// `value.shape[split_dim]`. + /// + /// + /// + public static Tensor[] split(Tensor split_dim, Tensor value, int num_split = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Split", name) { args = new object[] { split_dim, value }, attrs = new Dictionary() { ["num_split"] = num_split } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return split_eager_fallback(split_dim, value, num_split: num_split, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["split_dim"] = split_dim; + keywords["value"] = value; + keywords["num_split"] = num_split; + var _op = tf.OpDefLib._apply_op_helper("Split", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_split", _op._get_attr_int("num_split"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("Split", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] split_eager_fallback(Tensor split_dim, Tensor value, int num_split, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { split_dim, value }; + object[] _attrs = new object[] { "num_split", num_split, "T", value.dtype }; + var _result = _execute.execute("Split", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Split", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Splits a tensor into `num_split` tensors along one dimension. + /// + /// + /// + /// + /// + /// + public static Tensor[] split_v(Tensor value, Tensor size_splits, Tensor split_dim, int num_split = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SplitV", name) { args = new object[] { value, size_splits, split_dim }, attrs = new Dictionary() { ["num_split"] = num_split } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return split_v_eager_fallback(value, size_splits, split_dim, num_split: num_split, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["size_splits"] = size_splits; + keywords["split_dim"] = split_dim; + keywords["num_split"] = num_split; + var _op = tf.OpDefLib._apply_op_helper("SplitV", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num_split", _op._get_attr_int("num_split"), "T", _op._get_attr_type("T"), "Tlen", _op._get_attr_type("Tlen") }; + _execute.record_gradient("SplitV", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] split_v_eager_fallback(Tensor value, Tensor size_splits, Tensor split_dim, int num_split, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value, size_splits, split_dim }; + object[] _attrs = new object[] { "num_split", num_split, "T", value.dtype, "Tlen", size_splits.dtype }; + var _result = _execute.execute("SplitV", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SplitV", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Removes dimensions of size 1 from the shape of a tensor. + /// + /// + /// + /// Given a tensor `input`, this operation returns a tensor of the same type with + /// all dimensions of size 1 removed. If you don't want to remove all size 1 + /// dimensions, you can remove specific size 1 dimensions by specifying + /// `squeeze_dims`. + /// + /// For example: + /// + /// ``` + /// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] + /// shape(squeeze(t)) ==> [2, 3] + /// ``` + /// + /// Or, to remove specific size 1 dimensions: + /// + /// ``` + /// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] + /// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] + /// ``` + /// + /// + /// + /// + /// + /// If specified, only squeezes the dimensions listed. The dimension + /// index starts at 0. It is an error to squeeze a dimension that is not 1. Must + /// be in the range `[-rank(input), rank(input))`. + /// + /// + /// + public static Tensor squeeze(Tensor input, int[] squeeze_dims = null, string? name = null) + { + var _ctx = tf.Context; + if (squeeze_dims is null) + { + squeeze_dims = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Squeeze", name) { args = new object[] { input }, attrs = new Dictionary() { ["squeeze_dims"] = squeeze_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return squeeze_eager_fallback(input, squeeze_dims: squeeze_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["squeeze_dims"] = squeeze_dims; + var _op = tf.OpDefLib._apply_op_helper("Squeeze", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "squeeze_dims", _op.get_attr("squeeze_dims") }; + _execute.record_gradient("Squeeze", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor squeeze_eager_fallback(Tensor input, int[] squeeze_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "squeeze_dims", squeeze_dims }; + var _result = _execute.execute("Squeeze", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Squeeze", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Stops gradient computation. + /// + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, this op prevents the contribution of + /// its inputs to be taken into account. Normally, the gradient generator adds ops + /// to a graph to compute the derivatives of a specified 'loss' by recursively + /// finding out inputs that contributed to its computation. If you insert this op + /// in the graph it inputs are masked from the gradient generator. They are not + /// taken into account for computing gradients. + /// + /// This is useful any time you want to compute a value with TensorFlow but need + /// to pretend that the value was a constant. For example, the softmax function + /// for a vector x can be written as + /// + /// ```python + /// + /// def softmax(x): + /// numerator = tf.exp(x) + /// denominator = tf.reduce_sum(numerator) + /// return numerator / denominator + /// ``` + /// + /// This however is susceptible to overflow if the values in x are large. An + /// alternative more stable way is to subtract the maximum of x from each of the + /// values. + /// + /// ```python + /// + /// def stable_softmax(x): + /// z = x - tf.reduce_max(x) + /// numerator = tf.exp(z) + /// denominator = tf.reduce_sum(numerator) + /// return numerator / denominator + /// ``` + /// + /// However, when we backprop through the softmax to x, we dont want to backprop + /// through the `tf.reduce_max(x)` (if the max values are not unique then the + /// gradient could flow to the wrong input) calculation and treat that as a + /// constant. Therefore, we should write this out as + /// + /// ```python + /// + /// def stable_softmax(x): + /// z = x - tf.stop_gradient(tf.reduce_max(x)) + /// numerator = tf.exp(z) + /// denominator = tf.reduce_sum(numerator) + /// return numerator / denominator + /// ``` + /// + /// Some other examples include: + /// + /// * The *EM* algorithm where the *M-step* should not involve backpropagation + /// through the output of the *E-step*. + /// * Contrastive divergence training of Boltzmann machines where, when + /// differentiating the energy function, the training must not backpropagate + /// through the graph that generated the samples from the model. + /// * Adversarial training, where no backprop should happen through the adversarial + /// example generation process. + /// + /// + /// + /// + public static Tensor stop_gradient(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StopGradient", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return stop_gradient_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("StopGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("StopGradient", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor stop_gradient_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("StopGradient", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StopGradient", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return a strided slice from `input`. + /// + /// + /// + /// Note, most python users will want to use the Python `Tensor.__getitem__` + /// or `Variable.__getitem__` rather than this op directly. + /// + /// The goal of this op is to produce a new tensor with a subset of + /// the elements from the `n` dimensional `input` tensor. The subset is chosen using + /// a sequence of `m` sparse range specifications encoded into the arguments + /// of this function. Note, in some cases + /// `m` could be equal to `n`, but this need not be the case. Each + /// range specification entry can be one of the following: + /// + /// - An ellipsis (...). Ellipses are used to imply zero or more + /// dimensions of full-dimension selection and are produced using + /// `ellipsis_mask`. For example, `foo[...]` is the identity slice. + /// + /// - A new axis. This is used to insert a new shape=1 dimension and is + /// produced using `new_axis_mask`. For example, `foo[:, ...]` where + /// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + /// + /// + /// - A range `begin:end:stride`. This is used to specify how much to choose from + /// a given dimension. `stride` can be any integer but 0. `begin` is an integer + /// which represents the index of the first value to select while `end` represents + /// the index of the last value to select. The number of values selected in each + /// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. + /// `begin` and `end` can be negative where `-1` is the last element, `-2` is + /// the second to last. `begin_mask` controls whether to replace the explicitly + /// given `begin` with an implicit effective value of `0` if `stride > 0` and + /// `-1` if `stride < 0`. `end_mask` is analogous but produces the number + /// required to create the largest open interval. For example, given a shape + /// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do + /// not assume this is equivalent to `foo[0:-1]` which has an effective `begin` + /// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the + /// first dimension of a tensor while dropping the last two (in the original + /// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + /// + /// - A single index. This is used to keep only elements that have a given + /// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a + /// shape `(6,)` tensor. This is encoded in `begin` and `end` and + /// `shrink_axis_mask`. + /// + /// Each conceptual range specification is encoded in the op's argument. This + /// encoding is best understand by considering a non-trivial example. In + /// particular, + /// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as + /// + /// ``` + /// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) + /// end = [2, 4, x, x, -3, x] + /// strides = [1, 1, x, x, -1, 1] + /// begin_mask = 1<<4 | 1<<5 = 48 + /// end_mask = 1<<5 = 32 + /// ellipsis_mask = 1<<3 = 8 + /// new_axis_mask = 1<<2 = 4 + /// shrink_axis_mask = 1<<0 = 1 + /// ``` + /// + /// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of + /// the slice becomes (2, 1, 5, 5, 2, 5). + /// Let us walk step by step through each argument specification. + /// + /// 1. The first argument in the example slice is turned into `begin = 1` and + /// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we + /// also set the appropriate bit in `shrink_axis_mask`. + /// + /// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have + /// zero bits contributed. + /// + /// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 + /// dimension in the final shape. Dummy values are contributed to begin, + /// end and stride, while the new_axis_mask bit is set. + /// + /// 4. `...` grab the full ranges from as many dimensions as needed to + /// fully specify a slice for every dimension of the input shape. + /// + /// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated + /// with a dimension that has shape `s` is converted to a positive index + /// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion + /// is done internally so begin, end and strides receive x, -3, and -1. + /// The appropriate begin_mask bit is set to indicate the start range is the + /// full range (ignoring the x). + /// + /// 6. `:` indicates that the entire contents of the corresponding dimension + /// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides + /// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and + /// `end_mask` are also set. + /// + /// *Requirements*: + /// `0 != strides[i] for i in [0, m)` + /// `ellipsis_mask must be a power of two (only one ellipsis)` + /// + /// + /// + /// + /// + /// + /// + /// + /// a bitmask where a bit i being 1 means to ignore the begin + /// value and instead use the largest interval possible. At runtime + /// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or + /// `[-1, n-1]` if `stride[i] < 0` + /// + /// + /// + /// + /// analogous to `begin_mask` + /// + /// + /// + /// + /// a bitmask where bit `i` being 1 means the `i`th + /// position is actually an ellipsis. One bit at most can be 1. + /// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)` + /// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis + /// implicitly creates as many range specifications as necessary to fully + /// specify the sliced range for every dimension. For example for a 4-dimensional + /// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`. + /// + /// + /// + /// + /// a bitmask where bit `i` being 1 means the `i`th + /// specification creates a new shape 1 dimension. For example + /// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. + /// + /// + /// + /// + /// a bitmask where bit `i` implies that the `i`th + /// specification should shrink the dimensionality. begin and end + /// must imply a slice of size 1 in the dimension. For example in + /// python one might do `foo[:, 3, :]` which would result in + /// `shrink_axis_mask` being 2. + /// + /// + /// + public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StridedSlice", name) { args = new object[] { input, begin, end, strides }, attrs = new Dictionary() { ["begin_mask"] = begin_mask, ["end_mask"] = end_mask, ["ellipsis_mask"] = ellipsis_mask, ["new_axis_mask"] = new_axis_mask, ["shrink_axis_mask"] = shrink_axis_mask } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return strided_slice_eager_fallback(input, begin, end, strides, begin_mask: begin_mask, end_mask: end_mask, ellipsis_mask: ellipsis_mask, new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["begin"] = begin; + keywords["end"] = end; + keywords["strides"] = strides; + keywords["begin_mask"] = begin_mask; + keywords["end_mask"] = end_mask; + keywords["ellipsis_mask"] = ellipsis_mask; + keywords["new_axis_mask"] = new_axis_mask; + keywords["shrink_axis_mask"] = shrink_axis_mask; + var _op = tf.OpDefLib._apply_op_helper("StridedSlice", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index"), "begin_mask", _op._get_attr_int("begin_mask"), "end_mask", _op._get_attr_int("end_mask"), "ellipsis_mask", _op._get_attr_int("ellipsis_mask"), "new_axis_mask", _op._get_attr_int("new_axis_mask"), "shrink_axis_mask", _op._get_attr_int("shrink_axis_mask") }; + _execute.record_gradient("StridedSlice", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor strided_slice_eager_fallback(Tensor input, Tensor begin, Tensor end, Tensor strides, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, begin, end, strides }; + object[] _attrs = new object[] { "T", input.dtype, "Index", begin.dtype, "begin_mask", begin_mask, "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask }; + var _result = _execute.execute("StridedSlice", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StridedSlice", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Assign `value` to the sliced l-value reference of `ref`. + /// + /// + /// + /// The values of `value` are assigned to the positions in the variable + /// `ref` that are selected by the slice parameters. The slice parameters + /// `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + /// + /// NOTE this op currently does not support broadcasting and so `value`'s + /// shape must be exactly the shape produced by the slice of `ref`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor strided_slice_assign(Tensor ref_, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("strided_slice_assign op does not support eager execution. Arg ref is a ref."); + } + Dictionary keywords = new(); + keywords["ref"] = ref_; + keywords["begin"] = begin; + keywords["end"] = end; + keywords["strides"] = strides; + keywords["value"] = value; + keywords["begin_mask"] = begin_mask; + keywords["end_mask"] = end_mask; + keywords["ellipsis_mask"] = ellipsis_mask; + keywords["new_axis_mask"] = new_axis_mask; + keywords["shrink_axis_mask"] = shrink_axis_mask; + var _op = tf.OpDefLib._apply_op_helper("StridedSliceAssign", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index"), "begin_mask", _op._get_attr_int("begin_mask"), "end_mask", _op._get_attr_int("end_mask"), "ellipsis_mask", _op._get_attr_int("ellipsis_mask"), "new_axis_mask", _op._get_attr_int("new_axis_mask"), "shrink_axis_mask", _op._get_attr_int("shrink_axis_mask") }; + _execute.record_gradient("StridedSliceAssign", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor strided_slice_assign_eager_fallback(Tensor ref_, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, string name, Context ctx) + { + throw new RuntimeError($"strided_slice_assign op does not support eager execution. Arg 'ref' is a ref."); + } + /// + /// Returns the gradient of `StridedSlice`. + /// + /// + /// + /// Since `StridedSlice` cuts out pieces of its `input` which is size + /// `shape`, its gradient will have the same shape (which is passed here + /// as `shape`). The gradient will be zero in any element that the slice + /// does not select. + /// + /// Arguments are the same as StridedSliceGrad with the exception that + /// `dy` is the input gradient to be propagated and `shape` is the + /// shape of `StridedSlice`'s `input`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StridedSliceGrad", name) { args = new object[] { shape, begin, end, strides, dy }, attrs = new Dictionary() { ["begin_mask"] = begin_mask, ["end_mask"] = end_mask, ["ellipsis_mask"] = ellipsis_mask, ["new_axis_mask"] = new_axis_mask, ["shrink_axis_mask"] = shrink_axis_mask } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return strided_slice_grad_eager_fallback(shape, begin, end, strides, dy, begin_mask: begin_mask, end_mask: end_mask, ellipsis_mask: ellipsis_mask, new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["shape"] = shape; + keywords["begin"] = begin; + keywords["end"] = end; + keywords["strides"] = strides; + keywords["dy"] = dy; + keywords["begin_mask"] = begin_mask; + keywords["end_mask"] = end_mask; + keywords["ellipsis_mask"] = ellipsis_mask; + keywords["new_axis_mask"] = new_axis_mask; + keywords["shrink_axis_mask"] = shrink_axis_mask; + var _op = tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index"), "begin_mask", _op._get_attr_int("begin_mask"), "end_mask", _op._get_attr_int("end_mask"), "ellipsis_mask", _op._get_attr_int("ellipsis_mask"), "new_axis_mask", _op._get_attr_int("new_axis_mask"), "shrink_axis_mask", _op._get_attr_int("shrink_axis_mask") }; + _execute.record_gradient("StridedSliceGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor strided_slice_grad_eager_fallback(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { shape, begin, end, strides, dy }; + object[] _attrs = new object[] { "T", dy.dtype, "Index", shape.dtype, "begin_mask", begin_mask, "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask }; + var _result = _execute.execute("StridedSliceGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StridedSliceGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Adds sparse `updates` to an existing tensor according to `indices`. + /// + /// + /// + /// This operation creates a new tensor by adding sparse `updates` to the passed + /// in `tensor`. + /// This operation is very similar to `tf.compat.v1.scatter_nd_add`, except that the + /// updates are added onto an existing tensor (as opposed to a variable). If the + /// memory for the existing tensor cannot be re-used, a copy is made and updated. + /// + /// `indices` is an integer tensor containing indices into a new tensor of shape + /// `tensor.shape`. The last dimension of `indices` can be at most the rank of + /// `tensor.shape`: + /// + /// ``` + /// indices.shape[-1] <= tensor.shape.rank + /// ``` + /// + /// The last dimension of `indices` corresponds to indices into elements + /// (if `indices.shape[-1] = tensor.shape.rank`) or slices + /// (if `indices.shape[-1] < tensor.shape.rank`) along dimension + /// `indices.shape[-1]` of `tensor.shape`. `updates` is a tensor with shape + /// + /// ``` + /// indices.shape[:-1] + tensor.shape[indices.shape[-1]:] + /// ``` + /// + /// The simplest form of `tensor_scatter_nd_add` is to add individual elements to a + /// tensor by index. For example, say we want to add 4 elements in a rank-1 + /// tensor with 8 elements. + /// + /// In Python, this scatter add operation would look like this: + /// + /// >>> indices = tf.constant([[4], [3], [1], [7]]) + /// >>> updates = tf.constant([9, 10, 11, 12]) + /// >>> tensor = tf.ones([8], dtype=tf.int32) + /// >>> updated = tf.tensor_scatter_nd_add(tensor, indices, updates) + /// >>> updated + /// + /// + /// We can also, insert entire slices of a higher rank tensor all at once. For + /// example, if we wanted to insert two slices in the first dimension of a + /// rank-3 tensor with two matrices of new values. + /// + /// In Python, this scatter add operation would look like this: + /// + /// >>> indices = tf.constant([[0], [2]]) + /// >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + /// ... [7, 7, 7, 7], [8, 8, 8, 8]], + /// ... [[5, 5, 5, 5], [6, 6, 6, 6], + /// ... [7, 7, 7, 7], [8, 8, 8, 8]]]) + /// >>> tensor = tf.ones([4, 4, 4],dtype=tf.int32) + /// >>> updated = tf.tensor_scatter_nd_add(tensor, indices, updates) + /// >>> updated + /// + /// + /// Note: on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, the index is ignored. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_scatter_add(Tensor tensor, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorScatterAdd", name) { args = new object[] { tensor, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_scatter_add_eager_fallback(tensor, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("TensorScatterAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("TensorScatterAdd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_scatter_add_eager_fallback(Tensor tensor, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, updates }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("TensorScatterAdd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorScatterAdd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Apply a sparse update to a tensor taking the element-wise maximum. + /// + /// + /// + /// Returns a new tensor copied from `tensor` whose values are element-wise maximum between + /// tensor and updates according to the indices. + /// + /// >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0] + /// >>> indices = [[1], [4], [5]] + /// >>> updates = [1, -1, 1] + /// >>> tf.tensor_scatter_nd_max(tensor, indices, updates).numpy() + /// array([0, 1, 0, 0, 0, 1, 0, 0], dtype=int32) + /// + /// Refer to `tf.tensor_scatter_nd_update` for more details. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_scatter_max(Tensor tensor, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorScatterMax", name) { args = new object[] { tensor, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_scatter_max_eager_fallback(tensor, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("TensorScatterMax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("TensorScatterMax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_scatter_max_eager_fallback(Tensor tensor, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, updates }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("TensorScatterMax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorScatterMax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_scatter_min(Tensor tensor, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorScatterMin", name) { args = new object[] { tensor, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_scatter_min_eager_fallback(tensor, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("TensorScatterMin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("TensorScatterMin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_scatter_min_eager_fallback(Tensor tensor, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, updates }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("TensorScatterMin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorScatterMin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Subtracts sparse `updates` from an existing tensor according to `indices`. + /// + /// + /// + /// This operation creates a new tensor by subtracting sparse `updates` from the + /// passed in `tensor`. + /// This operation is very similar to `tf.scatter_nd_sub`, except that the updates + /// are subtracted from an existing tensor (as opposed to a variable). If the memory + /// for the existing tensor cannot be re-used, a copy is made and updated. + /// + /// `indices` is an integer tensor containing indices into a new tensor of shape + /// `shape`. The last dimension of `indices` can be at most the rank of `shape`: + /// + /// indices.shape[-1] <= shape.rank + /// + /// The last dimension of `indices` corresponds to indices into elements + /// (if `indices.shape[-1] = shape.rank`) or slices + /// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of + /// `shape`. `updates` is a tensor with shape + /// + /// indices.shape[:-1] + shape[indices.shape[-1]:] + /// + /// The simplest form of tensor_scatter_sub is to subtract individual elements + /// from a tensor by index. For example, say we want to insert 4 scattered elements + /// in a rank-1 tensor with 8 elements. + /// + /// In Python, this scatter subtract operation would look like this: + /// + /// ```python + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// tensor = tf.ones([8], dtype=tf.int32) + /// updated = tf.tensor_scatter_nd_sub(tensor, indices, updates) + /// print(updated) + /// ``` + /// + /// The resulting tensor would look like this: + /// + /// [1, -10, 1, -9, -8, 1, 1, -11] + /// + /// We can also, insert entire slices of a higher rank tensor all at once. For + /// example, if we wanted to insert two slices in the first dimension of a + /// rank-3 tensor with two matrices of new values. + /// + /// In Python, this scatter add operation would look like this: + /// + /// ```python + /// indices = tf.constant([[0], [2]]) + /// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]]]) + /// tensor = tf.ones([4, 4, 4],dtype=tf.int32) + /// updated = tf.tensor_scatter_nd_sub(tensor, indices, updates) + /// print(updated) + /// ``` + /// + /// The resulting tensor would look like this: + /// + /// [[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], + /// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], + /// [[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], + /// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, the index is ignored. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_scatter_sub(Tensor tensor, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorScatterSub", name) { args = new object[] { tensor, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_scatter_sub_eager_fallback(tensor, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("TensorScatterSub", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("TensorScatterSub", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_scatter_sub_eager_fallback(Tensor tensor, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, updates }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("TensorScatterSub", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorScatterSub", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Scatter `updates` into an existing tensor according to `indices`. + /// + /// + /// + /// This operation creates a new tensor by applying sparse `updates` to the passed + /// in `tensor`. + /// This operation is very similar to `tf.scatter_nd`, except that the updates are + /// scattered onto an existing tensor (as opposed to a zero-tensor). If the memory + /// for the existing tensor cannot be re-used, a copy is made and updated. + /// + /// If `indices` contains duplicates, then we pick the last update for the index. + /// + /// If an out of bound index is found on CPU, an error is returned. + /// + /// **WARNING**: There are some GPU specific semantics for this operation. + /// - If an out of bound index is found, the index is ignored. + /// - The order in which updates are applied is nondeterministic, so the output + /// will be nondeterministic if `indices` contains duplicates. + /// + /// `indices` is an integer tensor containing indices into a new tensor of shape + /// `shape`. + /// + /// * `indices` must have at least 2 axes: `(num_updates, index_depth)`. + /// * The last axis of `indices` is how deep to index into `tensor` so this index + /// depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim` + /// + /// if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements. + /// if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input + /// `tensor`. + /// + /// Each `update` has a rank of `tensor.rank - indices.shape[-1]`. + /// The overall shape of `updates` is: + /// + /// ``` + /// indices.shape[:-1] + tensor.shape[indices.shape[-1]:] + /// ``` + /// + /// For usage examples see the python [tf.tensor_scatter_nd_update]( + /// https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_scatter_update(Tensor tensor, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorScatterUpdate", name) { args = new object[] { tensor, indices, updates }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_scatter_update_eager_fallback(tensor, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("TensorScatterUpdate", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("TensorScatterUpdate", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_scatter_update_eager_fallback(Tensor tensor, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, updates }; + object[] _attrs = new object[] { "T", tensor.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("TensorScatterUpdate", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorScatterUpdate", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Assign `value` to the sliced l-value reference of `input`. + /// + /// + /// + /// The values of `value` are assigned to the positions in the tensor `input` that + /// are selected by the slice parameters. The slice parameters `begin` `end` + /// `strides` etc. work exactly as in `StridedSlice`. + /// + /// NOTE this op currently does not support broadcasting and so `value`'s shape + /// must be exactly the shape produced by the slice of `input`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_strided_slice_update(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorStridedSliceUpdate", name) { args = new object[] { input, begin, end, strides, value }, attrs = new Dictionary() { ["begin_mask"] = begin_mask, ["end_mask"] = end_mask, ["ellipsis_mask"] = ellipsis_mask, ["new_axis_mask"] = new_axis_mask, ["shrink_axis_mask"] = shrink_axis_mask } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_strided_slice_update_eager_fallback(input, begin, end, strides, value, begin_mask: begin_mask, end_mask: end_mask, ellipsis_mask: ellipsis_mask, new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["begin"] = begin; + keywords["end"] = end; + keywords["strides"] = strides; + keywords["value"] = value; + keywords["begin_mask"] = begin_mask; + keywords["end_mask"] = end_mask; + keywords["ellipsis_mask"] = ellipsis_mask; + keywords["new_axis_mask"] = new_axis_mask; + keywords["shrink_axis_mask"] = shrink_axis_mask; + var _op = tf.OpDefLib._apply_op_helper("TensorStridedSliceUpdate", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Index", _op._get_attr_type("Index"), "begin_mask", _op._get_attr_int("begin_mask"), "end_mask", _op._get_attr_int("end_mask"), "ellipsis_mask", _op._get_attr_int("ellipsis_mask"), "new_axis_mask", _op._get_attr_int("new_axis_mask"), "shrink_axis_mask", _op._get_attr_int("shrink_axis_mask") }; + _execute.record_gradient("TensorStridedSliceUpdate", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_strided_slice_update_eager_fallback(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, begin, end, strides, value }; + object[] _attrs = new object[] { "T", input.dtype, "Index", begin.dtype, "begin_mask", begin_mask, "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask }; + var _result = _execute.execute("TensorStridedSliceUpdate", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorStridedSliceUpdate", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Constructs a tensor by tiling a given tensor. + /// + /// + /// + /// This operation creates a new tensor by replicating `input` `multiples` times. + /// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, + /// and the values of `input` are replicated `multiples[i]` times along the 'i'th + /// dimension. For example, tiling `[a b c d]` by `[2]` produces + /// `[a b c d a b c d]`. + /// + /// >>> a = tf.constant([[1,2,3],[4,5,6]], tf.int32) + /// >>> b = tf.constant([1,2], tf.int32) + /// >>> tf.tile(a, b) + /// + /// >>> c = tf.constant([2,1], tf.int32) + /// >>> tf.tile(a, c) + /// + /// >>> d = tf.constant([2,2], tf.int32) + /// >>> tf.tile(a, d) + /// + /// + /// + /// + /// + /// + public static Tensor tile(Tensor input, Tensor multiples, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Tile", name) { args = new object[] { input, multiples }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tile_eager_fallback(input, multiples, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["multiples"] = multiples; + var _op = tf.OpDefLib._apply_op_helper("Tile", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tmultiples", _op._get_attr_type("Tmultiples") }; + _execute.record_gradient("Tile", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tile_eager_fallback(Tensor input, Tensor multiples, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, multiples }; + object[] _attrs = new object[] { "T", input.dtype, "Tmultiples", multiples.dtype }; + var _result = _execute.execute("Tile", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Tile", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the gradient of `Tile`. + /// + /// + /// + /// Since `Tile` takes an input and repeats the input `multiples` times + /// along each dimension, `TileGrad` takes in `multiples` and aggregates + /// each repeated tile of `input` into `output`. + /// + /// + /// + /// + /// + public static Tensor tile_grad(Tensor input, Tensor multiples, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TileGrad", name) { args = new object[] { input, multiples }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tile_grad_eager_fallback(input, multiples, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["multiples"] = multiples; + var _op = tf.OpDefLib._apply_op_helper("TileGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("TileGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tile_grad_eager_fallback(Tensor input, Tensor multiples, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, multiples }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("TileGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TileGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Shuffle dimensions of x according to a permutation. + /// + /// + /// + /// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + /// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + /// + /// + /// + /// + /// + public static Tensor transpose(Tensor x, Tensor perm, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Transpose", name) { args = new object[] { x, perm }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return transpose_eager_fallback(x, perm, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["perm"] = perm; + var _op = tf.OpDefLib._apply_op_helper("Transpose", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tperm", _op._get_attr_type("Tperm") }; + _execute.record_gradient("Transpose", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor transpose_eager_fallback(Tensor x, Tensor perm, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, perm }; + object[] _attrs = new object[] { "T", x.dtype, "Tperm", perm.dtype }; + var _result = _execute.execute("Transpose", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Transpose", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// + /// This operation returns a tensor `y` containing all of the unique elements of `x` + /// sorted in the same order that they occur in `x`; `x` does not need to be sorted. + /// This operation also returns a tensor `idx` the same size as `x` that contains + /// the index of each value of `x` in the unique output `y`. In other words: + /// + /// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + /// + /// Examples: + /// + /// ``` + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx = unique(x) + /// y ==> [1, 2, 4, 7, 8] + /// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// ``` + /// + /// ``` + /// # tensor 'x' is [4, 5, 1, 2, 3, 3, 4, 5] + /// y, idx = unique(x) + /// y ==> [4, 5, 1, 2, 3] + /// idx ==> [0, 1, 2, 3, 4, 4, 0, 1] + /// ``` + /// + /// + /// + /// + /// + public static Tensor[] unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Unique", name) { args = new object[] { x }, attrs = new Dictionary() { ["out_idx"] = out_idx } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unique_eager_fallback(x, out_idx: out_idx, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["out_idx"] = out_idx; + var _op = tf.OpDefLib._apply_op_helper("Unique", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_idx", _op._get_attr_type("out_idx") }; + _execute.record_gradient("Unique", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] unique_eager_fallback(Tensor x, TF_DataType out_idx, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype, "out_idx", out_idx }; + var _result = _execute.execute("Unique", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Unique", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds unique elements along an axis of a tensor. + /// + /// + /// + /// This operation either returns a tensor `y` containing unique elements + /// along the `axis` of a tensor. The returned unique elements is sorted + /// in the same order as they occur along `axis` in `x`. + /// This operation also returns a tensor `idx` that is the same size as + /// the number of the elements in `x` along the `axis` dimension. It + /// contains the index in the unique output `y`. + /// In other words, for an `1-D` tensor `x` with `axis = None: + /// + /// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + /// + /// For example: + /// + /// ``` + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx = unique(x) + /// y ==> [1, 2, 4, 7, 8] + /// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// ``` + /// + /// For an `2-D` tensor `x` with `axis = 0`: + /// + /// ``` + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx = unique(x, axis=0) + /// y ==> [[1, 0, 0], + /// [2, 0, 0]] + /// idx ==> [0, 0, 1] + /// ``` + /// + /// For an `2-D` tensor `x` with `axis = 1`: + /// + /// ``` + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx = unique(x, axis=1) + /// y ==> [[1, 0], + /// [1, 0], + /// [2, 0]] + /// idx ==> [0, 1, 1] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor[] unique_v2(Tensor x, Tensor axis, TF_DataType out_idx = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UniqueV2", name) { args = new object[] { x, axis }, attrs = new Dictionary() { ["out_idx"] = out_idx } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unique_v2_eager_fallback(x, axis, out_idx: out_idx, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["axis"] = axis; + keywords["out_idx"] = out_idx; + var _op = tf.OpDefLib._apply_op_helper("UniqueV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Taxis", _op._get_attr_type("Taxis"), "out_idx", _op._get_attr_type("out_idx") }; + _execute.record_gradient("UniqueV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] unique_v2_eager_fallback(Tensor x, Tensor axis, TF_DataType out_idx, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, axis }; + object[] _attrs = new object[] { "T", x.dtype, "Taxis", axis.dtype, "out_idx", out_idx }; + var _result = _execute.execute("UniqueV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UniqueV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// + /// This operation returns a tensor `y` containing all of the unique elements of `x` + /// sorted in the same order that they occur in `x`. This operation also returns a + /// tensor `idx` the same size as `x` that contains the index of each value of `x` + /// in the unique output `y`. Finally, it returns a third tensor `count` that + /// contains the count of each element of `y` in `x`. In other words: + /// + /// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + /// + /// For example: + /// + /// ``` + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx, count = unique_with_counts(x) + /// y ==> [1, 2, 4, 7, 8] + /// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// count ==> [2, 1, 3, 1, 2] + /// ``` + /// + /// + /// + /// + /// + public static Tensor[] unique_with_counts(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UniqueWithCounts", name) { args = new object[] { x }, attrs = new Dictionary() { ["out_idx"] = out_idx } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unique_with_counts_eager_fallback(x, out_idx: out_idx, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["out_idx"] = out_idx; + var _op = tf.OpDefLib._apply_op_helper("UniqueWithCounts", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_idx", _op._get_attr_type("out_idx") }; + _execute.record_gradient("UniqueWithCounts", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] unique_with_counts_eager_fallback(Tensor x, TF_DataType out_idx, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype, "out_idx", out_idx }; + var _result = _execute.execute("UniqueWithCounts", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UniqueWithCounts", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds unique elements along an axis of a tensor. + /// + /// + /// + /// This operation either returns a tensor `y` containing unique elements + /// along the `axis` of a tensor. The returned unique elements is sorted + /// in the same order as they occur along `axis` in `x`. + /// This operation also returns a tensor `idx` and a tensor `count` + /// that are the same size as the number of the elements in `x` along the + /// `axis` dimension. The `idx` contains the index in the unique output `y` + /// and the `count` contains the count in the unique output `y`. + /// In other words, for an `1-D` tensor `x` with `axis = None: + /// + /// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + /// + /// For example: + /// + /// ``` + /// x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8]) + /// y, idx, count = UniqueWithCountsV2(x, axis = [0]) + /// y ==> [1, 2, 4, 7, 8] + /// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// count ==> [2, 1, 3, 1, 2] + /// ``` + /// + /// For a `2-D` tensor `x` with `axis = 0`: + /// + /// ``` + /// x = tf.constant([[1, 0, 0], + /// [1, 0, 0], + /// [2, 0, 0]]) + /// y, idx, count = UniqueWithCountsV2(x, axis=[0]) + /// y ==> [[1, 0, 0], + /// [2, 0, 0]] + /// idx ==> [0, 0, 1] + /// count ==> [2, 1] + /// ``` + /// + /// For a `2-D` tensor `x` with `axis = 1`: + /// + /// ``` + /// x = tf.constant([[1, 0, 0], + /// [1, 0, 0], + /// [2, 0, 0]]) + /// y, idx, count = UniqueWithCountsV2(x, axis=[1]) + /// y ==> [[1, 0], + /// [1, 0], + /// [2, 0]] + /// idx ==> [0, 1, 1] + /// count ==> [1, 2] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor[] unique_with_counts_v2(Tensor x, Tensor axis, TF_DataType out_idx = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UniqueWithCountsV2", name) { args = new object[] { x, axis }, attrs = new Dictionary() { ["out_idx"] = out_idx } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unique_with_counts_v2_eager_fallback(x, axis, out_idx: out_idx, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["axis"] = axis; + keywords["out_idx"] = out_idx; + var _op = tf.OpDefLib._apply_op_helper("UniqueWithCountsV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Taxis", _op._get_attr_type("Taxis"), "out_idx", _op._get_attr_type("out_idx") }; + _execute.record_gradient("UniqueWithCountsV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] unique_with_counts_v2_eager_fallback(Tensor x, Tensor axis, TF_DataType out_idx, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, axis }; + object[] _attrs = new object[] { "T", x.dtype, "Taxis", axis.dtype, "out_idx", out_idx }; + var _result = _execute.execute("UniqueWithCountsV2", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UniqueWithCountsV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. + /// + /// + /// + /// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. + /// For example, given a tensor of shape `(A, B, C, D)`; + /// + /// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` + /// and each tensor in `output` will have shape `(B, C, D)`. (Note that the + /// dimension unpacked along is gone, unlike `split`). + /// + /// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` + /// and each tensor in `output` will have shape `(A, C, D)`. + /// Etc. + /// + /// This is the opposite of `pack`. + /// + /// + /// + /// + /// + /// + /// Dimension along which to unpack. Negative values wrap around, so the + /// valid range is `[-R, R)`. + /// + /// + /// + public static Tensor[] unpack(Tensor value, int num = 0, int axis = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Unpack", name) { args = new object[] { value }, attrs = new Dictionary() { ["num"] = num, ["axis"] = axis } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unpack_eager_fallback(value, num: num, axis: axis, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["num"] = num; + keywords["axis"] = axis; + var _op = tf.OpDefLib._apply_op_helper("Unpack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "num", _op._get_attr_int("num"), "T", _op._get_attr_type("T"), "axis", _op._get_attr_int("axis") }; + _execute.record_gradient("Unpack", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] unpack_eager_fallback(Tensor value, int num, int axis, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value }; + object[] _attrs = new object[] { "num", num, "T", value.dtype, "axis", axis }; + var _result = _execute.execute("Unpack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Unpack", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Converts an array of flat indices into a tuple of coordinate arrays. + /// + /// + /// + /// + /// Example: + /// + /// ``` + /// y = tf.unravel_index(indices=[2, 5, 7], dims=[3, 3]) + /// # 'dims' represent a hypothetical (3, 3) tensor of indices: + /// # [[0, 1, *2*], + /// # [3, 4, *5*], + /// # [6, *7*, 8]] + /// # For each entry from 'indices', this operation returns + /// # its coordinates (marked with '*'), such as + /// # 2 ==> (0, 2) + /// # 5 ==> (1, 2) + /// # 7 ==> (2, 1) + /// y ==> [[0, 1, 2], [2, 2, 1]] + /// ``` + /// + /// @compatibility(numpy) + /// Equivalent to np.unravel_index + /// @end_compatibility + /// + /// + /// + /// + /// + public static Tensor unravel_index(Tensor indices, Tensor dims, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UnravelIndex", name) { args = new object[] { indices, dims }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unravel_index_eager_fallback(indices, dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["indices"] = indices; + keywords["dims"] = dims; + var _op = tf.OpDefLib._apply_op_helper("UnravelIndex", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("UnravelIndex", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor unravel_index_eager_fallback(Tensor indices, Tensor dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { indices, dims }; + object[] _attrs = new object[] { "Tidx", indices.dtype }; + var _result = _execute.execute("UnravelIndex", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UnravelIndex", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Applies upper_bound(sorted_search_values, values) along each row. + /// + /// + /// + /// Each set of rows with the same index in (sorted_inputs, values) is treated + /// independently. The resulting row is the equivalent of calling + /// `np.searchsorted(sorted_inputs, values, side='right')`. + /// + /// The result is not a global index to the entire + /// `Tensor`, but rather just the index in the last dimension. + /// + /// A 2-D example: + /// sorted_sequence = [[0, 3, 9, 9, 10], + /// [1, 2, 3, 4, 5]] + /// values = [[2, 4, 9], + /// [0, 2, 6]] + /// + /// result = UpperBound(sorted_sequence, values) + /// + /// result == [[1, 2, 4], + /// [0, 2, 5]] + /// + /// + /// + /// + /// + /// + public static Tensor upper_bound(Tensor sorted_inputs, Tensor values, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UpperBound", name) { args = new object[] { sorted_inputs, values }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return upper_bound_eager_fallback(sorted_inputs, values, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["sorted_inputs"] = sorted_inputs; + keywords["values"] = values; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("UpperBound", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("UpperBound", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor upper_bound_eager_fallback(Tensor sorted_inputs, Tensor values, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { sorted_inputs, values }; + object[] _attrs = new object[] { "T", sorted_inputs.dtype, "out_type", out_type }; + var _result = _execute.execute("UpperBound", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UpperBound", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns locations of nonzero / true values in a tensor. + /// + /// + /// + /// This operation returns the coordinates of true elements in `input`. The + /// coordinates are returned in a 2-D tensor where the first dimension (rows) + /// represents the number of true elements, and the second dimension (columns) + /// represents the coordinates of the true elements. Keep in mind, the shape of + /// the output tensor can vary depending on how many true values there are in + /// `input`. Indices are output in row-major order. + /// + /// For example: + /// + /// ``` + /// # 'input' tensor is [[True, False] + /// # [True, False]] + /// # 'input' has two true values, so output has two coordinates. + /// # 'input' has rank of 2, so coordinates have two indices. + /// where(input) ==> [[0, 0], + /// [1, 0]] + /// + /// # `input` tensor is [[[True, False] + /// # [True, False]] + /// # [[False, True] + /// # [False, True]] + /// # [[False, False] + /// # [False, True]]] + /// # 'input' has 5 true values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==> [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// + /// # `input` tensor is [[[1.5, 0.0] + /// # [-0.5, 0.0]] + /// # [[0.0, 0.25] + /// # [0.0, 0.75]] + /// # [[0.0, 0.0] + /// # [0.0, 0.01]]] + /// # 'input' has 5 nonzero values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==> [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// + /// # `input` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] + /// # [0.0 + 0.5j, 0.0 + 0.0j]] + /// # [[0.0 + 0.0j, 0.25 + 1.5j] + /// # [0.0 + 0.0j, 0.75 + 0.0j]] + /// # [[0.0 + 0.0j, 0.0 + 0.0j] + /// # [0.0 + 0.0j, 0.01 + 0.0j]]] + /// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==> [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// ``` + /// + /// + /// + /// + public static Tensor where(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Where", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return where_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("Where", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Where", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor where_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("Where", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Where", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns a tensor of zeros with the same shape and type as x. + /// + /// + /// + public static Tensor zeros_like(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ZerosLike", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return zeros_like_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("ZerosLike", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ZerosLike", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor zeros_like_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("ZerosLike", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ZerosLike", _inputs_flat, _attrs, _result); } + return _result[0]; } } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs new file mode 100644 index 000000000..2901e5fcc --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -0,0 +1,177 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_control_flow_ops + { + public static Operation control_trigger(string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("ControlTrigger", name, new + { + }); + + return _op; + } + + /// + /// Creates or finds a child frame, and makes `data` available to the child frame. + /// + /// + /// + /// + /// + /// + /// + public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("Enter", name, new + { + data, + frame_name, + is_constant, + parallel_iterations + }); + + return _op.output; + } + + /// + /// Forwards the input to the output. + /// + /// + /// + /// + public static Tensor loop_cond(Tensor input, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("LoopCond", name, new { input }); + + return _op.output; + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// + /// + public static Tensor ref_next_iteration(Tensor data, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("RefNextIteration", name, new { data }); + + return _op; + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// + /// + public static Tensor next_iteration(Tensor data, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("NextIteration", name, new { data }); + + return _op; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// + /// + public static Tensor ref_exit(Tensor data, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("RefExit", name, new { data }); + + return _op; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// + /// + public static Tensor _exit(Tensor data, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("Exit", name, new { data }); + + return _op; + } + + public static Operation no_op(string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("NoOp", name, null); + + return _op; + } + + public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("RefSwitch", name, new { data, pred }); + return _op.outputs; + } + + /// + /// Forwards `data` to the output port determined by `pred`. + /// + /// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, + /// the data goes to `output_false`. + /// + /// See also `RefSwitch` and `Merge`. + /// + /// A `Tensor`. The tensor to be forwarded to the appropriate output. + /// A `Tensor` of type `bool`. + /// A scalar that specifies which output port will receive data. + /// + /// A name for the operation (optional). + /// A tuple of `Tensor` objects (output_false, output_true). + /// + /// output_false: A `Tensor`. Has the same type as `data`. + /// output_true: A `Tensor`. Has the same type as `data`. + /// + public static Tensor[] @switch(Tensor data, Tensor pred, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("Switch", name, new { data, pred }); + var _inputs_flat = _op.inputs; +#pragma warning disable CS0219 // Variable is assigned but its value is never used + var _attrs = ("T", _op.get_attr("T")); +#pragma warning restore CS0219 // Variable is assigned but its value is never used + // TODO: missing original code + //_execute.record_gradient("Switch", _inputs_flat, _attrs, _result, name); + return new[] { _op.outputs[0], _op.outputs[1] }; + } + + public static MergeOutput ref_merge(Tensor[] inputs, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("RefMerge", name, new { inputs }); + + return new MergeOutput(_op.outputs); + } + + public static MergeOutput merge(Tensor[] inputs, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("Merge", name, new { inputs }); + + return new MergeOutput(_op.outputs); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs new file mode 100644 index 000000000..37ecbba83 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_ctc_ops.cs @@ -0,0 +1,38 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_ctc_ops + { + public static Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = "CTCGreedyDecoder") + { + var op = tf.OpDefLib._apply_op_helper("CTCGreedyDecoder", name: name, args: new + { + inputs, + sequence_length, + merge_repeated + }); + /*var decoded_indices = op.outputs[0]; + var decoded_values = op.outputs[1]; + var decoded_shape = op.outputs[2]; + var log_probability = op.outputs[3];*/ + return op.outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs new file mode 100644 index 000000000..4a6377285 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -0,0 +1,313 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_data_flow_ops + { + public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("DynamicStitch", name, new { indices, data }); + + return _op.output; + } + + public static Tensor[] dynamic_partition(Tensor data, Tensor partitions, int num_partitions, + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("DynamicPartition", name, new + { + data, + partitions, + num_partitions + }); + + return _op.outputs; + } + + public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, + Shape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + bool identical_element_shapes = false, string tensor_array_name = "", string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArrayV3", name, new + { + size, + dtype, + element_shape, + dynamic_size, + clear_after_read, + identical_element_shapes, + tensor_array_name + }); + + return (_op.outputs[0], _op.outputs[1]); + } + + public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, + Tensor flow_in, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArrayScatterV3", name, new + { + handle, + indices, + value, + flow_in + }); + + return _op.output; + } + + public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, Shape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("PaddingFIFOQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor fifo_queue_v2(TF_DataType[] component_types, Shape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("FIFOQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor priority_queue_v2(TF_DataType[] component_types, Shape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("PriorityQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor random_shuffle_queue_v2(TF_DataType[] component_types, Shape[] shapes, + int capacity = -1, int min_after_dequeue = 0, int seed = 0, int seed2 = 0, + string container = "", string shared_name = "", string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("RandomShuffleQueueV2", name, new + { + component_types, + shapes, + capacity, + min_after_dequeue, + seed, + seed2, + container, + shared_name + }); + + return _op.output; + } + + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueEnqueue", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueEnqueueV2", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Tensor[] queue_dequeue_v2(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueDequeueV2", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Tensor[] queue_dequeue(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueDequeue", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Operation queue_enqueue_many_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueEnqueueManyV2", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("QueueDequeueManyV2", name, new + { + handle, + n, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + /// + /// Read an element from the TensorArray into output `value`. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArrayReadV3", name, new + { + handle, + index, + flow_in, + dtype + }); + + return _op.output; + } + + public static Tensor tensor_array_write_v3(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArrayWriteV3", name, new + { + handle, + index, + value, + flow_in + }); + + return _op.output; + } + + public static Tensor tensor_array_size_v3(Tensor handle, Tensor flow_in, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArraySizeV3", name, new + { + handle, + flow_in + }); + + return _op.output; + } + + public static Tensor tensor_array_gather_v3(Tensor handle, Tensor indices, Tensor flow_in, + TF_DataType dtype, Shape element_shape = null, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("TensorArrayGatherV3", name, new + { + handle, + indices, + dtype, + element_shape, + flow_in + }); + + return _op.output; + } + + public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = "", + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("StackV2", name, new + { + max_size, + elem_type, + stack_name + }); + + return _op.output; + } + + public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool swap_memory = false, + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("StackPushV2", name, new + { + handle, + elem, + swap_memory + }); + + return _op.output; + } + + public static Tensor stack_pop_v2(Tensor handle, TF_DataType elem_type, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("StackPopV2", name, new + { + handle, + elem_type + }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs new file mode 100644 index 000000000..6ec426f58 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs @@ -0,0 +1,1089 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_functional_ops +{ + /// + /// An n-way switch statement which calls a single branch function. + /// + /// + /// + /// An n-way switch statement, implementing the following: + /// ``` + /// switch (branch_index) { + /// case 0: + /// output = branches[0](input); + /// break; + /// case 1: + /// output = branches[1](input); + /// break; + /// ... + /// case [[nbranches-1]]: + /// default: + /// output = branches[nbranches-1](input); + /// break; + /// } + /// ``` + /// + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A list of functions each of which takes 'inputs' and returns a list of + /// tensors, whose types are the same as what every other branch returns. + /// + /// + /// + /// + public static Tensor[] _case(Tensor branch_index, Tensors input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Case", name) { args = new object[] { branch_index, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["branches"] = branches, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return case_eager_fallback(branch_index, input, Tout: Tout, branches: branches, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["branch_index"] = branch_index; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["branches"] = branches; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("Case", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "branches", _op.get_attr("branches"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("Case", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] case_eager_fallback(Tensor branch_index, Tensor input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { branch_index, input }; + object[] _attrs = new object[] { "branches", branches, "output_shapes", output_shapes }; + var _result = _execute.execute("Case", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Case", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Return the index of device the op runs. + /// + /// + /// + /// Given a list of device names, this operation returns the index of the device + /// this op runs. The length of the list is returned in two cases: + /// (1) Device does not exist in the given device list. + /// (2) It is in XLA compilation. + /// + /// + /// + /// + public static Tensor device_index(string[] device_names, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DeviceIndex", name) { args = new object[] { }, attrs = new Dictionary() { ["device_names"] = device_names } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return device_index_eager_fallback(device_names: device_names, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["device_names"] = device_names; + var _op = tf.OpDefLib._apply_op_helper("DeviceIndex", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "device_names", _op.get_attr("device_names") }; + _execute.record_gradient("DeviceIndex", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor device_index_eager_fallback(string[] device_names, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "device_names", device_names }; + var _result = _execute.execute("DeviceIndex", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DeviceIndex", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// ~~%~~ This op is used as a placeholder in If branch functions. It doesn't provide a~~%~~ valid output when run, so must either be removed (e.g. replaced with a~~%~~ function input) or guaranteed not to be used (e.g. if mirroring an~~%~~ intermediate output needed for the gradient computation of the other branch).~~%~~ + /// + /// + /// The type of the output. + /// + /// + /// + /// The purported shape of the output. This is only used for shape inference; + /// the output will not necessarily have this shape. Can be a partial shape. + /// + /// + /// + public static Tensor fake_param(TF_DataType dtype, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeParam", name) { args = new object[] { }, attrs = new Dictionary() { ["dtype"] = dtype, ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fake_param_eager_fallback(dtype: dtype, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dtype"] = dtype; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("FakeParam", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("FakeParam", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_param_eager_fallback(TF_DataType dtype, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "dtype", dtype, "shape", shape }; + var _result = _execute.execute("FakeParam", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeParam", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Applies a for loop. + /// + /// + /// + /// ```python + /// output = input; + /// for i in range(start, limit, delta) + /// output = body(i, output); + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + /// A function that takes a list of tensors (int32, T) and returns another + /// list of tensors (T). + /// + /// + /// + public static Tensor[] _for(Tensor start, Tensor limit, Tensor delta, Tensors input, object body, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "For", name) { args = new object[] { start, limit, delta, input }, attrs = new Dictionary() { ["body"] = body } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return for_eager_fallback(start, limit, delta, input, body: body, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["start"] = start; + keywords["limit"] = limit; + keywords["delta"] = delta; + keywords["input"] = input; + keywords["body"] = body; + var _op = tf.OpDefLib._apply_op_helper("For", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "body", _op.get_attr("body") }; + _execute.record_gradient("For", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] for_eager_fallback(Tensor start, Tensor limit, Tensor delta, Tensor input, object body, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { start, limit, delta, input }; + object[] _attrs = new object[] { "body", body }; + var _result = _execute.execute("For", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("For", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = cond ? then_branch(input) : else_branch(input) + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what else_branch returns. + /// + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what then_branch returns. + /// + /// + /// + /// + public static Tensor[] _if(Tensor cond, Tensors input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "If", name) { args = new object[] { cond, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["then_branch"] = then_branch, ["else_branch"] = else_branch, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return if_eager_fallback(cond, input, Tout: Tout, then_branch: then_branch, else_branch: else_branch, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["cond"] = cond; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["then_branch"] = then_branch; + keywords["else_branch"] = else_branch; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("If", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tcond", _op._get_attr_type("Tcond"), "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "then_branch", _op.get_attr("then_branch"), "else_branch", _op.get_attr("else_branch"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("If", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] if_eager_fallback(Tensor cond, Tensor input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { cond, input }; + object[] _attrs = new object[] { "Tcond", cond.dtype, "then_branch", then_branch, "else_branch", else_branch, "output_shapes", output_shapes }; + var _result = _execute.execute("If", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("If", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// returns `f(inputs)`, where `f`'s body is placed and partitioned. + /// + /// + /// + /// Asynchronously executes a function, potentially across multiple devices but + /// within a single process. The kernel places and partitions a given function's + /// underlying graph, and executes each of the partitioned subgraphs as a function. + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'args', a list of tensors, and returns 'output', + /// another list of tensors. Input and output types are specified by 'Tin' + /// and 'Tout'. The function body of f will be placed and partitioned across + /// devices, setting this op apart from the regular Call op. + /// + /// + /// + /// + /// + /// + public static Tensor[] partitioned_call(Tensors args, TF_DataType[] Tout, object f, string config = "", string config_proto = "", string executor_type = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PartitionedCall", name) { args = new object[] { args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f, ["config"] = config, ["config_proto"] = config_proto, ["executor_type"] = executor_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return partitioned_call_eager_fallback(args, Tout: Tout, f: f, config: config, config_proto: config_proto, executor_type: executor_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (config is null) + { + config = ""; + } + if (config_proto is null) + { + config_proto = ""; + } + if (executor_type is null) + { + executor_type = ""; + } + Dictionary keywords = new(); + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + keywords["config"] = config; + keywords["config_proto"] = config_proto; + keywords["executor_type"] = executor_type; + var _op = tf.OpDefLib._apply_op_helper("PartitionedCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f"), "config", _op.get_attr("config"), "config_proto", _op.get_attr("config_proto"), "executor_type", _op.get_attr("executor_type") }; + _execute.record_gradient("PartitionedCall", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] partitioned_call_eager_fallback(Tensor args, TF_DataType[] Tout, object f, string config, string config_proto, string executor_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { args }; + object[] _attrs = new object[] { "f", f, "config", config, "config_proto", config_proto, "executor_type", executor_type }; + var _result = _execute.execute("PartitionedCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PartitionedCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Runs function `f` on a remote device indicated by `target`. + /// + /// + /// + /// + /// + /// The type list for the return values. + /// + /// + /// + /// + /// The function to run remotely. + /// + /// + /// + public static Tensor[] remote_call(Tensor target, Tensors args, TF_DataType[] Tout, object f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RemoteCall", name) { args = new object[] { target, args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return remote_call_eager_fallback(target, args, Tout: Tout, f: f, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["target"] = target; + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + var _op = tf.OpDefLib._apply_op_helper("RemoteCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f") }; + _execute.record_gradient("RemoteCall", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] remote_call_eager_fallback(Tensor target, Tensor args, TF_DataType[] Tout, object f, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { target, args }; + object[] _attrs = new object[] { "f", f }; + var _result = _execute.execute("RemoteCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RemoteCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// returns `f(inputs)`, where `f`'s body is placed and partitioned. + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'args', a list of tensors, and returns 'output', + /// another list of tensors. Input and output types are specified by 'Tin' + /// and 'Tout'. The function body of f will be placed and partitioned across + /// devices, setting this op apart from the regular Call op. This op is + /// stateful. + /// + /// + /// + /// + /// + /// + public static Tensor[] stateful_partitioned_call(Tensors args, TF_DataType[] Tout, object f, string config = "", string config_proto = "", string executor_type = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatefulPartitionedCall", name) { args = new object[] { args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f, ["config"] = config, ["config_proto"] = config_proto, ["executor_type"] = executor_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return stateful_partitioned_call_eager_fallback(args, Tout: Tout, f: f, config: config, config_proto: config_proto, executor_type: executor_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (config is null) + { + config = ""; + } + if (config_proto is null) + { + config_proto = ""; + } + if (executor_type is null) + { + executor_type = ""; + } + Dictionary keywords = new(); + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + keywords["config"] = config; + keywords["config_proto"] = config_proto; + keywords["executor_type"] = executor_type; + var _op = tf.OpDefLib._apply_op_helper("StatefulPartitionedCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f"), "config", _op.get_attr("config"), "config_proto", _op.get_attr("config_proto"), "executor_type", _op.get_attr("executor_type") }; + _execute.record_gradient("StatefulPartitionedCall", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] stateful_partitioned_call_eager_fallback(Tensor args, TF_DataType[] Tout, object f, string config, string config_proto, string executor_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { args }; + object[] _attrs = new object[] { "f", f, "config", config, "config_proto", config_proto, "executor_type", executor_type }; + var _result = _execute.execute("StatefulPartitionedCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatefulPartitionedCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// An n-way switch statement which calls a single branch function. + /// + /// + /// + /// An n-way switch statement, implementing the following: + /// ``` + /// switch (branch_index) { + /// case 0: + /// output = branches[0](input); + /// break; + /// case 1: + /// output = branches[1](input); + /// break; + /// ... + /// case [[nbranches-1]]: + /// default: + /// output = branches[nbranches-1](input); + /// break; + /// } + /// ``` + /// + /// This should only be used when the none of branches has stateful ops. + /// + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A list of functions each of which takes 'inputs' and returns a list of + /// tensors, whose types are the same as what every other branch returns. + /// + /// + /// + /// + public static Tensor[] stateless_case(Tensor branch_index, Tensors input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessCase", name) { args = new object[] { branch_index, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["branches"] = branches, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return stateless_case_eager_fallback(branch_index, input, Tout: Tout, branches: branches, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["branch_index"] = branch_index; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["branches"] = branches; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("StatelessCase", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "branches", _op.get_attr("branches"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("StatelessCase", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] stateless_case_eager_fallback(Tensor branch_index, Tensor input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { branch_index, input }; + object[] _attrs = new object[] { "branches", branches, "output_shapes", output_shapes }; + var _result = _execute.execute("StatelessCase", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatelessCase", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = cond ? then_branch(input) : else_branch(input) + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what else_branch returns. + /// + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what then_branch returns. + /// + /// + /// + /// + public static Tensor[] stateless_if(Tensor cond, Tensors input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessIf", name) { args = new object[] { cond, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["then_branch"] = then_branch, ["else_branch"] = else_branch, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return stateless_if_eager_fallback(cond, input, Tout: Tout, then_branch: then_branch, else_branch: else_branch, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["cond"] = cond; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["then_branch"] = then_branch; + keywords["else_branch"] = else_branch; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("StatelessIf", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tcond", _op._get_attr_type("Tcond"), "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "then_branch", _op.get_attr("then_branch"), "else_branch", _op.get_attr("else_branch"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("StatelessIf", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] stateless_if_eager_fallback(Tensor cond, Tensor input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { cond, input }; + object[] _attrs = new object[] { "Tcond", cond.dtype, "then_branch", then_branch, "else_branch", else_branch, "output_shapes", output_shapes }; + var _result = _execute.execute("StatelessIf", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatelessIf", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = input; While (Cond(output)) { output = Body(output) } + /// + /// + /// + /// + /// A function takes 'input' and returns a tensor. If the tensor is + /// a scalar of non-boolean, the scalar is converted to a boolean + /// according to the following rule: if the scalar is a numerical + /// value, non-zero means True and zero means False; if the scalar is + /// a string, non-empty means True and empty means False. If the + /// tensor is not a scalar, non-emptiness means True and False + /// otherwise. + /// + /// This should only be used when the while condition and body functions + /// do not have stateful ops. + /// + /// + /// + /// + /// A function that takes a list of tensors and returns another + /// list of tensors. Both lists have the same types as specified + /// by T. + /// + /// + /// + /// + /// + public static Tensor[] stateless_while(Tensors input, object cond, object body, Shape[] output_shapes, int parallel_iterations = 10, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessWhile", name) { args = new object[] { input }, attrs = new Dictionary() { ["cond"] = cond, ["body"] = body, ["output_shapes"] = output_shapes, ["parallel_iterations"] = parallel_iterations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return stateless_while_eager_fallback(input, cond: cond, body: body, output_shapes: output_shapes, parallel_iterations: parallel_iterations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["cond"] = cond; + keywords["body"] = body; + keywords["output_shapes"] = output_shapes; + keywords["parallel_iterations"] = parallel_iterations; + var _op = tf.OpDefLib._apply_op_helper("StatelessWhile", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", _op.get_attr("body"), "output_shapes", _op.get_attr("output_shapes"), "parallel_iterations", _op._get_attr_int("parallel_iterations") }; + _execute.record_gradient("StatelessWhile", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] stateless_while_eager_fallback(Tensor input, object cond, object body, Shape[] output_shapes, int parallel_iterations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "cond", cond, "body", body, "output_shapes", output_shapes, "parallel_iterations", parallel_iterations }; + var _result = _execute.execute("StatelessWhile", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatelessWhile", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes the gradient function for function f via backpropagation. + /// + /// + /// + /// + /// the type list for the input list. + /// + /// + /// + /// + /// The function we want to compute the gradient for. + /// + /// The function 'f' must be a numerical function which takes N inputs and + /// produces M outputs. Its gradient function 'g', which is computed by + /// this SymbolicGradient op is a function taking N + M inputs and + /// produces N outputs. + /// + /// I.e. if we have + /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + /// then, g is + /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + /// dL/dy1, dL/dy2, ..., dL/dy_M), + /// + /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + /// loss function). dL/dx_i is the partial derivative of L with respect + /// to x_i. + /// + /// (Needs some math expert to say the comment above better.) + /// + /// + /// + public static Tensor[] symbolic_gradient(Tensors input, TF_DataType[] Tout, object f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SymbolicGradient", name) { args = new object[] { input }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return symbolic_gradient_eager_fallback(input, Tout: Tout, f: f, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["f"] = f; + var _op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f") }; + _execute.record_gradient("SymbolicGradient", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] symbolic_gradient_eager_fallback(Tensor input, TF_DataType[] Tout, object f, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "f", f }; + var _result = _execute.execute("SymbolicGradient", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SymbolicGradient", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Converts a tensor to a scalar predicate. + /// + /// + /// + /// Converts a tensor to a scalar predicate with the following rules: + /// + /// - For 0D tensors, truthiness is determined by comparing against a "zero" + /// value. For numerical types it is the obvious zero. For strings it is the + /// empty string. + /// + /// - For >0D tensors, truthiness is determined by looking at the number of + /// elements. If has zero elements, then the result is false. Otherwise the + /// result is true. + /// + /// This matches the behavior of If and While for determining if a tensor counts + /// as true/false for a branch condition. + /// + /// + /// + /// + public static Tensor to_bool(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ToBool", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return to_bool_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("ToBool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ToBool", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor to_bool_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("ToBool", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ToBool", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// output = input; While (Cond(output)) { output = Body(output) } + /// + /// + /// + /// + /// A function takes 'input' and returns a tensor. If the tensor is + /// a scalar of non-boolean, the scalar is converted to a boolean + /// according to the following rule: if the scalar is a numerical + /// value, non-zero means True and zero means False; if the scalar is + /// a string, non-empty means True and empty means False. If the + /// tensor is not a scalar, non-emptiness means True and False + /// otherwise. + /// + /// + /// + /// + /// A function that takes a list of tensors and returns another + /// list of tensors. Both lists have the same types as specified + /// by T. + /// + /// + /// + /// + /// + public static Tensor[] _while(Tensors input, object cond, object body, Shape[] output_shapes, int parallel_iterations = 10, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "While", name) { args = new object[] { input }, attrs = new Dictionary() { ["cond"] = cond, ["body"] = body, ["output_shapes"] = output_shapes, ["parallel_iterations"] = parallel_iterations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return while_eager_fallback(input, cond: cond, body: body, output_shapes: output_shapes, parallel_iterations: parallel_iterations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["cond"] = cond; + keywords["body"] = body; + keywords["output_shapes"] = output_shapes; + keywords["parallel_iterations"] = parallel_iterations; + var _op = tf.OpDefLib._apply_op_helper("While", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", _op.get_attr("body"), "output_shapes", _op.get_attr("output_shapes"), "parallel_iterations", _op._get_attr_int("parallel_iterations") }; + _execute.record_gradient("While", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] while_eager_fallback(Tensor input, object cond, object body, Shape[] output_shapes, int parallel_iterations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "cond", cond, "body", body, "output_shapes", output_shapes, "parallel_iterations", parallel_iterations }; + var _result = _execute.execute("While", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("While", _inputs_flat, _attrs, _result); + } + return _result; + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs new file mode 100644 index 000000000..cbe661ae5 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs @@ -0,0 +1,492 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; +using Tensorflow.Exceptions; +using Tensorflow.Contexts; +using System.Xml.Linq; +using Google.Protobuf; + +namespace Tensorflow +{ + public class gen_image_ops + { + public static Tensor adjust_contrastv2(Tensor images, Tensor contrast_factor, string name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustContrastv2", name) { + args = new object[] { images, contrast_factor }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return adjust_contrastv2_eager_fallback(images, contrast_factor, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["images"] = images; + keywords["contrast_factor"] = contrast_factor; + var _op = tf.OpDefLib._apply_op_helper("AdjustContrastv2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("AdjustContrastv2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + public static Tensor adjust_contrastv2(Tensor image, float contrast_factor, string name = null) + { + return adjust_contrastv2(image, tf.convert_to_tensor(contrast_factor), name: name); + } + + public static Tensor adjust_contrastv2_eager_fallback(Tensor images, Tensor contrast_factor, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { images, contrast_factor}; + object[] _attrs = new object[] { "T", images.dtype }; + var _result = _execute.execute("AdjustContrastv2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AdjustContrastv2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + + public static Tensor adjust_hue(Tensor images, Tensor delta, string name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustHue", name) { + args = new object[] { images, delta }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return adjust_hue_eager_fallback(images, delta, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["images"] = images; + keywords["delta"] = delta; + var _op = tf.OpDefLib._apply_op_helper("AdjustHue", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("AdjustHue", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor adjust_hue(Tensor images, float delta, string name = null) + => adjust_hue(images, delta, name: name); + + public static Tensor adjust_hue_eager_fallback(Tensor images, Tensor delta, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { images, delta}; + object[] _attrs = new object[] { "T", images.dtype }; + var _result = _execute.execute("AdjustHue", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AdjustHue", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + + public static Tensor adjust_saturation(Tensor images, Tensor scale, string name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustSaturation", name) + { + args = new object[] { images, scale }, + attrs = new Dictionary() { } + }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return adjust_hue_eager_fallback(images, scale, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["images"] = images; + keywords["scale"] = scale; + var _op = tf.OpDefLib._apply_op_helper("AdjustSaturation", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("AdjustSaturation", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor adjust_saturation(Tensor images, float scale, string name = null) + => adjust_saturation(images, ops.convert_to_tensor(scale), name: name); + + public static Tensor adjust_saturation_eager_fallback(Tensor images, Tensor scale, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { images, scale }; + object[] _attrs = new object[] { "T", images.dtype }; + var _result = _execute.execute("AdjustSaturation", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AdjustSaturation", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + + public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size, + Tensor iou_threshold, Tensor score_threshold, bool pad_per_class = false, bool clip_boxes = true, string name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CombinedNonMaxSuppression", name){ + args = new object[] { + boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold, + "pad_per_class", pad_per_class, "clip_boxes", clip_boxes}, + attrs = new Dictionary() { }}); + return (_fast_path_result[0], _fast_path_result[1], _fast_path_result[2], _fast_path_result[3]); + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return combined_non_max_suppression_eager_fallback( + boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, + score_threshold, pad_per_class, clip_boxes, name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["boxes"] = boxes; + keywords["scores"] = scores; + keywords["max_output_size_per_class"] = max_output_size_per_class; + keywords["max_total_size"] = max_total_size; + keywords["iou_threshold"] = iou_threshold; + keywords["score_threshold"] = score_threshold; + keywords["pad_per_class"] = pad_per_class; + keywords["clip_boxes"] = clip_boxes; + + var _op = tf.OpDefLib._apply_op_helper("CombinedNonMaxSuppression", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "pad_per_class", _op._get_attr_type("pad_per_class") ,"clip_boxes", _op._get_attr_type("clip_boxes")}; + _execute.record_gradient("CombinedNonMaxSuppression", _op.inputs, _attrs, _result); + } + return (_result[0], _result[1], _result[2], _result[3]); + } + + public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression_eager_fallback(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size, + Tensor iou_threshold, Tensor score_threshold, bool pad_per_class, bool clip_boxes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold }; + object[] _attrs = new object[] { "pad_per_class", pad_per_class, "clip_boxes", clip_boxes }; + var _result = _execute.execute("CombinedNonMaxSuppression", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("CombinedNonMaxSuppression", _inputs_flat, _attrs, _result); + } + return (_result[0], _result[1], _result[2], _result[3]); + } + + public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CropAndResize", name) { + args = new object[] { + image, boxes, box_ind, crop_size, "method", method, "extrapolation_value", extrapolation_value }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return crop_and_resize_eager_fallback( + image, boxes, box_ind, crop_size, method: method, extrapolation_value: extrapolation_value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["image"] = image; + keywords["boxes"] = boxes; + keywords["box_ind"] = box_ind; + keywords["crop_size"] = crop_size; + keywords["method"] = method; + keywords["extrapolation_value"] = extrapolation_value; + var _op = tf.OpDefLib._apply_op_helper("CropAndResize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") ,"method", _op._get_attr_type("method") , + "extrapolation_value", _op.get_attr("extrapolation_value")}; + _execute.record_gradient("CropAndResize", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor crop_and_resize_eager_fallback(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method, float extrapolation_value, string name, Context ctx) + { + if (method is null) + method = "bilinear"; + //var method_cpmpat = ByteString.CopyFromUtf8(method ?? string.Empty); + //var extrapolation_value_float = (float)extrapolation_value; + + Tensor[] _inputs_flat = new Tensor[] { image, boxes, box_ind, crop_size, tf.convert_to_tensor(method), tf.convert_to_tensor(extrapolation_value) }; + object[] _attrs = new object[] { "T", image.dtype }; + var _result = _execute.execute("CropAndResize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("CropAndResize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + + + public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) + { + if (dtype == image.dtype) + return array_ops.identity(image, name: name); + + return tf_with(ops.name_scope(name, "convert_image", image), scope => + { + name = scope; + + if (image.dtype.is_integer() && dtype.is_integer()) + { + throw new NotImplementedException("convert_image_dtype is_integer"); + } + else if (image.dtype.is_floating() && dtype.is_floating()) + { + throw new NotImplementedException("convert_image_dtype is_floating"); + } + else + { + if (image.dtype.is_integer()) + { + // Converting to float: first cast, then scale. No saturation possible. + var cast = math_ops.cast(image, dtype); + var scale = 1.0f / image.dtype.max(); + return math_ops.multiply(cast, scale, name: name); + } + else + { + throw new NotImplementedException("convert_image_dtype is_integer"); + } + } + }); + } + + public static Tensor decode_image(Tensor contents, + long channels = 0, + TF_DataType dtype = TF_DataType.TF_UINT8, + bool expand_animations = true, + string name = null) + => tf.Context.ExecuteOp("DecodeImage", name, + new ExecuteOpArgs(contents).SetAttributes(new + { + channels, + dtype, + expand_animations + })); + + public static Tensor decode_jpeg(Tensor contents, + long channels = 0, + long ratio = 1, + bool fancy_upscaling = true, + bool try_recover_truncated = false, + float acceptable_fraction = 1, + string dct_method = "", + string name = null) + => tf.Context.ExecuteOp("DecodeJpeg", name, + new ExecuteOpArgs(contents).SetAttributes( + new + { + channels, + ratio, + fancy_upscaling, + try_recover_truncated, + acceptable_fraction, + dct_method + })); + + public static Tensor decode_gif(Tensor contents, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.Context.executing_eagerly()) + { + throw new NotImplementedException("decode_gif"); + } + else + { + var _op = tf.OpDefLib._apply_op_helper("DecodeGif", name: name, args: new + { + contents + }); + + return _op.output; + } + } + + public static Tensor decode_png(Tensor contents, + int channels = 0, + TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.Context.executing_eagerly()) + { + throw new NotImplementedException("decode_png"); + } + else + { + var _op = tf.OpDefLib._apply_op_helper("DecodePng", name: name, args: new + { + contents, + channels, + dtype + }); + + return _op.output; + } + } + + public static Tensor decode_bmp(Tensor contents, + int channels = 0, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.Context.executing_eagerly()) + { + throw new NotImplementedException("decode_bmp"); + } + else + { + var _op = tf.OpDefLib._apply_op_helper("DecodeBmp", name: name, args: new + { + contents, + channels + }); + + return _op.output; + } + } + + public static Tensor resize_bilinear(Tensor images, + Tensor size, + bool align_corners = false, + bool half_pixel_centers = false, + string name = null) + => tf.Context.ExecuteOp("ResizeBilinear", name, + new ExecuteOpArgs(images, size).SetAttributes(new + { + align_corners, + half_pixel_centers + })); + + public static Tensor resize_bicubic(Tensor images, + Tensor size, + bool align_corners = false, + bool half_pixel_centers = false, + string name = null) + => tf.Context.ExecuteOp("ResizeBicubic", name, + new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers })); + + public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + => tf.Context.ExecuteOp("ResizeNearestNeighbor", name, + new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers })); + + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + => tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new ExecuteOpArgs(grads, size) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + align_corners = op.get_attr("align_corners"), + half_pixel_centers = op.get_attr("half_pixel_centers") + } + }.SetAttributes(new { align_corners, half_pixel_centers })); + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.cs b/src/TensorFlowNET.Core/Operations/gen_io_ops.cs new file mode 100644 index 000000000..0b92ff360 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_io_ops.cs @@ -0,0 +1,2096 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_io_ops +{ + /// + /// A Reader that outputs fixed-length records from a file. + /// + /// + /// + /// Number of bytes in the header, defaults to 0. + /// + /// + /// + /// + /// Number of bytes in the record. + /// + /// + /// + /// + /// Number of bytes in the footer, defaults to 0. + /// + /// + /// + /// + /// Number of bytes to hop before each read. Default of 0 means using + /// record_bytes. + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor fixed_length_record_reader(int header_bytes = 0, int record_bytes = 0, int footer_bytes = 0, int hop_bytes = 0, string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FixedLengthRecordReader", name) { args = new object[] { }, attrs = new Dictionary() { ["header_bytes"] = header_bytes, ["record_bytes"] = record_bytes, ["footer_bytes"] = footer_bytes, ["hop_bytes"] = hop_bytes, ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fixed_length_record_reader_eager_fallback(header_bytes: header_bytes, record_bytes: record_bytes, footer_bytes: footer_bytes, hop_bytes: hop_bytes, container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["header_bytes"] = header_bytes; + keywords["record_bytes"] = record_bytes; + keywords["footer_bytes"] = footer_bytes; + keywords["hop_bytes"] = hop_bytes; + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("FixedLengthRecordReader", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "header_bytes", _op._get_attr_int("header_bytes"), "record_bytes", _op._get_attr_int("record_bytes"), "footer_bytes", _op._get_attr_int("footer_bytes"), "hop_bytes", _op._get_attr_int("hop_bytes"), "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("FixedLengthRecordReader", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fixed_length_record_reader_eager_fallback(int header_bytes, int record_bytes, int footer_bytes, int hop_bytes, string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "header_bytes", header_bytes, "record_bytes", record_bytes, "footer_bytes", footer_bytes, "hop_bytes", hop_bytes, "container", container, "shared_name", shared_name }; + var _result = _execute.execute("FixedLengthRecordReader", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FixedLengthRecordReader", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs fixed-length records from a file. + /// + /// + /// + /// Number of bytes in the header, defaults to 0. + /// + /// + /// + /// + /// Number of bytes in the record. + /// + /// + /// + /// + /// Number of bytes in the footer, defaults to 0. + /// + /// + /// + /// + /// Number of bytes to hop before each read. Default of 0 means using + /// record_bytes. + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + /// + /// The type of encoding for the file. Currently ZLIB and GZIP + /// are supported. Defaults to none. + /// + /// + /// + public static Tensor fixed_length_record_reader_v2(int header_bytes = 0, int record_bytes = 0, int footer_bytes = 0, int hop_bytes = 0, string container = "", string shared_name = "", string encoding = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FixedLengthRecordReaderV2", name) { args = new object[] { }, attrs = new Dictionary() { ["header_bytes"] = header_bytes, ["record_bytes"] = record_bytes, ["footer_bytes"] = footer_bytes, ["hop_bytes"] = hop_bytes, ["container"] = container, ["shared_name"] = shared_name, ["encoding"] = encoding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fixed_length_record_reader_v2_eager_fallback(header_bytes: header_bytes, record_bytes: record_bytes, footer_bytes: footer_bytes, hop_bytes: hop_bytes, container: container, shared_name: shared_name, encoding: encoding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + if (encoding is null) + { + encoding = ""; + } + Dictionary keywords = new(); + keywords["header_bytes"] = header_bytes; + keywords["record_bytes"] = record_bytes; + keywords["footer_bytes"] = footer_bytes; + keywords["hop_bytes"] = hop_bytes; + keywords["container"] = container; + keywords["shared_name"] = shared_name; + keywords["encoding"] = encoding; + var _op = tf.OpDefLib._apply_op_helper("FixedLengthRecordReaderV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "header_bytes", _op._get_attr_int("header_bytes"), "record_bytes", _op._get_attr_int("record_bytes"), "footer_bytes", _op._get_attr_int("footer_bytes"), "hop_bytes", _op._get_attr_int("hop_bytes"), "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name"), "encoding", _op.get_attr("encoding") }; + _execute.record_gradient("FixedLengthRecordReaderV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fixed_length_record_reader_v2_eager_fallback(int header_bytes, int record_bytes, int footer_bytes, int hop_bytes, string container, string shared_name, string encoding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "header_bytes", header_bytes, "record_bytes", record_bytes, "footer_bytes", footer_bytes, "hop_bytes", hop_bytes, "container", container, "shared_name", shared_name, "encoding", encoding }; + var _result = _execute.execute("FixedLengthRecordReaderV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FixedLengthRecordReaderV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the queued work as both the key and value. + /// + /// + /// + /// To use, enqueue strings in a Queue. ReaderRead will take the front + /// work string and output (work, work). + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor identity_reader(string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IdentityReader", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return identity_reader_eager_fallback(container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("IdentityReader", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("IdentityReader", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor identity_reader_eager_fallback(string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name }; + var _result = _execute.execute("IdentityReader", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IdentityReader", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the queued work as both the key and value. + /// + /// + /// + /// To use, enqueue strings in a Queue. ReaderRead will take the front + /// work string and output (work, work). + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor identity_reader_v2(string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IdentityReaderV2", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return identity_reader_v2_eager_fallback(container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("IdentityReaderV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("IdentityReaderV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor identity_reader_v2_eager_fallback(string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name }; + var _result = _execute.execute("IdentityReaderV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IdentityReaderV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the set of files matching one or more glob patterns. + /// + /// + /// + /// Note that this routine only supports wildcard characters in the + /// basename portion of the pattern, not in the directory portion. + /// Note also that the order of filenames returned is deterministic. + /// + /// + /// + /// + public static Tensor matching_files(Tensor pattern, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatchingFiles", name) { args = new object[] { pattern }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return matching_files_eager_fallback(pattern, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["pattern"] = pattern; + var _op = tf.OpDefLib._apply_op_helper("MatchingFiles", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("MatchingFiles", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor matching_files_eager_fallback(Tensor pattern, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { pattern }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("MatchingFiles", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatchingFiles", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Reads and outputs the entire contents of the input filename. + /// + /// + /// + public static Tensor read_file(Tensor filename, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReadFile", name) { args = new object[] { filename }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return read_file_eager_fallback(filename, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["filename"] = filename; + var _op = tf.OpDefLib._apply_op_helper("ReadFile", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReadFile", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor read_file_eager_fallback(Tensor filename, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { filename }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReadFile", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReadFile", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the number of records this Reader has produced. + /// + /// + /// + /// This is the same as the number of ReaderRead executions that have + /// succeeded. + /// + /// + /// + /// + public static Tensor reader_num_records_produced(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_num_records_produced op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderNumRecordsProduced", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderNumRecordsProduced", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_num_records_produced_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + throw new RuntimeError($"reader_num_records_produced op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Returns the number of records this Reader has produced. + /// + /// + /// + /// This is the same as the number of ReaderRead executions that have + /// succeeded. + /// + /// + /// + /// + public static Tensor reader_num_records_produced_v2(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderNumRecordsProducedV2", name) { args = new object[] { reader_handle }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_num_records_produced_v2_eager_fallback(reader_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderNumRecordsProducedV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderNumRecordsProducedV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_num_records_produced_v2_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderNumRecordsProducedV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderNumRecordsProducedV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the number of work units this Reader has finished processing. + /// + /// + /// + public static Tensor reader_num_work_units_completed(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_num_work_units_completed op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderNumWorkUnitsCompleted", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderNumWorkUnitsCompleted", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_num_work_units_completed_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + throw new RuntimeError($"reader_num_work_units_completed op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Returns the number of work units this Reader has finished processing. + /// + /// + /// + public static Tensor reader_num_work_units_completed_v2(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderNumWorkUnitsCompletedV2", name) { args = new object[] { reader_handle }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_num_work_units_completed_v2_eager_fallback(reader_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderNumWorkUnitsCompletedV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderNumWorkUnitsCompletedV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_num_work_units_completed_v2_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderNumWorkUnitsCompletedV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderNumWorkUnitsCompletedV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the next record (key, value pair) produced by a Reader. + /// + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// + /// + /// + /// + /// + public static Tensor[] reader_read(Tensor reader_handle, Tensor queue_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_read op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["queue_handle"] = queue_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderRead", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderRead", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] reader_read_eager_fallback(Tensor reader_handle, Tensor queue_handle, string name, Context ctx) + { + throw new RuntimeError($"reader_read op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Returns up to `num_records` (key, value) pairs produced by a Reader. + /// + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// It may return less than `num_records` even before the last batch. + /// + /// + /// + /// + /// + /// + public static Tensor[] reader_read_up_to(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_read_up_to op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["queue_handle"] = queue_handle; + keywords["num_records"] = num_records; + var _op = tf.OpDefLib._apply_op_helper("ReaderReadUpTo", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderReadUpTo", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] reader_read_up_to_eager_fallback(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string name, Context ctx) + { + throw new RuntimeError($"reader_read_up_to op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Returns up to `num_records` (key, value) pairs produced by a Reader. + /// + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// It may return less than `num_records` even before the last batch. + /// + /// + /// + /// + /// + /// + public static Tensor[] reader_read_up_to_v2(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderReadUpToV2", name) { args = new object[] { reader_handle, queue_handle, num_records }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_read_up_to_v2_eager_fallback(reader_handle, queue_handle, num_records, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["queue_handle"] = queue_handle; + keywords["num_records"] = num_records; + var _op = tf.OpDefLib._apply_op_helper("ReaderReadUpToV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderReadUpToV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] reader_read_up_to_v2_eager_fallback(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle, queue_handle, num_records }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderReadUpToV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderReadUpToV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns the next record (key, value pair) produced by a Reader. + /// + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// + /// + /// + /// + /// + public static Tensor[] reader_read_v2(Tensor reader_handle, Tensor queue_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderReadV2", name) { args = new object[] { reader_handle, queue_handle }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_read_v2_eager_fallback(reader_handle, queue_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["queue_handle"] = queue_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderReadV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderReadV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] reader_read_v2_eager_fallback(Tensor reader_handle, Tensor queue_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle, queue_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderReadV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderReadV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Restore a Reader to its initial clean state. + /// + /// + /// + public static Operation reader_reset(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_reset op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderReset", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderReset", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation reader_reset_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + throw new RuntimeError($"reader_reset op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Restore a Reader to its initial clean state. + /// + /// + /// + public static Operation reader_reset_v2(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderResetV2", name) { args = new object[] { reader_handle }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_reset_v2_eager_fallback(reader_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderResetV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderResetV2", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation reader_reset_v2_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderResetV2", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderResetV2", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Restore a reader to a previously saved state. + /// + /// + /// + /// Not all Readers support being restored, so this can produce an + /// Unimplemented error. + /// + /// + /// + /// + /// + public static Operation reader_restore_state(Tensor reader_handle, Tensor state, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_restore_state op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["state"] = state; + var _op = tf.OpDefLib._apply_op_helper("ReaderRestoreState", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderRestoreState", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation reader_restore_state_eager_fallback(Tensor reader_handle, Tensor state, string name, Context ctx) + { + throw new RuntimeError($"reader_restore_state op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Restore a reader to a previously saved state. + /// + /// + /// + /// Not all Readers support being restored, so this can produce an + /// Unimplemented error. + /// + /// + /// + /// + /// + public static Operation reader_restore_state_v2(Tensor reader_handle, Tensor state, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderRestoreStateV2", name) { args = new object[] { reader_handle, state }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_restore_state_v2_eager_fallback(reader_handle, state, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + keywords["state"] = state; + var _op = tf.OpDefLib._apply_op_helper("ReaderRestoreStateV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderRestoreStateV2", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation reader_restore_state_v2_eager_fallback(Tensor reader_handle, Tensor state, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle, state }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderRestoreStateV2", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderRestoreStateV2", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Produce a string tensor that encodes the state of a Reader. + /// + /// + /// + /// Not all Readers support being serialized, so this can produce an + /// Unimplemented error. + /// + /// + /// + /// + public static Tensor reader_serialize_state(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + throw new RuntimeError("reader_serialize_state op does not support eager execution. Arg reader_handle is a ref."); + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderSerializeState", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderSerializeState", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_serialize_state_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + throw new RuntimeError($"reader_serialize_state op does not support eager execution. Arg 'reader_handle' is a ref."); + } + /// + /// Produce a string tensor that encodes the state of a Reader. + /// + /// + /// + /// Not all Readers support being serialized, so this can produce an + /// Unimplemented error. + /// + /// + /// + /// + public static Tensor reader_serialize_state_v2(Tensor reader_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReaderSerializeStateV2", name) { args = new object[] { reader_handle }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reader_serialize_state_v2_eager_fallback(reader_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["reader_handle"] = reader_handle; + var _op = tf.OpDefLib._apply_op_helper("ReaderSerializeStateV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ReaderSerializeStateV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reader_serialize_state_v2_eager_fallback(Tensor reader_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { reader_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ReaderSerializeStateV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReaderSerializeStateV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Restores a tensor from checkpoint files. + /// + /// + /// + /// Reads a tensor stored in one or several files. If there are several files (for + /// instance because a tensor was saved as slices), `file_pattern` may contain + /// wildcard symbols (`*` and `?`) in the filename portion only, not in the + /// directory portion. + /// + /// If a `file_pattern` matches several files, `preferred_shard` can be used to hint + /// in which file the requested tensor is likely to be found. This op will first + /// open the file at index `preferred_shard` in the list of matching files and try + /// to restore tensors from that file. Only if some tensors or tensor slices are + /// not found in that first file, then the Op opens all the files. Setting + /// `preferred_shard` to match the value passed as the `shard` input + /// of a matching `Save` Op may speed up Restore. This attribute only affects + /// performance, not correctness. The default value -1 means files are processed in + /// order. + /// + /// See also `RestoreSlice`. + /// + /// + /// + /// + /// + /// + /// The type of the tensor to be restored. + /// + /// + /// + /// + /// Index of file to open first if multiple files match + /// `file_pattern`. + /// + /// + /// + public static Tensor restore(Tensor file_pattern, Tensor tensor_name, TF_DataType dt, int preferred_shard = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Restore", name) { args = new object[] { file_pattern, tensor_name }, attrs = new Dictionary() { ["dt"] = dt, ["preferred_shard"] = preferred_shard } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return restore_eager_fallback(file_pattern, tensor_name, dt: dt, preferred_shard: preferred_shard, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["file_pattern"] = file_pattern; + keywords["tensor_name"] = tensor_name; + keywords["dt"] = dt; + keywords["preferred_shard"] = preferred_shard; + var _op = tf.OpDefLib._apply_op_helper("Restore", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dt", _op._get_attr_type("dt"), "preferred_shard", _op._get_attr_int("preferred_shard") }; + _execute.record_gradient("Restore", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor restore_eager_fallback(Tensor file_pattern, Tensor tensor_name, TF_DataType dt, int preferred_shard, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { file_pattern, tensor_name }; + object[] _attrs = new object[] { "dt", dt, "preferred_shard", preferred_shard }; + var _result = _execute.execute("Restore", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Restore", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Restores a tensor from checkpoint files. + /// + /// + /// + /// This is like `Restore` except that restored tensor can be listed as filling + /// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the + /// larger tensor and the slice that the restored tensor covers. + /// + /// The `shape_and_slice` input has the same format as the + /// elements of the `shapes_and_slices` input of the `SaveSlices` op. + /// + /// + /// + /// + /// + /// + /// + /// The type of the tensor to be restored. + /// + /// + /// + /// + /// Index of file to open first if multiple files match + /// `file_pattern`. See the documentation for `Restore`. + /// + /// + /// + public static Tensor restore_slice(Tensor file_pattern, Tensor tensor_name, Tensor shape_and_slice, TF_DataType dt, int preferred_shard = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RestoreSlice", name) { args = new object[] { file_pattern, tensor_name, shape_and_slice }, attrs = new Dictionary() { ["dt"] = dt, ["preferred_shard"] = preferred_shard } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return restore_slice_eager_fallback(file_pattern, tensor_name, shape_and_slice, dt: dt, preferred_shard: preferred_shard, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["file_pattern"] = file_pattern; + keywords["tensor_name"] = tensor_name; + keywords["shape_and_slice"] = shape_and_slice; + keywords["dt"] = dt; + keywords["preferred_shard"] = preferred_shard; + var _op = tf.OpDefLib._apply_op_helper("RestoreSlice", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dt", _op._get_attr_type("dt"), "preferred_shard", _op._get_attr_int("preferred_shard") }; + _execute.record_gradient("RestoreSlice", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor restore_slice_eager_fallback(Tensor file_pattern, Tensor tensor_name, Tensor shape_and_slice, TF_DataType dt, int preferred_shard, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { file_pattern, tensor_name, shape_and_slice }; + object[] _attrs = new object[] { "dt", dt, "preferred_shard", preferred_shard }; + var _result = _execute.execute("RestoreSlice", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RestoreSlice", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Restores tensors from a V2 checkpoint. + /// + /// + /// + /// For backward compatibility with the V1 format, this Op currently allows + /// restoring from a V1 checkpoint as well: + /// - This Op first attempts to find the V2 index file pointed to by "prefix", and + /// if found proceed to read it as a V2 checkpoint; + /// - Otherwise the V1 read path is invoked. + /// Relying on this behavior is not recommended, as the ability to fall back to read + /// V1 might be deprecated and eventually removed. + /// + /// By default, restores the named tensors in full. If the caller wishes to restore + /// specific slices of stored tensors, "shape_and_slices" should be non-empty + /// strings and correspondingly well-formed. + /// + /// Callers must ensure all the named tensors are indeed stored in the checkpoint. + /// + /// + /// + /// + /// + /// + /// + /// shape {N}. The list of expected dtype for the tensors. Must match + /// those stored in the checkpoint. + /// + /// + /// + public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RestoreV2", name) { args = new object[] { prefix, tensor_names, shape_and_slices }, attrs = new Dictionary() { ["dtypes"] = dtypes } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes: dtypes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["prefix"] = prefix; + keywords["tensor_names"] = tensor_names; + keywords["shape_and_slices"] = shape_and_slices; + keywords["dtypes"] = dtypes; + var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtypes", _op.get_attr("dtypes") }; + _execute.record_gradient("RestoreV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] restore_v2_eager_fallback(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { prefix, tensor_names, shape_and_slices }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("RestoreV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RestoreV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Saves the input tensors to disk. + /// + /// + /// + /// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` + /// is written to `filename` with name `tensor_names[i]`. + /// + /// See also `SaveSlices`. + /// + /// + /// + /// + /// + /// + public static Operation save(Tensor filename, Tensor tensor_names, Tensors data, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Save", name) { args = new object[] { filename, tensor_names, data }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return save_eager_fallback(filename, tensor_names, data, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["filename"] = filename; + keywords["tensor_names"] = tensor_names; + keywords["data"] = data; + var _op = tf.OpDefLib._apply_op_helper("Save", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T") }; + _execute.record_gradient("Save", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation save_eager_fallback(Tensor filename, Tensor tensor_names, Tensor data, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { filename, tensor_names, data }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("Save", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Save", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Saves input tensors slices to disk. + /// + /// + /// + /// This is like `Save` except that tensors can be listed in the saved file as being + /// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the + /// larger tensor and the slice that this tensor covers. `shapes_and_slices` must + /// have as many elements as `tensor_names`. + /// + /// Elements of the `shapes_and_slices` input must either be: + /// + /// * The empty string, in which case the corresponding tensor is + /// saved normally. + /// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the + /// `dimI` are the dimensions of the larger tensor and `slice-spec` + /// specifies what part is covered by the tensor to save. + /// + /// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` + /// where each `sliceI` is either: + /// + /// * The string `-` meaning that the slice covers all indices of this dimension + /// * `start,length` where `start` and `length` are integers. In that + /// case the slice covers `length` indices starting at `start`. + /// + /// See also `Save`. + /// + /// + /// + /// + /// + /// + /// + public static Operation save_slices(Tensor filename, Tensor tensor_names, Tensor shapes_and_slices, Tensors data, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SaveSlices", name) { args = new object[] { filename, tensor_names, shapes_and_slices, data }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return save_slices_eager_fallback(filename, tensor_names, shapes_and_slices, data, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["filename"] = filename; + keywords["tensor_names"] = tensor_names; + keywords["shapes_and_slices"] = shapes_and_slices; + keywords["data"] = data; + var _op = tf.OpDefLib._apply_op_helper("SaveSlices", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T") }; + _execute.record_gradient("SaveSlices", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation save_slices_eager_fallback(Tensor filename, Tensor tensor_names, Tensor shapes_and_slices, Tensor data, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { filename, tensor_names, shapes_and_slices, data }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("SaveSlices", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SaveSlices", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Saves tensors in V2 checkpoint format. + /// + /// + /// + /// By default, saves the named tensors in full. If the caller wishes to save + /// specific slices of full tensors, "shape_and_slices" should be non-empty strings + /// and correspondingly well-formed. + /// + /// + /// + /// + /// + /// + /// + public static Operation save_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, Tensors tensors, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SaveV2", name) { args = new object[] { prefix, tensor_names, shape_and_slices, tensors }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["prefix"] = prefix; + keywords["tensor_names"] = tensor_names; + keywords["shape_and_slices"] = shape_and_slices; + keywords["tensors"] = tensors; + var _op = tf.OpDefLib._apply_op_helper("SaveV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtypes", _op.get_attr("dtypes") }; + _execute.record_gradient("SaveV2", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation save_v2_eager_fallback(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, Tensor tensors, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { prefix, tensor_names, shape_and_slices, tensors }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("SaveV2", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SaveV2", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Generate a sharded filename. The filename is printf formatted as + /// + /// + /// + /// %s-%05d-of-%05d, basename, shard, num_shards. + /// + /// + /// + /// + /// + /// + public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ShardedFilename", name) { args = new object[] { basename, shard, num_shards }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sharded_filename_eager_fallback(basename, shard, num_shards, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["basename"] = basename; + keywords["shard"] = shard; + keywords["num_shards"] = num_shards; + var _op = tf.OpDefLib._apply_op_helper("ShardedFilename", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ShardedFilename", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sharded_filename_eager_fallback(Tensor basename, Tensor shard, Tensor num_shards, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { basename, shard, num_shards }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ShardedFilename", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ShardedFilename", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Generate a glob pattern matching all sharded file names. + /// + /// + /// + /// + public static Tensor sharded_filespec(Tensor basename, Tensor num_shards, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ShardedFilespec", name) { args = new object[] { basename, num_shards }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sharded_filespec_eager_fallback(basename, num_shards, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["basename"] = basename; + keywords["num_shards"] = num_shards; + var _op = tf.OpDefLib._apply_op_helper("ShardedFilespec", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ShardedFilespec", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sharded_filespec_eager_fallback(Tensor basename, Tensor num_shards, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { basename, num_shards }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ShardedFilespec", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ShardedFilespec", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the lines of a file delimited by '\n'. + /// + /// + /// + /// Number of lines to skip from the beginning of every file. + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor text_line_reader(int skip_header_lines = 0, string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TextLineReader", name) { args = new object[] { }, attrs = new Dictionary() { ["skip_header_lines"] = skip_header_lines, ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return text_line_reader_eager_fallback(skip_header_lines: skip_header_lines, container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["skip_header_lines"] = skip_header_lines; + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("TextLineReader", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "skip_header_lines", _op._get_attr_int("skip_header_lines"), "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("TextLineReader", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor text_line_reader_eager_fallback(int skip_header_lines, string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "skip_header_lines", skip_header_lines, "container", container, "shared_name", shared_name }; + var _result = _execute.execute("TextLineReader", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TextLineReader", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the lines of a file delimited by '\n'. + /// + /// + /// + /// Number of lines to skip from the beginning of every file. + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor text_line_reader_v2(int skip_header_lines = 0, string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TextLineReaderV2", name) { args = new object[] { }, attrs = new Dictionary() { ["skip_header_lines"] = skip_header_lines, ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return text_line_reader_v2_eager_fallback(skip_header_lines: skip_header_lines, container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["skip_header_lines"] = skip_header_lines; + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("TextLineReaderV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "skip_header_lines", _op._get_attr_int("skip_header_lines"), "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("TextLineReaderV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor text_line_reader_v2_eager_fallback(int skip_header_lines, string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "skip_header_lines", skip_header_lines, "container", container, "shared_name", shared_name }; + var _result = _execute.execute("TextLineReaderV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TextLineReaderV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the entire contents of a file as a value. + /// + /// + /// + /// To use, enqueue filenames in a Queue. The output of ReaderRead will + /// be a filename (key) and the contents of that file (value). + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor whole_file_reader(string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "WholeFileReader", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return whole_file_reader_eager_fallback(container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("WholeFileReader", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("WholeFileReader", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor whole_file_reader_eager_fallback(string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name }; + var _result = _execute.execute("WholeFileReader", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("WholeFileReader", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// A Reader that outputs the entire contents of a file as a value. + /// + /// + /// + /// To use, enqueue filenames in a Queue. The output of ReaderRead will + /// be a filename (key) and the contents of that file (value). + /// + /// + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor whole_file_reader_v2(string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "WholeFileReaderV2", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return whole_file_reader_v2_eager_fallback(container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("WholeFileReaderV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("WholeFileReaderV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor whole_file_reader_v2_eager_fallback(string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name }; + var _result = _execute.execute("WholeFileReaderV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("WholeFileReaderV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Writes `contents` to the file at input `filename`. + /// + /// + /// + /// Creates the file and recursively creates directory if it does not exist. + /// + /// + /// + /// + /// + public static Operation write_file(Tensor filename, Tensor contents, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "WriteFile", name) { args = new object[] { filename, contents }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return write_file_eager_fallback(filename, contents, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["filename"] = filename; + keywords["contents"] = contents; + var _op = tf.OpDefLib._apply_op_helper("WriteFile", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("WriteFile", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation write_file_eager_fallback(Tensor filename, Tensor contents, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { filename, contents }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("WriteFile", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("WriteFile", _inputs_flat, _attrs, _result); + } + return null; + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_list_ops.cs b/src/TensorFlowNET.Core/Operations/gen_list_ops.cs new file mode 100644 index 000000000..59c783b24 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_list_ops.cs @@ -0,0 +1,1308 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_list_ops +{ + /// + /// Creates and returns an empty tensor list. + /// + /// + /// + /// All list elements must be tensors of dtype element_dtype and shape compatible + /// with element_shape. + /// + /// handle: an empty tensor list. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + /// + /// + /// + /// + /// + public static Tensor empty_tensor_list(Tensor element_shape, Tensor max_num_elements, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EmptyTensorList", name) { args = new object[] { element_shape, max_num_elements }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return empty_tensor_list_eager_fallback(element_shape, max_num_elements, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["element_shape"] = element_shape; + keywords["max_num_elements"] = max_num_elements; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("EmptyTensorList", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("EmptyTensorList", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor empty_tensor_list_eager_fallback(Tensor element_shape, Tensor max_num_elements, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { element_shape, max_num_elements }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("EmptyTensorList", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EmptyTensorList", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concats all tensors in the list along the 0th dimension. + /// + /// + /// + /// Requires that all tensors have the same shape except the first dimension. + /// + /// input_handle: The input list. + /// tensor: The concated result. + /// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_concat(Tensor input_handle, TF_DataType element_dtype, Shape element_shape = null, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcat", name) { args = new object[] { input_handle }, attrs = new Dictionary() { ["element_dtype"] = element_dtype, ["element_shape"] = element_shape } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_eager_fallback(input_handle, element_dtype: element_dtype, element_shape: element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_dtype"] = element_dtype; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcat", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "element_shape", _op.get_attr("element_shape") }; + _execute.record_gradient("TensorListConcat", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_concat_eager_fallback(Tensor input_handle, TF_DataType element_dtype, Shape element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "element_shape", element_shape }; + var _result = _execute.execute("TensorListConcat", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcat", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_concat_lists(Tensor input_a, Tensor input_b, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcatLists", name) { args = new object[] { input_a, input_b }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_lists_eager_fallback(input_a, input_b, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_a"] = input_a; + keywords["input_b"] = input_b; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcatLists", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListConcatLists", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_concat_lists_eager_fallback(Tensor input_a, Tensor input_b, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_a, input_b }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListConcatLists", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcatLists", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concats all tensors in the list along the 0th dimension. + /// + /// + /// + /// Requires that all tensors have the same shape except the first dimension. + /// + /// input_handle: The input list. + /// element_shape: The shape of the uninitialized elements in the list. If the first + /// dimension is not -1, it is assumed that all list elements have the same + /// leading dim. + /// leading_dims: The list of leading dims of uninitialized list elements. Used if + /// the leading dim of input_handle.element_shape or the element_shape input arg + /// is not already set. + /// tensor: The concated result. + /// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_concat_v2(Tensor input_handle, Tensor element_shape, Tensor leading_dims, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcatV2", name) { args = new object[] { input_handle, element_shape, leading_dims }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_v2_eager_fallback(input_handle, element_shape, leading_dims, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["leading_dims"] = leading_dims; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcatV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListConcatV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_concat_v2_eager_fallback(Tensor input_handle, Tensor element_shape, Tensor leading_dims, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape, leading_dims }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListConcatV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcatV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// The shape of the elements of the given list, as a tensor. + /// + /// + /// + /// input_handle: the list + /// element_shape: the shape of elements of the list + /// + /// + /// + /// + /// + public static Tensor tensor_list_element_shape(Tensor input_handle, TF_DataType shape_type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListElementShape", name) { args = new object[] { input_handle }, attrs = new Dictionary() { ["shape_type"] = shape_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_element_shape_eager_fallback(input_handle, shape_type: shape_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["shape_type"] = shape_type; + var _op = tf.OpDefLib._apply_op_helper("TensorListElementShape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListElementShape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_element_shape_eager_fallback(Tensor input_handle, TF_DataType shape_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { "shape_type", shape_type }; + var _result = _execute.execute("TensorListElementShape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListElementShape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList which, when stacked, has the value of `tensor`. + /// + /// + /// + /// Each tensor in the result list corresponds to one row of the input tensor. + /// + /// tensor: The input tensor. + /// output_handle: The list. + /// + /// + /// + /// + /// + public static Tensor tensor_list_from_tensor(Tensor tensor, Tensor element_shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListFromTensor", name) { args = new object[] { tensor, element_shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_from_tensor_eager_fallback(tensor, element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListFromTensor", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListFromTensor", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_from_tensor_eager_fallback(Tensor tensor, Tensor element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, element_shape }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListFromTensor", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListFromTensor", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a Tensor by indexing into the TensorList. + /// + /// + /// + /// Each row in the produced Tensor corresponds to the element in the TensorList + /// specified by the given index (see `tf.gather`). + /// + /// input_handle: The input tensor list. + /// indices: The indices used to index into the list. + /// values: The tensor. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListGather", name) { args = new object[] { input_handle, indices, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_gather_eager_fallback(input_handle, indices, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListGather", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListGather", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_gather_eager_fallback(Tensor input_handle, Tensor indices, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, indices, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListGather", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListGather", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListGetItem", name) { args = new object[] { input_handle, index, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_get_item_eager_fallback(input_handle, index, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["index"] = index; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListGetItem", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListGetItem", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_get_item_eager_fallback(Tensor input_handle, Tensor index, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, index, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListGetItem", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListGetItem", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the number of tensors in the input tensor list. + /// + /// + /// + /// input_handle: the input list + /// length: the number of tensors in the list + /// + /// + /// + /// + public static Tensor tensor_list_length(Tensor input_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListLength", name) { args = new object[] { input_handle }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_length_eager_fallback(input_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + var _op = tf.OpDefLib._apply_op_helper("TensorListLength", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("TensorListLength", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_length_eager_fallback(Tensor input_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("TensorListLength", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListLength", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the last element of the input list as well as a list with all but that element. + /// + /// + /// + /// Fails if the list is empty. + /// + /// input_handle: the input list + /// tensor: the withdrawn last element of the list + /// element_dtype: the type of elements in the list + /// element_shape: the shape of the output tensor + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_pop_back(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPopBack", name) { args = new object[] { input_handle, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_pop_back_eager_fallback(input_handle, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListPopBack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPopBack", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_pop_back_eager_fallback(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListPopBack", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPopBack", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. + /// + /// + /// + /// tensor: The tensor to put on the list. + /// input_handle: The old list. + /// output_handle: A list with the elements of the old list followed by tensor. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + /// + /// + /// + /// + public static Tensor tensor_list_push_back(Tensor input_handle, Tensor tensor, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPushBack", name) { args = new object[] { input_handle, tensor }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_push_back_eager_fallback(input_handle, tensor, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["tensor"] = tensor; + var _op = tf.OpDefLib._apply_op_helper("TensorListPushBack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPushBack", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_push_back_eager_fallback(Tensor input_handle, Tensor tensor, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, tensor }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListPushBack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPushBack", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_push_back_batch(Tensor input_handles, Tensor tensor, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPushBackBatch", name) { args = new object[] { input_handles, tensor }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_push_back_batch_eager_fallback(input_handles, tensor, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handles"] = input_handles; + keywords["tensor"] = tensor; + var _op = tf.OpDefLib._apply_op_helper("TensorListPushBackBatch", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPushBackBatch", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_push_back_batch_eager_fallback(Tensor input_handles, Tensor tensor, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handles, tensor }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListPushBackBatch", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPushBackBatch", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// List of the given size with empty elements. + /// + /// + /// + /// element_shape: the shape of the future elements of the list + /// num_elements: the number of elements to reserve + /// handle: the output list + /// element_dtype: the desired type of elements in the list. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_reserve(Tensor element_shape, Tensor num_elements, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListReserve", name) { args = new object[] { element_shape, num_elements }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_reserve_eager_fallback(element_shape, num_elements, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["element_shape"] = element_shape; + keywords["num_elements"] = num_elements; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListReserve", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListReserve", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_reserve_eager_fallback(Tensor element_shape, Tensor num_elements, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { element_shape, num_elements }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListReserve", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListReserve", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Resizes the list. + /// + /// + /// + /// + /// input_handle: the input list + /// size: size of the output list + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_resize(Tensor input_handle, Tensor size, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListResize", name) { args = new object[] { input_handle, size }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_resize_eager_fallback(input_handle, size, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["size"] = size; + var _op = tf.OpDefLib._apply_op_helper("TensorListResize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("TensorListResize", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_resize_eager_fallback(Tensor input_handle, Tensor size, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, size }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("TensorListResize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListResize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList by indexing into a Tensor. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// element_shape: The shape of the elements in the list (can be less specified than + /// the shape of the tensor). + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Tensor element_shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatter", name) { args = new object[] { tensor, indices, element_shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_eager_fallback(tensor, indices, element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListScatter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_eager_fallback(Tensor tensor, Tensor indices, Tensor element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, element_shape }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListScatter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Scatters tensor at indices in an input list. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// input_handle: The list to scatter into. + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter_into_existing_list(Tensor input_handle, Tensor tensor, Tensor indices, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatterIntoExistingList", name) { args = new object[] { input_handle, tensor, indices }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_into_existing_list_eager_fallback(input_handle, tensor, indices, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["tensor"] = tensor; + keywords["indices"] = indices; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatterIntoExistingList", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListScatterIntoExistingList", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_into_existing_list_eager_fallback(Tensor input_handle, Tensor tensor, Tensor indices, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, tensor, indices }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListScatterIntoExistingList", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatterIntoExistingList", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList by indexing into a Tensor. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// element_shape: The shape of the elements in the list (can be less specified than + /// the shape of the tensor). + /// num_elements: The size of the output list. Must be large enough to accommodate + /// the largest index in indices. If -1, the list is just large enough to include + /// the largest index in indices. + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter_v2(Tensor tensor, Tensor indices, Tensor element_shape, Tensor num_elements, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatterV2", name) { args = new object[] { tensor, indices, element_shape, num_elements }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_v2_eager_fallback(tensor, indices, element_shape, num_elements, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + keywords["num_elements"] = num_elements; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatterV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListScatterV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_v2_eager_fallback(Tensor tensor, Tensor indices, Tensor element_shape, Tensor num_elements, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, element_shape, num_elements }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListScatterV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatterV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListSetItem", name) { args = new object[] { input_handle, index, item }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_set_item_eager_fallback(input_handle, index, item, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["index"] = index; + keywords["item"] = item; + var _op = tf.OpDefLib._apply_op_helper("TensorListSetItem", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListSetItem", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_set_item_eager_fallback(Tensor input_handle, Tensor index, Tensor item, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, index, item }; + object[] _attrs = new object[] { "element_dtype", item.dtype }; + var _result = _execute.execute("TensorListSetItem", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListSetItem", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Splits a tensor into a list. + /// + /// + /// + /// list[i] corresponds to lengths[i] tensors from the input tensor. + /// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. + /// + /// tensor: The input tensor. + /// element_shape: A shape compatible with that of elements in the tensor. + /// lengths: Vector of sizes of the 0th dimension of tensors in the list. + /// output_handle: The list. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_split(Tensor tensor, Tensor element_shape, Tensor lengths, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListSplit", name) { args = new object[] { tensor, element_shape, lengths }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_split_eager_fallback(tensor, element_shape, lengths, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["element_shape"] = element_shape; + keywords["lengths"] = lengths; + var _op = tf.OpDefLib._apply_op_helper("TensorListSplit", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListSplit", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_split_eager_fallback(Tensor tensor, Tensor element_shape, Tensor lengths, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, element_shape, lengths }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListSplit", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListSplit", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Stacks all tensors in the list. + /// + /// + /// + /// Requires that all tensors have the same shape. + /// + /// input_handle: the input list + /// tensor: the gathered result + /// num_elements: optional. If not -1, the number of elements in the list. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_stack(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, int num_elements = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListStack", name) { args = new object[] { input_handle, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype, ["num_elements"] = num_elements } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tensor_list_stack_eager_fallback(input_handle, element_shape, element_dtype: element_dtype, num_elements: num_elements, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + keywords["num_elements"] = num_elements; + var _op = tf.OpDefLib._apply_op_helper("TensorListStack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "num_elements", _op._get_attr_int("num_elements") }; + _execute.record_gradient("TensorListStack", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_stack_eager_fallback(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, int num_elements, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "num_elements", num_elements }; + var _result = _execute.execute("TensorListStack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListStack", _inputs_flat, _attrs, _result); + } + return _result[0]; + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs new file mode 100644 index 000000000..d2907f090 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs @@ -0,0 +1,107 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_logging_ops + { + public static Operation assert(Tensor condition, object[] data, long summarize = 3, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + tf.Context, "Assert", name, + new object[] { condition, data, summarize })); + + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("Assert", name, args: new { condition, data, summarize }); + + return _op; + } + + public static Tensor histogram_summary(string tag, Tensor values, string name = null) + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("HistogramSummary", name: name, args: new { tag, values }); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with scalar values. + /// + /// + /// Tags for the summary. + /// + /// + /// Same shape as tags. Values for the summary. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScalarSummary'. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input tags and values must have the same shape. The generated summary + /// has a summary value for each tag-value pair in tags and values. + /// + public static Tensor scalar_summary(string tags, Tensor values, string name = "ScalarSummary") + { + var dict = new Dictionary(); + dict["tags"] = tags; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("ScalarSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Merges summaries. + /// + /// + /// Can be of any shape. Each must contain serialized Summary protocol + /// buffers. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MergeSummary'. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a + /// [Summary](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + /// protocol buffer that contains the union of all the values in the input + /// summaries. + /// + /// When the Op is run, it reports an InvalidArgument error if multiple values + /// in the summaries to merge use the same tag. + /// + public static Tensor merge_summary(Tensor[] inputs, string name = "MergeSummary") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("MergeSummary", name: name, keywords: dict); + return op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 6aae72bde..a8152a11e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -1,36 +1,10072 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Text; +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ -namespace Tensorflow +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_math_ops { - public static class gen_math_ops + /// + /// Computes the absolute value of a tensor. + /// + /// + /// + /// Given a tensor `x`, this operation returns a tensor containing the absolute + /// value of each element in `x`. For example, if x is an input element and y is + /// an output element, this operation computes \(y = |x|\). + /// + /// + /// + /// + public static Tensor abs(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Abs", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return abs_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Abs", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Abs", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor abs_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Abs", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Abs", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the element-wise sum of a list of tensors. + /// + /// + /// + /// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not + /// wait for all of its inputs to be ready before beginning to sum. This can + /// save memory if inputs are ready at different times, since minimum temporary + /// storage is proportional to the output size rather than the inputs size. + /// + /// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. + /// + /// Returns a `Tensor` of same shape and type as the elements of `inputs`. + /// + /// + /// + /// + /// + /// Shape of elements of `inputs`. + /// + /// + /// + public static Tensor accumulate_nv2(Tensors inputs, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AccumulateNV2", name) { args = new object[] { inputs }, attrs = new Dictionary() { ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return accumulate_nv2_eager_fallback(inputs, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["inputs"] = inputs; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("AccumulateNV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("AccumulateNV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor accumulate_nv2_eager_fallback(Tensors inputs, Shape shape, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(inputs); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", inputs.Length, "T", inputs.dtype, "shape", shape }; + var _result = _execute.execute("AccumulateNV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AccumulateNV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes acos of x element-wise. + /// + /// + /// + /// + /// Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`. + /// + /// Input range is `[-1, 1]` and the output has a range of `[0, pi]`. + /// + /// + /// + /// + /// + public static Tensor acos(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Acos", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return acos_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Acos", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Acos", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor acos_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Acos", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Acos", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes inverse hyperbolic cosine of x element-wise. + /// + /// + /// + /// Given an input tensor, the function computes inverse hyperbolic cosine of every element. + /// Input range is `[1, inf]`. It returns `nan` if the input lies outside the range. + /// + /// ```python + /// x = tf.constant([-2, -0.5, 1, 1.2, 200, 10000, float("inf")]) + /// tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] + /// ``` + /// + /// + /// + /// + public static Tensor acosh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Acosh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return acosh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Acosh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Acosh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor acosh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Acosh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Acosh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x + y element-wise. + /// + /// + /// + /// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor. + /// + /// Both input and output have a range `(-inf, inf)`. + /// + /// + /// + /// + /// + /// + public static Tensor add(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Add", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return add_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Add", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Add", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor add_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Add", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Add", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Add all input tensors element wise. + /// + /// + /// + /// Inputs must be of same size and shape. + /// + /// ```python + /// x = [9, 7, 10] + /// tf.math.add_n(x) ==> 26 + /// ``` + /// + /// + /// + /// + public static Tensor add_n(Tensors inputs, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AddN", name) { args = new object[] { inputs }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return add_n_eager_fallback(inputs, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["inputs"] = inputs; + var _op = tf.OpDefLib._apply_op_helper("AddN", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "N", _op._get_attr_int("N"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("AddN", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor add_n_eager_fallback(Tensors inputs, string name, Context ctx) + { + List _inputs_flat_list = new(); + _inputs_flat_list.AddRange(inputs); + var _inputs_flat = _inputs_flat_list.ToArray(); + object[] _attrs = new object[] { "N", inputs.Length, "T", inputs.dtype }; + var _result = _execute.execute("AddN", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AddN", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x + y element-wise. + /// + /// + /// + /// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor add_v2(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AddV2", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return add_v2_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("AddV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("AddV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor add_v2_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("AddV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AddV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the "logical and" of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor all(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "All", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return all_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("All", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("All", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor all_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("All", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("All", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the argument of a complex number. + /// + /// + /// + /// Given a tensor `input` of complex numbers, this operation returns a tensor of + /// type `float` that is the argument of each element in `input`. All elements in + /// `input` must be complex numbers of the form \(a + bj\), where *a* + /// is the real part and *b* is the imaginary part. + /// + /// The argument returned by this operation is of the form \(atan2(b, a)\). + /// + /// For example: + /// + /// ``` + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.angle(input) ==> [2.0132, 1.056] + /// ``` + /// + /// @compatibility(numpy) + /// Equivalent to np.angle. + /// @end_compatibility + /// + /// + /// + /// + /// + public static Tensor angle(Tensor input, TF_DataType Tout = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Angle", name) { args = new object[] { input }, attrs = new Dictionary() { ["Tout"] = Tout } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return angle_eager_fallback(input, Tout: Tout, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["Tout"] = Tout; + var _op = tf.OpDefLib._apply_op_helper("Angle", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tout", _op._get_attr_type("Tout") }; + _execute.record_gradient("Angle", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor angle_eager_fallback(Tensor input, TF_DataType Tout, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "Tout", Tout }; + var _result = _execute.execute("Angle", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Angle", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the "logical or" of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor any(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Any", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return any_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Any", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Any", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor any_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Any", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Any", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of abs(x-y) < tolerance element-wise. + /// + /// + /// + /// + /// + public static Tensor approximate_equal(Tensor x, Tensor y, float tolerance = 1E-05f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ApproximateEqual", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["tolerance"] = tolerance } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return approximate_equal_eager_fallback(x, y, tolerance: tolerance, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["tolerance"] = tolerance; + var _op = tf.OpDefLib._apply_op_helper("ApproximateEqual", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "tolerance", _op.get_attr("tolerance") }; + _execute.record_gradient("ApproximateEqual", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor approximate_equal_eager_fallback(Tensor x, Tensor y, float tolerance, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "tolerance", tolerance }; + var _result = _execute.execute("ApproximateEqual", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ApproximateEqual", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the index with the largest value across dimensions of a tensor. + /// + /// + /// + /// Note that in case of ties the identity of the return value is not guaranteed. + /// + /// Usage: + /// ```python + /// import tensorflow as tf + /// a = [1, 10, 26.9, 2.8, 166.32, 62.3] + /// b = tf.math.argmax(input = a) + /// c = tf.keras.backend.eval(b) + /// # c = 4 + /// # here a[4] = 166.32 which is the largest element of a across axis 0 + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor arg_max(Tensor input, Tensor dimension, TF_DataType output_type = TF_DataType.TF_INT64, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ArgMax", name) { args = new object[] { input, dimension }, attrs = new Dictionary() { ["output_type"] = output_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return arg_max_eager_fallback(input, dimension, output_type: output_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["dimension"] = dimension; + keywords["output_type"] = output_type; + var _op = tf.OpDefLib._apply_op_helper("ArgMax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "output_type", _op._get_attr_type("output_type") }; + _execute.record_gradient("ArgMax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor arg_max_eager_fallback(Tensor input, Tensor dimension, TF_DataType output_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, dimension }; + object[] _attrs = new object[] { "T", input.dtype, "Tidx", dimension.dtype, "output_type", output_type }; + var _result = _execute.execute("ArgMax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ArgMax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the index with the smallest value across dimensions of a tensor. + /// + /// + /// + /// Note that in case of ties the identity of the return value is not guaranteed. + /// + /// Usage: + /// ```python + /// import tensorflow as tf + /// a = [1, 10, 26.9, 2.8, 166.32, 62.3] + /// b = tf.math.argmin(input = a) + /// c = tf.keras.backend.eval(b) + /// # c = 0 + /// # here a[0] = 1 which is the smallest element of a across axis 0 + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor arg_min(Tensor input, Tensor dimension, TF_DataType output_type = TF_DataType.TF_INT64, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ArgMin", name) { args = new object[] { input, dimension }, attrs = new Dictionary() { ["output_type"] = output_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return arg_min_eager_fallback(input, dimension, output_type: output_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["dimension"] = dimension; + keywords["output_type"] = output_type; + var _op = tf.OpDefLib._apply_op_helper("ArgMin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "output_type", _op._get_attr_type("output_type") }; + _execute.record_gradient("ArgMin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor arg_min_eager_fallback(Tensor input, Tensor dimension, TF_DataType output_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, dimension }; + object[] _attrs = new object[] { "T", input.dtype, "Tidx", dimension.dtype, "output_type", output_type }; + var _result = _execute.execute("ArgMin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ArgMin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the trignometric inverse sine of x element-wise. + /// + /// + /// + /// The `tf.math.asin` operation returns the inverse of `tf.math.sin`, such that + /// if `y = tf.math.sin(x)` then, `x = tf.math.asin(y)`. + /// + /// **Note**: The output of `tf.math.asin` will lie within the invertible range + /// of sine, i.e [-pi/2, pi/2]. + /// + /// For example: + /// + /// ```python + /// # Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] + /// x = tf.constant([1.047, 0.785]) + /// y = tf.math.sin(x) # [0.8659266, 0.7068252] + /// + /// tf.math.asin(y) # [1.047, 0.785] = x + /// ``` + /// + /// + /// + /// + /// + public static Tensor asin(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Asin", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return asin_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Asin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Asin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor asin_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Asin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Asin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes inverse hyperbolic sine of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes inverse hyperbolic sine + /// for every element in the tensor. Both input and output has a range of + /// `[-inf, inf]`. + /// + /// ```python + /// x = tf.constant([-float("inf"), -2, -0.5, 1, 1.2, 200, 10000, float("inf")]) + /// tf.math.asinh(x) ==> [-inf -1.4436355 -0.4812118 0.8813736 1.0159732 5.991471 9.903487 inf] + /// ``` + /// + /// + /// + /// + public static Tensor asinh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Asinh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return asinh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Asinh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Asinh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor asinh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Asinh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Asinh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the trignometric inverse tangent of x element-wise. + /// + /// + /// + /// The `tf.math.atan` operation returns the inverse of `tf.math.tan`, such that + /// if `y = tf.math.tan(x)` then, `x = tf.math.atan(y)`. + /// + /// **Note**: The output of `tf.math.atan` will lie within the invertible range + /// of tan, i.e (-pi/2, pi/2). + /// + /// For example: + /// + /// ```python + /// # Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] + /// x = tf.constant([1.047, 0.785]) + /// y = tf.math.tan(x) # [1.731261, 0.99920404] + /// + /// tf.math.atan(y) # [1.047, 0.785] = x + /// ``` + /// + /// + /// + /// + /// + public static Tensor atan(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Atan", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return atan_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Atan", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Atan", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor atan_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Atan", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Atan", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. + /// + /// + /// + /// This is the angle \( heta in [-pi, pi] \) such that + /// \[ x = r cos( heta) \] + /// and + /// \[ y = r sin( heta) \] + /// where \(r = sqrt{x^2 + y^2} \). + /// + /// For example: + /// + /// >>> x = [1., 1.] + /// >>> y = [1., -1.] + /// >>> print((tf.math.atan2(y,x) * (180 / np.pi)).numpy()) + /// [ 45. -45.] + /// + /// + /// + /// + /// + /// + /// + public static Tensor atan2(Tensor y, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Atan2", name) { args = new object[] { y, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return atan2_eager_fallback(y, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Atan2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Atan2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor atan2_eager_fallback(Tensor y, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, x }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("Atan2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Atan2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes inverse hyperbolic tangent of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes inverse hyperbolic tangent + /// for every element in the tensor. Input range is `[-1,1]` and output range is + /// `[-inf, inf]`. If input is `-1`, output will be `-inf` and if the + /// input is `1`, output will be `inf`. Values outside the range will have + /// `nan` as output. + /// + /// ```python + /// x = tf.constant([-float("inf"), -1, -0.5, 1, 0, 0.5, 10, float("inf")]) + /// tf.math.atanh(x) ==> [nan -inf -0.54930615 inf 0. 0.54930615 nan nan] + /// ``` + /// + /// + /// + /// + public static Tensor atanh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Atanh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return atanh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Atanh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Atanh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor atanh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Atanh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Atanh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Multiplies slices of two tensors in batches. + /// + /// + /// + /// Multiplies all slices of `Tensor` `x` and `y` (each slice can be + /// viewed as an element of a batch), and arranges the individual results + /// in a single output tensor of the same batch size. Each of the + /// individual slices can optionally be adjointed (to adjoint a matrix + /// means to transpose and conjugate it) before multiplication by setting + /// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + /// + /// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` + /// and `[..., r_y, c_y]`. + /// + /// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + /// + /// r_o = c_x if adj_x else r_x + /// c_o = r_y if adj_y else c_y + /// + /// It is computed as: + /// + /// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + /// + /// + /// + /// + /// + /// + /// If `True`, adjoint the slices of `x`. Defaults to `False`. + /// + /// + /// + /// + /// If `True`, adjoint the slices of `y`. Defaults to `False`. + /// + /// + /// + public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string? name = null) { - public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatMul", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["adj_x"] = adj_x, ["adj_y"] = adj_y } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_mat_mul_eager_fallback(x, y, adj_x: adj_x, adj_y: adj_y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["adj_x"] = adj_x; + keywords["adj_y"] = adj_y; + var _op = tf.OpDefLib._apply_op_helper("BatchMatMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "adj_x", _op._get_attr_bool("adj_x"), "adj_y", _op._get_attr_bool("adj_y") }; + _execute.record_gradient("BatchMatMul", _op.inputs, _attrs, _result); + } + return _result[0]; + } - public static Tensor add(Tensor a, Tensor b) + public static Tensor batch_mat_mul_eager_fallback(Tensor x, Tensor y, bool adj_x, bool adj_y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "adj_x", adj_x, "adj_y", adj_y }; + var _result = _execute.execute("BatchMatMul", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatMul", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Multiplies slices of two tensors in batches. + /// + /// + /// + /// Multiplies all slices of `Tensor` `x` and `y` (each slice can be + /// viewed as an element of a batch), and arranges the individual results + /// in a single output tensor of the same batch size. Each of the + /// individual slices can optionally be adjointed (to adjoint a matrix + /// means to transpose and conjugate it) before multiplication by setting + /// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + /// + /// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` + /// and `[..., r_y, c_y]`. + /// + /// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + /// + /// r_o = c_x if adj_x else r_x + /// c_o = r_y if adj_y else c_y + /// + /// It is computed as: + /// + /// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + /// + /// *NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More + /// about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + /// + /// + /// + /// + /// + /// + /// + /// If `True`, adjoint the slices of `x`. Defaults to `False`. + /// + /// + /// + /// + /// If `True`, adjoint the slices of `y`. Defaults to `False`. + /// + /// + /// + public static Tensor batch_mat_mul_v2(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatMulV2", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["adj_x"] = adj_x, ["adj_y"] = adj_y } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_mat_mul_v2_eager_fallback(x, y, adj_x: adj_x, adj_y: adj_y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["adj_x"] = adj_x; + keywords["adj_y"] = adj_y; + var _op = tf.OpDefLib._apply_op_helper("BatchMatMulV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) { - var keywords = new Dictionary(); - keywords.Add("x", a); - keywords.Add("y", b); + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "adj_x", _op._get_attr_bool("adj_x"), "adj_y", _op._get_attr_bool("adj_y") }; + _execute.record_gradient("BatchMatMulV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); + public static Tensor batch_mat_mul_v2_eager_fallback(Tensor x, Tensor y, bool adj_x, bool adj_y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "adj_x", adj_x, "adj_y", adj_y }; + var _result = _execute.execute("BatchMatMulV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatMulV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Multiplies slices of two tensors in batches. + /// + /// + /// + /// Multiplies all slices of `Tensor` `x` and `y` (each slice can be + /// viewed as an element of a batch), and arranges the individual results + /// in a single output tensor of the same batch size. Each of the + /// individual slices can optionally be adjointed (to adjoint a matrix + /// means to transpose and conjugate it) before multiplication by setting + /// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + /// + /// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` + /// and `[..., r_y, c_y]`. + /// + /// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + /// + /// r_o = c_x if adj_x else r_x + /// c_o = r_y if adj_y else c_y + /// + /// It is computed as: + /// + /// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + /// + /// *NOTE*: `BatchMatMulV3` supports broadcasting in the batch dimensions. More + /// about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + /// + /// + /// + /// + /// + /// + /// + /// If not spcified, Tout is the same type to input type. + /// + /// + /// + /// + /// If `True`, adjoint the slices of `x`. Defaults to `False`. + /// + /// + /// + /// + /// If `True`, adjoint the slices of `y`. Defaults to `False`. + /// + /// + /// + public static Tensor batch_mat_mul_v3(Tensor x, Tensor y, TF_DataType Tout, bool adj_x = false, bool adj_y = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchMatMulV3", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["Tout"] = Tout, ["adj_x"] = adj_x, ["adj_y"] = adj_y } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_mat_mul_v3_eager_fallback(x, y, Tout: Tout, adj_x: adj_x, adj_y: adj_y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["Tout"] = Tout; + keywords["adj_x"] = adj_x; + keywords["adj_y"] = adj_y; + var _op = tf.OpDefLib._apply_op_helper("BatchMatMulV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Ta", _op._get_attr_type("Ta"), "Tb", _op._get_attr_type("Tb"), "Tout", _op._get_attr_type("Tout"), "adj_x", _op._get_attr_bool("adj_x"), "adj_y", _op._get_attr_bool("adj_y") }; + _execute.record_gradient("BatchMatMulV3", _op.inputs, _attrs, _result); + } + return _result[0]; + } - var tensor = new Tensor(_op, 0, TF_DataType.TF_FLOAT); + public static Tensor batch_mat_mul_v3_eager_fallback(Tensor x, Tensor y, TF_DataType Tout, bool adj_x, bool adj_y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "Ta", x.dtype, "Tb", y.dtype, "Tout", Tout, "adj_x", adj_x, "adj_y", adj_y }; + var _result = _execute.execute("BatchMatMulV3", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchMatMulV3", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). + /// + /// + /// + /// The regularized incomplete beta integral is defined as: + /// + /// + /// \(I_x(a, b) = rac{B(x; a, b)}{B(a, b)}\) + /// + /// where + /// + /// + /// \(B(x; a, b) = int_0^x t^{a-1} (1 - t)^{b-1} dt\) + /// + /// + /// is the incomplete beta function and \(B(a, b)\) is the *complete* + /// beta function. + /// + /// + /// + /// + /// + /// + public static Tensor betainc(Tensor a, Tensor b, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Betainc", name) { args = new object[] { a, b, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return betainc_eager_fallback(a, b, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Betainc", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Betainc", _op.inputs, _attrs, _result); + } + return _result[0]; + } - return tensor; + public static Tensor betainc_eager_fallback(Tensor a, Tensor b, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, x }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("Betainc", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Betainc", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Counts the number of occurrences of each value in an integer array. + /// + /// + /// + /// Outputs a vector with length `size` and the same dtype as `weights`. If + /// `weights` are empty, then index `i` stores the number of times the value `i` is + /// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of + /// the value in `weights` at each index where the corresponding value in `arr` is + /// `i`. + /// + /// Values in `arr` outside of the range [0, size) are ignored. + /// + /// + /// + /// + /// + /// + public static Tensor bincount(Tensor arr, Tensor size, Tensor weights, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Bincount", name) { args = new object[] { arr, size, weights }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bincount_eager_fallback(arr, size, weights, name: name, ctx: _ctx); + } + catch (Exception) + { + } } + Dictionary keywords = new(); + keywords["arr"] = arr; + keywords["size"] = size; + keywords["weights"] = weights; + var _op = tf.OpDefLib._apply_op_helper("Bincount", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Bincount", _op.inputs, _attrs, _result); + } + return _result[0]; + } - private static OpDefLibrary _InitOpDefLibrary() + public static Tensor bincount_eager_fallback(Tensor arr, Tensor size, Tensor weights, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { arr, size, weights }; + object[] _attrs = new object[] { "T", weights.dtype }; + var _result = _execute.execute("Bincount", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Bincount", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Bucketizes 'input' based on 'boundaries'. + /// + /// + /// + /// For example, if the inputs are + /// boundaries = [0, 10, 100] + /// input = [[-5, 10000] + /// [150, 10] + /// [5, 100]] + /// + /// then the output will be + /// output = [[0, 3] + /// [3, 2] + /// [1, 3]] + /// + /// + /// + /// + /// + /// A sorted list of floats gives the boundary of the buckets. + /// + /// + /// + public static Tensor bucketize(Tensor input, float[] boundaries, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) { - // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); - var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin"); - var op_list = OpList.Parser.ParseFrom(bytes); - var op_def_lib = new OpDefLibrary(); - op_def_lib.add_op_list(op_list); + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Bucketize", name) { args = new object[] { input }, attrs = new Dictionary() { ["boundaries"] = boundaries } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bucketize_eager_fallback(input, boundaries: boundaries, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["boundaries"] = boundaries; + var _op = tf.OpDefLib._apply_op_helper("Bucketize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "boundaries", _op.get_attr("boundaries") }; + _execute.record_gradient("Bucketize", _op.inputs, _attrs, _result); + } + return _result[0]; + } - return op_def_lib; + public static Tensor bucketize_eager_fallback(Tensor input, float[] boundaries, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "boundaries", boundaries }; + var _result = _execute.execute("Bucketize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Bucketize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Cast x of type SrcT to y of DstT. + /// + /// + /// + /// + /// + public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cast", name) { args = new object[] { x }, attrs = new Dictionary() { ["DstT"] = DstT, ["Truncate"] = Truncate } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cast_eager_fallback(x, DstT: DstT, Truncate: Truncate, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["DstT"] = DstT; + keywords["Truncate"] = Truncate; + var _op = tf.OpDefLib._apply_op_helper("Cast", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "SrcT", _op._get_attr_type("SrcT"), "DstT", _op._get_attr_type("DstT"), "Truncate", _op._get_attr_bool("Truncate") }; + _execute.record_gradient("Cast", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cast_eager_fallback(Tensor x, TF_DataType DstT, bool Truncate, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "SrcT", x.dtype, "DstT", DstT, "Truncate", Truncate }; + var _result = _execute.execute("Cast", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cast", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns element-wise smallest integer not less than x. + /// + /// + /// + public static Tensor ceil(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Ceil", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return ceil_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Ceil", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Ceil", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ceil_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Ceil", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Ceil", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// + /// Given a tensor `t`, this operation returns a tensor of the same type and + /// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. + /// Any values less than `clip_value_min` are set to `clip_value_min`. Any values + /// greater than `clip_value_max` are set to `clip_value_max`. + /// + /// + /// + /// + /// + /// + public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ClipByValue", name) { args = new object[] { t, clip_value_min, clip_value_max }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return clip_by_value_eager_fallback(t, clip_value_min, clip_value_max, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["t"] = t; + keywords["clip_value_min"] = clip_value_min; + keywords["clip_value_max"] = clip_value_max; + var _op = tf.OpDefLib._apply_op_helper("ClipByValue", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ClipByValue", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor clip_by_value_eager_fallback(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { t, clip_value_min, clip_value_max }; + object[] _attrs = new object[] { "T", t.dtype }; + var _result = _execute.execute("ClipByValue", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ClipByValue", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Converts two real numbers to a complex number. + /// + /// + /// + /// Given a tensor `real` representing the real part of a complex number, and a + /// tensor `imag` representing the imaginary part of a complex number, this + /// operation returns complex numbers elementwise of the form \(a + bj\), where + /// *a* represents the `real` part and *b* represents the `imag` part. + /// + /// The input tensors `real` and `imag` must have the same shape. + /// + /// For example: + /// + /// ``` + /// # tensor 'real' is [2.25, 3.25] + /// # tensor `imag` is [4.75, 5.75] + /// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor complex(Tensor real, Tensor imag, TF_DataType Tout = TF_DataType.TF_COMPLEX64, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Complex", name) { args = new object[] { real, imag }, attrs = new Dictionary() { ["Tout"] = Tout } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return complex_eager_fallback(real, imag, Tout: Tout, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["real"] = real; + keywords["imag"] = imag; + keywords["Tout"] = Tout; + var _op = tf.OpDefLib._apply_op_helper("Complex", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tout", _op._get_attr_type("Tout") }; + _execute.record_gradient("Complex", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor complex_eager_fallback(Tensor real, Tensor imag, TF_DataType Tout, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { real, imag }; + object[] _attrs = new object[] { "T", real.dtype, "Tout", Tout }; + var _result = _execute.execute("Complex", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Complex", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the complex absolute value of a tensor. + /// + /// + /// + /// Given a tensor `x` of complex numbers, this operation returns a tensor of type + /// `float` or `double` that is the absolute value of each element in `x`. All + /// elements in `x` must be complex numbers of the form \(a + bj\). The absolute + /// value is computed as \( sqrt{a^2 + b^2}\). + /// + /// For example: + /// + /// >>> x = tf.complex(3.0, 4.0) + /// >>> print((tf.raw_ops.ComplexAbs(x=x, Tout=tf.dtypes.float32, name=None)).numpy()) + /// 5.0 + /// + /// + /// + /// + /// + /// + public static Tensor complex_abs(Tensor x, TF_DataType Tout = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ComplexAbs", name) { args = new object[] { x }, attrs = new Dictionary() { ["Tout"] = Tout } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return complex_abs_eager_fallback(x, Tout: Tout, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["Tout"] = Tout; + var _op = tf.OpDefLib._apply_op_helper("ComplexAbs", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tout", _op._get_attr_type("Tout") }; + _execute.record_gradient("ComplexAbs", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor complex_abs_eager_fallback(Tensor x, TF_DataType Tout, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype, "Tout", Tout }; + var _result = _execute.execute("ComplexAbs", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ComplexAbs", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the complex conjugate of a complex number. + /// + /// + /// + /// Given a tensor `input` of complex numbers, this operation returns a tensor of + /// complex numbers that are the complex conjugate of each element in `input`. The + /// complex numbers in `input` must be of the form \(a + bj\), where *a* is the + /// real part and *b* is the imaginary part. + /// + /// The complex conjugate returned by this operation is of the form \(a - bj\). + /// + /// For example: + /// + /// ``` + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] + /// ``` + /// + /// + /// + /// + public static Tensor conj(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conj", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conj_eager_fallback(input, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("Conj", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Conj", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conj_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("Conj", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conj", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes cos of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes cosine of every + /// element in the tensor. Input range is `(-inf, inf)` and + /// output range is `[-1,1]`. If input lies outside the boundary, `nan` + /// is returned. + /// + /// ```python + /// x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")]) + /// tf.math.cos(x) ==> [nan -0.91113025 0.87758255 0.5403023 0.36235774 0.48718765 -0.95215535 nan] + /// ``` + /// + /// + /// + /// + public static Tensor cos(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cos", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cos_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Cos", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Cos", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cos_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Cos", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cos", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes hyperbolic cosine of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes hyperbolic cosine of every + /// element in the tensor. Input range is `[-inf, inf]` and output range + /// is `[1, inf]`. + /// + /// ```python + /// x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + /// tf.math.cosh(x) ==> [inf 4.0515420e+03 1.1276259e+00 1.5430807e+00 1.8106556e+00 3.7621956e+00 1.1013233e+04 inf] + /// ``` + /// + /// + /// + /// + public static Tensor cosh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cosh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cosh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Cosh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Cosh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cosh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Cosh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cosh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the pairwise cross product. + /// + /// + /// + /// `a` and `b` must be the same shape; they can either be simple 3-element vectors, + /// or any shape where the innermost dimension is 3. In the latter case, each pair + /// of corresponding 3-element vectors is cross-multiplied independently. + /// + /// + /// + /// + /// + public static Tensor cross(Tensor a, Tensor b, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cross", name) { args = new object[] { a, b }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cross_eager_fallback(a, b, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + var _op = tf.OpDefLib._apply_op_helper("Cross", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Cross", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cross_eager_fallback(Tensor a, Tensor b, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("Cross", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cross", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the cumulative product of the tensor `x` along `axis`. + /// + /// + /// + /// By default, this op performs an inclusive cumprod, which means that the first + /// element of the input is identical to the first element of the output: + /// + /// ```python + /// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] + /// ``` + /// + /// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is + /// performed instead: + /// + /// ```python + /// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] + /// ``` + /// + /// By setting the `reverse` kwarg to `True`, the cumprod is performed in the + /// opposite direction: + /// + /// ```python + /// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] + /// ``` + /// + /// This is more efficient than using separate `tf.reverse` ops. + /// + /// The `reverse` and `exclusive` kwargs can also be combined: + /// + /// ```python + /// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] + /// ``` + /// + /// + /// + /// + /// + /// + /// If `True`, perform exclusive cumprod. + /// + /// + /// + /// + /// A `bool` (default: False). + /// + /// + /// + public static Tensor cumprod(Tensor x, Tensor axis, bool exclusive = false, bool reverse = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cumprod", name) { args = new object[] { x, axis }, attrs = new Dictionary() { ["exclusive"] = exclusive, ["reverse"] = reverse } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cumprod_eager_fallback(x, axis, exclusive: exclusive, reverse: reverse, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["axis"] = axis; + keywords["exclusive"] = exclusive; + keywords["reverse"] = reverse; + var _op = tf.OpDefLib._apply_op_helper("Cumprod", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "exclusive", _op._get_attr_bool("exclusive"), "reverse", _op._get_attr_bool("reverse"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Cumprod", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cumprod_eager_fallback(Tensor x, Tensor axis, bool exclusive, bool reverse, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, axis }; + object[] _attrs = new object[] { "exclusive", exclusive, "reverse", reverse, "T", x.dtype, "Tidx", axis.dtype }; + var _result = _execute.execute("Cumprod", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cumprod", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the cumulative sum of the tensor `x` along `axis`. + /// + /// + /// + /// By default, this op performs an inclusive cumsum, which means that the first + /// element of the input is identical to the first element of the output: + /// + /// ```python + /// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] + /// ``` + /// + /// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is + /// performed instead: + /// + /// ```python + /// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] + /// ``` + /// + /// By setting the `reverse` kwarg to `True`, the cumsum is performed in the + /// opposite direction: + /// + /// ```python + /// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] + /// ``` + /// + /// This is more efficient than using separate `tf.reverse` ops. + /// + /// The `reverse` and `exclusive` kwargs can also be combined: + /// + /// ```python + /// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] + /// ``` + /// + /// + /// + /// + /// + /// + /// If `True`, perform exclusive cumsum. + /// + /// + /// + /// + /// A `bool` (default: False). + /// + /// + /// + public static Tensor cumsum(Tensor x, Tensor axis, bool exclusive = false, bool reverse = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Cumsum", name) { args = new object[] { x, axis }, attrs = new Dictionary() { ["exclusive"] = exclusive, ["reverse"] = reverse } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cumsum_eager_fallback(x, axis, exclusive: exclusive, reverse: reverse, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["axis"] = axis; + keywords["exclusive"] = exclusive; + keywords["reverse"] = reverse; + var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "exclusive", _op._get_attr_bool("exclusive"), "reverse", _op._get_attr_bool("reverse"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Cumsum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cumsum_eager_fallback(Tensor x, Tensor axis, bool exclusive, bool reverse, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, axis }; + object[] _attrs = new object[] { "exclusive", exclusive, "reverse", reverse, "T", x.dtype, "Tidx", axis.dtype }; + var _result = _execute.execute("Cumsum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Cumsum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the cumulative product of the tensor `x` along `axis`. + /// + /// + /// + /// By default, this op performs an inclusive cumulative log-sum-exp, + /// which means that the first + /// element of the input is identical to the first element of the output: + /// ```python + /// tf.math.cumulative_logsumexp([a, b, c]) # => [a, log(exp(a) + exp(b)), log(exp(a) + exp(b) + exp(c))] + /// ``` + /// + /// By setting the `exclusive` kwarg to `True`, an exclusive cumulative log-sum-exp is + /// performed instead: + /// ```python + /// tf.cumulative_logsumexp([a, b, c], exclusive=True) # => [-inf, a, log(exp(a) * exp(b))] + /// ``` + /// Note that the neutral element of the log-sum-exp operation is `-inf`, + /// however, for performance reasons, the minimal value representable by the + /// floating point type is used instead. + /// + /// By setting the `reverse` kwarg to `True`, the cumulative log-sum-exp is performed in the + /// opposite direction. + /// + /// + /// + /// + /// + /// + /// If `True`, perform exclusive cumulative log-sum-exp. + /// + /// + /// + /// + /// A `bool` (default: False). + /// + /// + /// + public static Tensor cumulative_logsumexp(Tensor x, Tensor axis, bool exclusive = false, bool reverse = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CumulativeLogsumexp", name) { args = new object[] { x, axis }, attrs = new Dictionary() { ["exclusive"] = exclusive, ["reverse"] = reverse } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return cumulative_logsumexp_eager_fallback(x, axis, exclusive: exclusive, reverse: reverse, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["axis"] = axis; + keywords["exclusive"] = exclusive; + keywords["reverse"] = reverse; + var _op = tf.OpDefLib._apply_op_helper("CumulativeLogsumexp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "exclusive", _op._get_attr_bool("exclusive"), "reverse", _op._get_attr_bool("reverse"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("CumulativeLogsumexp", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor cumulative_logsumexp_eager_fallback(Tensor x, Tensor axis, bool exclusive, bool reverse, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, axis }; + object[] _attrs = new object[] { "exclusive", exclusive, "reverse", reverse, "T", x.dtype, "Tidx", axis.dtype }; + var _result = _execute.execute("CumulativeLogsumexp", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("CumulativeLogsumexp", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Counts the number of occurrences of each value in an integer array. + /// + /// + /// + /// Outputs a vector with length `size` and the same dtype as `weights`. If + /// `weights` are empty, then index `i` stores the number of times the value `i` is + /// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of + /// the value in `weights` at each index where the corresponding value in `arr` is + /// `i`. + /// + /// Values in `arr` outside of the range [0, size) are ignored. + /// + /// + /// + /// + /// + /// + /// + /// bool; Whether the kernel should count the appearance or number of occurrences. + /// + /// + /// + public static Tensor dense_bincount(Tensor input, Tensor size, Tensor weights, bool binary_output = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DenseBincount", name) { args = new object[] { input, size, weights }, attrs = new Dictionary() { ["binary_output"] = binary_output } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return dense_bincount_eager_fallback(input, size, weights, binary_output: binary_output, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["size"] = size; + keywords["weights"] = weights; + keywords["binary_output"] = binary_output; + var _op = tf.OpDefLib._apply_op_helper("DenseBincount", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx"), "T", _op._get_attr_type("T"), "binary_output", _op._get_attr_bool("binary_output") }; + _execute.record_gradient("DenseBincount", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor dense_bincount_eager_fallback(Tensor input, Tensor size, Tensor weights, bool binary_output, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, size, weights }; + object[] _attrs = new object[] { "Tidx", input.dtype, "T", weights.dtype, "binary_output", binary_output }; + var _result = _execute.execute("DenseBincount", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DenseBincount", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes Psi, the derivative of Lgamma (the log of the absolute value of + /// + /// + /// + /// `Gamma(x)`), element-wise. + /// + /// + /// + /// + public static Tensor digamma(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Digamma", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return digamma_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Digamma", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Digamma", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor digamma_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Digamma", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Digamma", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x / y element-wise. + /// + /// + /// + /// *NOTE*: `Div` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor div(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Div", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return div_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Div", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Div", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor div_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Div", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Div", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns 0 if the denominator is zero. + /// + /// + /// + /// + /// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor div_no_nan(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DivNoNan", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return div_no_nan_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("DivNoNan", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("DivNoNan", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor div_no_nan_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("DivNoNan", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DivNoNan", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x == y) element-wise. + /// + /// + /// + /// *NOTE*: `Equal` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// ```python + /// x = tf.constant([2, 4]) + /// y = tf.constant(2) + /// tf.math.equal(x, y) ==> array([True, False]) + /// + /// x = tf.constant([2, 4]) + /// y = tf.constant([2, 4]) + /// tf.math.equal(x, y) ==> array([True, True]) + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor equal(Tensor x, Tensor y, bool incompatible_shape_error = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Equal", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["incompatible_shape_error"] = incompatible_shape_error } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return equal_eager_fallback(x, y, incompatible_shape_error: incompatible_shape_error, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["incompatible_shape_error"] = incompatible_shape_error; + var _op = tf.OpDefLib._apply_op_helper("Equal", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "incompatible_shape_error", _op._get_attr_bool("incompatible_shape_error") }; + _execute.record_gradient("Equal", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor equal_eager_fallback(Tensor x, Tensor y, bool incompatible_shape_error, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "incompatible_shape_error", incompatible_shape_error }; + var _result = _execute.execute("Equal", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Equal", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the [Gauss error function](https://en.wikipedia.org/wiki/Error_function) of `x` element-wise. In statistics, for non-negative values of $x$, the error function has the following interpretation: for a random variable $Y$ that is normally distributed with mean 0 and variance $1/\sqrt{2}$, $erf(x)$ is the probability that $Y$ falls in the range $[−x, x]$. + /// + /// + /// + public static Tensor erf(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Erf", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return erf_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Erf", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Erf", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor erf_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Erf", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Erf", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the complementary error function of `x` element-wise. + /// + /// + /// + public static Tensor erfc(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Erfc", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return erfc_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Erfc", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Erfc", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor erfc_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Erfc", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Erfc", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + public static Tensor erfinv(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Erfinv", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return erfinv_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Erfinv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Erfinv", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor erfinv_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Erfinv", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Erfinv", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the euclidean norm of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor euclidean_norm(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EuclideanNorm", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return euclidean_norm_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("EuclideanNorm", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("EuclideanNorm", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor euclidean_norm_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("EuclideanNorm", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EuclideanNorm", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes exponential of x element-wise. \\(y = e^x\\). + /// + /// + /// + /// This function computes the exponential of every element in the input tensor. + /// i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor. + /// `e` denotes Euler's number and is approximately equal to 2.718281. + /// Output is positive for any real input. + /// + /// ```python + /// x = tf.constant(2.0) + /// tf.math.exp(x) ==> 7.389056 + /// + /// x = tf.constant([2.0, 8.0]) + /// tf.math.exp(x) ==> array([7.389056, 2980.958], dtype=float32) + /// ``` + /// + /// For complex numbers, the exponential value is calculated as follows: + /// + /// ``` + /// e^(x+iy) = e^x * e^iy = e^x * (cos y + i sin y) + /// ``` + /// + /// Let's consider complex number 1+1j as an example. + /// e^1 * (cos 1 + i sin 1) = 2.7182818284590 * (0.54030230586+0.8414709848j) + /// + /// ```python + /// x = tf.constant(1 + 1j) + /// tf.math.exp(x) ==> 1.4686939399158851+2.2873552871788423j + /// ``` + /// + /// + /// + /// + public static Tensor exp(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Exp", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return exp_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Exp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Exp", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor exp_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Exp", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Exp", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes `exp(x) - 1` element-wise. + /// + /// + /// + /// i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. + /// `e` denotes Euler's number and is approximately equal to 2.718281. + /// + /// ```python + /// x = tf.constant(2.0) + /// tf.math.expm1(x) ==> 6.389056 + /// + /// x = tf.constant([2.0, 8.0]) + /// tf.math.expm1(x) ==> array([6.389056, 2979.958], dtype=float32) + /// + /// x = tf.constant(1 + 1j) + /// tf.math.expm1(x) ==> (0.46869393991588515+2.2873552871788423j) + /// ``` + /// + /// + /// + /// + public static Tensor expm1(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Expm1", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return expm1_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Expm1", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Expm1", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor expm1_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Expm1", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Expm1", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns element-wise largest integer not greater than x. + /// + /// + /// + public static Tensor floor(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Floor", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return floor_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Floor", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Floor", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor floor_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Floor", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Floor", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x // y element-wise. + /// + /// + /// + /// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor floor_div(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FloorDiv", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return floor_div_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("FloorDiv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("FloorDiv", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor floor_div_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("FloorDiv", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FloorDiv", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns element-wise remainder of division. + /// + /// + /// + /// This follows Python semantics in that the + /// result here is consistent with a flooring divide. E.g. + /// `floor(x / y) * y + floormod(x, y) = x`, regardless of the signs of x and y. + /// + /// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor floor_mod(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FloorMod", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return floor_mod_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("FloorMod", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("FloorMod", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor floor_mod_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("FloorMod", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FloorMod", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x > y) element-wise. + /// + /// + /// + /// *NOTE*: `Greater` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5, 2, 5]) + /// tf.math.greater(x, y) ==> [False, True, True] + /// + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5]) + /// tf.math.greater(x, y) ==> [False, False, True] + /// ``` + /// + /// + /// + /// + /// + public static Tensor greater(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Greater", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return greater_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Greater", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Greater", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor greater_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Greater", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Greater", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x >= y) element-wise. + /// + /// + /// + /// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5, 4, 6, 7]) + /// y = tf.constant([5, 2, 5, 10]) + /// tf.math.greater_equal(x, y) ==> [True, True, True, False] + /// + /// x = tf.constant([5, 4, 6, 7]) + /// y = tf.constant([5]) + /// tf.math.greater_equal(x, y) ==> [True, False, True, True] + /// ``` + /// + /// + /// + /// + /// + public static Tensor greater_equal(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "GreaterEqual", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return greater_equal_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("GreaterEqual", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("GreaterEqual", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor greater_equal_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("GreaterEqual", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("GreaterEqual", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Return histogram of values. + /// + /// + /// + /// Given the tensor `values`, this operation returns a rank 1 histogram counting + /// the number of entries in `values` that fall into every bin. The bins are + /// equal width and determined by the arguments `value_range` and `nbins`. + /// + /// ```python + /// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + /// nbins = 5 + /// value_range = [0.0, 5.0] + /// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + /// + /// with tf.get_default_session() as sess: + /// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) + /// variables.global_variables_initializer().run() + /// sess.run(hist) => [2, 1, 1, 0, 2] + /// ``` + /// + /// + /// + /// + /// + /// + /// + public static Tensor histogram_fixed_width(Tensor values, Tensor value_range, Tensor nbins, TF_DataType dtype = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "HistogramFixedWidth", name) { args = new object[] { values, value_range, nbins }, attrs = new Dictionary() { ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return histogram_fixed_width_eager_fallback(values, value_range, nbins, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["values"] = values; + keywords["value_range"] = value_range; + keywords["nbins"] = nbins; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("HistogramFixedWidth", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("HistogramFixedWidth", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor histogram_fixed_width_eager_fallback(Tensor values, Tensor value_range, Tensor nbins, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { values, value_range, nbins }; + object[] _attrs = new object[] { "T", values.dtype, "dtype", dtype }; + var _result = _execute.execute("HistogramFixedWidth", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("HistogramFixedWidth", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the lower regularized incomplete Gamma function `P(a, x)`. + /// + /// + /// + /// The lower regularized incomplete Gamma function is defined as: + /// + /// + /// \(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\) + /// + /// where + /// + /// \(gamma(a, x) = \int_{0}^{x} t^{a-1} exp(-t) dt\) + /// + /// is the lower incomplete Gamma function. + /// + /// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete + /// Gamma function. + /// + /// + /// + /// + /// + public static Tensor igamma(Tensor a, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Igamma", name) { args = new object[] { a, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return igamma_eager_fallback(a, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Igamma", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Igamma", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor igamma_eager_fallback(Tensor a, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, x }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("Igamma", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Igamma", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient of `igamma(a, x)` wrt `a`. + /// + /// + /// + /// + public static Tensor igamma_grad_a(Tensor a, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IgammaGradA", name) { args = new object[] { a, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return igamma_grad_a_eager_fallback(a, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("IgammaGradA", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("IgammaGradA", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor igamma_grad_a_eager_fallback(Tensor a, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, x }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("IgammaGradA", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IgammaGradA", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the upper regularized incomplete Gamma function `Q(a, x)`. + /// + /// + /// + /// The upper regularized incomplete Gamma function is defined as: + /// + /// \(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\) + /// + /// where + /// + /// \(Gamma(a, x) = int_{x}^{infty} t^{a-1} exp(-t) dt\) + /// + /// is the upper incomplete Gamma function. + /// + /// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete + /// Gamma function. + /// + /// + /// + /// + /// + public static Tensor igammac(Tensor a, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Igammac", name) { args = new object[] { a, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return igammac_eager_fallback(a, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Igammac", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Igammac", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor igammac_eager_fallback(Tensor a, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, x }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("Igammac", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Igammac", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the imaginary part of a complex number. + /// + /// + /// + /// Given a tensor `input` of complex numbers, this operation returns a tensor of + /// type `float` that is the imaginary part of each element in `input`. All + /// elements in `input` must be complex numbers of the form \(a + bj\), where *a* + /// is the real part and *b* is the imaginary part returned by this operation. + /// + /// For example: + /// + /// ``` + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.imag(input) ==> [4.75, 5.75] + /// ``` + /// + /// + /// + /// + /// + public static Tensor imag(Tensor input, TF_DataType Tout = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Imag", name) { args = new object[] { input }, attrs = new Dictionary() { ["Tout"] = Tout } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return imag_eager_fallback(input, Tout: Tout, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["Tout"] = Tout; + var _op = tf.OpDefLib._apply_op_helper("Imag", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tout", _op._get_attr_type("Tout") }; + _execute.record_gradient("Imag", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor imag_eager_fallback(Tensor input, TF_DataType Tout, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "Tout", Tout }; + var _result = _execute.execute("Imag", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Imag", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the reciprocal of x element-wise. + /// + /// + /// + /// I.e., \(y = 1 / x\). + /// + /// + /// + /// + public static Tensor inv(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Inv", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return inv_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Inv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Inv", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor inv_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Inv", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Inv", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient for the inverse of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` + /// is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor inv_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InvGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return inv_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("InvGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InvGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor inv_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("InvGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InvGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns which elements of x are finite. + /// + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isfinite + /// @end_compatibility + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5.0, 4.8, 6.8, np.inf, np.nan]) + /// tf.math.is_finite(x) ==> [True, True, True, False, False] + /// ``` + /// + /// + /// + /// + public static Tensor is_finite(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IsFinite", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return is_finite_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("IsFinite", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("IsFinite", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor is_finite_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("IsFinite", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IsFinite", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns which elements of x are Inf. + /// + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isinf + /// @end_compatibility + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5.0, np.inf, 6.8, np.inf]) + /// tf.math.is_inf(x) ==> [False, True, False, True] + /// ``` + /// + /// + /// + /// + public static Tensor is_inf(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IsInf", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return is_inf_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("IsInf", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("IsInf", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor is_inf_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("IsInf", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IsInf", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns which elements of x are NaN. + /// + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isnan + /// @end_compatibility + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5.0, np.nan, 6.8, np.nan, np.inf]) + /// tf.math.is_nan(x) ==> [False, True, False, True, False] + /// ``` + /// + /// + /// + /// + public static Tensor is_nan(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IsNan", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return is_nan_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("IsNan", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("IsNan", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor is_nan_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("IsNan", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IsNan", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x < y) element-wise. + /// + /// + /// + /// *NOTE*: `Less` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5]) + /// tf.math.less(x, y) ==> [False, True, False] + /// + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5, 6, 7]) + /// tf.math.less(x, y) ==> [False, True, True] + /// ``` + /// + /// + /// + /// + /// + public static Tensor less(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Less", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return less_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Less", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Less", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor less_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Less", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Less", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x <= y) element-wise. + /// + /// + /// + /// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// Example: + /// + /// ```python + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5]) + /// tf.math.less_equal(x, y) ==> [True, True, False] + /// + /// x = tf.constant([5, 4, 6]) + /// y = tf.constant([5, 6, 6]) + /// tf.math.less_equal(x, y) ==> [True, True, True] + /// ``` + /// + /// + /// + /// + /// + public static Tensor less_equal(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LessEqual", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return less_equal_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("LessEqual", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("LessEqual", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor less_equal_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("LessEqual", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LessEqual", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the log of the absolute value of `Gamma(x)` element-wise. + /// + /// + /// + /// For positive numbers, this function computes log((input - 1)!) for every element in the tensor. + /// `lgamma(5) = log((5-1)!) = log(4!) = log(24) = 3.1780539` + /// + /// Example: + /// + /// ```python + /// x = tf.constant([0, 0.5, 1, 4.5, -4, -5.6]) + /// tf.math.lgamma(x) ==> [inf, 0.5723649, 0., 2.4537368, inf, -4.6477685] + /// ``` + /// + /// + /// + /// + public static Tensor lgamma(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Lgamma", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return lgamma_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Lgamma", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Lgamma", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor lgamma_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Lgamma", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Lgamma", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Generates values in an interval. + /// + /// + /// + /// A sequence of `num` evenly-spaced values are generated beginning at `start`. + /// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, + /// so that the last one is exactly `stop`. + /// + /// For example: + /// + /// ``` + /// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor lin_space(Tensor start, Tensor stop, Tensor num, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LinSpace", name) { args = new object[] { start, stop, num }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return lin_space_eager_fallback(start, stop, num, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["start"] = start; + keywords["stop"] = stop; + keywords["num"] = num; + var _op = tf.OpDefLib._apply_op_helper("LinSpace", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("LinSpace", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor lin_space_eager_fallback(Tensor start, Tensor stop, Tensor num, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { start, stop, num }; + object[] _attrs = new object[] { "T", start.dtype, "Tidx", num.dtype }; + var _result = _execute.execute("LinSpace", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LinSpace", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes natural logarithm of x element-wise. + /// + /// + /// + /// I.e., \(y = log_e x\). + /// + /// Example: + /// + /// ```python + /// x = tf.constant([0, 0.5, 1, 5]) + /// tf.math.log(x) ==> [-inf, -0.6931472, 0. , 1.609438] + /// ``` + /// + /// + /// + /// + public static Tensor log(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Log", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return log_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Log", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Log", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor log_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Log", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Log", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes natural logarithm of (1 + x) element-wise. + /// + /// + /// + /// I.e., \(y = log_e (1 + x)\). + /// + /// Example: + /// + /// ```python + /// x = tf.constant([0, 0.5, 1, 5]) + /// tf.math.log1p(x) ==> [0., 0.4054651, 0.6931472, 1.7917595] + /// ``` + /// + /// + /// + /// + public static Tensor log1p(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Log1p", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return log1p_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Log1p", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Log1p", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor log1p_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Log1p", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Log1p", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of x AND y element-wise. + /// + /// + /// + /// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor logical_and(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LogicalAnd", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return logical_and_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("LogicalAnd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("LogicalAnd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor logical_and_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("LogicalAnd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LogicalAnd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of `NOT x` element-wise. + /// + /// + /// + public static Tensor logical_not(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LogicalNot", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return logical_not_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("LogicalNot", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("LogicalNot", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor logical_not_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("LogicalNot", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LogicalNot", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of x OR y element-wise. + /// + /// + /// + /// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor logical_or(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LogicalOr", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return logical_or_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("LogicalOr", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("LogicalOr", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor logical_or_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("LogicalOr", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LogicalOr", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Multiply the matrix "a" by the matrix "b". + /// + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of + /// "a" (after being transposed if transpose_a is true) must match the + /// outer dimension of "b" (after being transposed if transposed_b is + /// true). + /// + /// *Note*: The default kernel implementation for MatMul on GPUs uses + /// cublas. + /// + /// + /// + /// + /// + /// + /// If true, "a" is transposed before multiplication. + /// + /// + /// + /// + /// If true, "b" is transposed before multiplication. + /// + /// + /// + public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mat_mul_eager_fallback(a, b, transpose_a: transpose_a, transpose_b: transpose_b, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + var _op = tf.OpDefLib._apply_op_helper("MatMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MatMul", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mat_mul_eager_fallback(Tensor a, Tensor b, bool transpose_a, bool transpose_b, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b }; + object[] _attrs = new object[] { "transpose_a", transpose_a, "transpose_b", transpose_b, "T", a.dtype }; + var _result = _execute.execute("MatMul", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MatMul", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the maximum of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor max(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Max", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Max", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Max", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Max", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Max", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the max of x and y (i.e. x > y ? x : y) element-wise. + /// + /// + /// + /// *NOTE*: `Maximum` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor maximum(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Maximum", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return maximum_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Maximum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Maximum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor maximum_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Maximum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Maximum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the mean of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor mean(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Mean", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mean_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Mean", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Mean", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mean_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Mean", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Mean", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the minimum of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor min(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Min", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return min_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Min", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Min", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor min_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Min", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Min", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the min of x and y (i.e. x < y ? x : y) element-wise. + /// + /// + /// + /// *NOTE*: `Minimum` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor minimum(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Minimum", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return minimum_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Minimum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Minimum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor minimum_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Minimum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Minimum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns element-wise remainder of division. This emulates C semantics in that + /// + /// + /// + /// the result here is consistent with a truncating divide. E.g. + /// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. + /// + /// *NOTE*: `Mod` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor mod(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Mod", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mod_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Mod", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Mod", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mod_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Mod", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Mod", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x * y element-wise. + /// + /// + /// + /// *NOTE*: `Mul` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor mul(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Mul", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mul_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Mul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Mul", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mul_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Mul", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Mul", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. + /// + /// + /// + /// *NOTE*: `MulNoNan` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor mul_no_nan(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MulNoNan", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mul_no_nan_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("MulNoNan", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("MulNoNan", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mul_no_nan_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("MulNoNan", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MulNoNan", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + public static Tensor ndtri(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Ndtri", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return ndtri_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Ndtri", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Ndtri", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ndtri_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Ndtri", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Ndtri", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes numerical negative value element-wise. + /// + /// + /// + /// I.e., \(y = -x\). + /// + /// + /// + /// + public static Tensor neg(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Neg", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return neg_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Neg", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Neg", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor neg_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Neg", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Neg", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the next representable value of `x1` in the direction of `x2`, element-wise. + /// + /// + /// + /// This operation returns the same result as the C++ std::nextafter function. + /// + /// It can also return a subnormal number. + /// + /// @compatibility(cpp) + /// Equivalent to C++ std::nextafter function. + /// @end_compatibility + /// + /// + /// + /// + /// + public static Tensor next_after(Tensor x1, Tensor x2, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "NextAfter", name) { args = new object[] { x1, x2 }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return next_after_eager_fallback(x1, x2, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x1"] = x1; + keywords["x2"] = x2; + var _op = tf.OpDefLib._apply_op_helper("NextAfter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("NextAfter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor next_after_eager_fallback(Tensor x1, Tensor x2, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x1, x2 }; + object[] _attrs = new object[] { "T", x1.dtype }; + var _result = _execute.execute("NextAfter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("NextAfter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the truth value of (x != y) element-wise. + /// + /// + /// + /// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + /// + public static Tensor not_equal(Tensor x, Tensor y, bool incompatible_shape_error = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "NotEqual", name) { args = new object[] { x, y }, attrs = new Dictionary() { ["incompatible_shape_error"] = incompatible_shape_error } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return not_equal_eager_fallback(x, y, incompatible_shape_error: incompatible_shape_error, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["incompatible_shape_error"] = incompatible_shape_error; + var _op = tf.OpDefLib._apply_op_helper("NotEqual", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "incompatible_shape_error", _op._get_attr_bool("incompatible_shape_error") }; + _execute.record_gradient("NotEqual", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor not_equal_eager_fallback(Tensor x, Tensor y, bool incompatible_shape_error, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype, "incompatible_shape_error", incompatible_shape_error }; + var _result = _execute.execute("NotEqual", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("NotEqual", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the polygamma function \\(\psi^{(n)}(x)\\). + /// + /// + /// + /// The polygamma function is defined as: + /// + /// + /// \(psi^{(a)}(x) = rac{d^a}{dx^a} psi(x)\) + /// + /// where \(psi(x)\) is the digamma function. + /// The polygamma function is defined only for non-negative integer orders \a\. + /// + /// + /// + /// + /// + public static Tensor polygamma(Tensor a, Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Polygamma", name) { args = new object[] { a, x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return polygamma_eager_fallback(a, x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Polygamma", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Polygamma", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor polygamma_eager_fallback(Tensor a, Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, x }; + object[] _attrs = new object[] { "T", a.dtype }; + var _result = _execute.execute("Polygamma", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Polygamma", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the power of one value to another. + /// + /// + /// + /// Given a tensor `x` and a tensor `y`, this operation computes \(x^y\) for + /// corresponding elements in `x` and `y`. For example: + /// + /// ``` + /// # tensor 'x' is [[2, 2]], [3, 3]] + /// # tensor 'y' is [[8, 16], [2, 3]] + /// tf.pow(x, y) ==> [[256, 65536], [9, 27]] + /// ``` + /// + /// + /// + /// + /// + public static Tensor pow(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Pow", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return pow_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Pow", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Pow", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor pow_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Pow", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Pow", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the product of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor prod(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Prod", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return prod_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Prod", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Prod", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor prod_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Prod", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Prod", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Convert the quantized 'input' tensor into a lower-precision 'output', using the + /// + /// + /// + /// actual distribution of the values to maximize the usage of the lower bit depth + /// and adjusting the output min and max ranges accordingly. + /// + /// [input_min, input_max] are scalar floats that specify the range for the float + /// interpretation of the 'input' data. For example, if input_min is -1.0f and + /// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 + /// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. + /// + /// This operator tries to squeeze as much precision as possible into an output with + /// a lower bit depth by calculating the actual min and max values found in the + /// data. For example, maybe that quint16 input has no values lower than 16,384 and + /// none higher than 49,152. That means only half the range is actually needed, all + /// the float interpretations are between -0.5f and 0.5f, so if we want to compress + /// the data into a quint8 output, we can use that range rather than the theoretical + /// -1.0f to 1.0f that is suggested by the input min and max. + /// + /// In practice, this is most useful for taking output from operations like + /// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and + /// may have large potential output ranges, but in practice have a distribution of + /// input values that only uses a small fraction of the possible range. By feeding + /// that output into this operator, we can reduce it from 32 bits down to 8 with + /// minimal loss of accuracy. + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. Should be a lower bit depth than Tinput. + /// + /// + /// + public static Tensor[] quantize_down_and_shrink_range(Tensor input, Tensor input_min, Tensor input_max, TF_DataType out_type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizeDownAndShrinkRange", name) { args = new object[] { input, input_min, input_max }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantize_down_and_shrink_range_eager_fallback(input, input_min, input_max, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("QuantizeDownAndShrinkRange", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("QuantizeDownAndShrinkRange", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantize_down_and_shrink_range_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max }; + object[] _attrs = new object[] { "Tinput", input.dtype, "out_type", out_type }; + var _result = _execute.execute("QuantizeDownAndShrinkRange", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizeDownAndShrinkRange", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns x + y element-wise, working on quantized buffers. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_add(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType Toutput = TF_DataType.TF_QINT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedAdd", name) { args = new object[] { x, y, min_x, max_x, min_y, max_y }, attrs = new Dictionary() { ["Toutput"] = Toutput } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_add_eager_fallback(x, y, min_x, max_x, min_y, max_y, Toutput: Toutput, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["min_x"] = min_x; + keywords["max_x"] = max_x; + keywords["min_y"] = min_y; + keywords["max_y"] = max_y; + keywords["Toutput"] = Toutput; + var _op = tf.OpDefLib._apply_op_helper("QuantizedAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Toutput", _op._get_attr_type("Toutput") }; + _execute.record_gradient("QuantizedAdd", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_add_eager_fallback(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType Toutput, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y, min_x, max_x, min_y, max_y }; + object[] _attrs = new object[] { "T1", x.dtype, "T2", y.dtype, "Toutput", Toutput }; + var _result = _execute.execute("QuantizedAdd", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedAdd", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Perform a quantized matrix multiplication of `a` by the matrix `b`. + /// + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of + /// `a` (after being transposed if `transpose_a` is non-zero) must match the + /// outer dimension of `b` (after being transposed if `transposed_b` is + /// non-zero). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If true, `a` is transposed before multiplication. + /// + /// + /// + /// + /// If true, `b` is transposed before multiplication. + /// + /// + /// + /// + /// The type of output produced by activation function + /// following this operation. + /// + /// + /// + public static Tensor[] quantized_mat_mul(Tensor a, Tensor b, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput = TF_DataType.TF_QINT32, bool transpose_a = false, bool transpose_b = false, TF_DataType Tactivation = TF_DataType.TF_QUINT8, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMul", name) { args = new object[] { a, b, min_a, max_a, min_b, max_b }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["Tactivation"] = Tactivation } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_eager_fallback(a, b, min_a, max_a, min_b, max_b, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, Tactivation: Tactivation, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["Tactivation"] = Tactivation; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "Tactivation", _op._get_attr_type("Tactivation") }; + _execute.record_gradient("QuantizedMatMul", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mat_mul_eager_fallback(Tensor a, Tensor b, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput, bool transpose_a, bool transpose_b, TF_DataType Tactivation, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, min_a, max_a, min_b, max_b }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "Tactivation", Tactivation }; + var _result = _execute.execute("QuantizedMatMul", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMul", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns x * y element-wise, working on quantized buffers. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_mul(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType Toutput = TF_DataType.TF_QINT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMul", name) { args = new object[] { x, y, min_x, max_x, min_y, max_y }, attrs = new Dictionary() { ["Toutput"] = Toutput } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mul_eager_fallback(x, y, min_x, max_x, min_y, max_y, Toutput: Toutput, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + keywords["min_x"] = min_x; + keywords["max_x"] = max_x; + keywords["min_y"] = min_y; + keywords["max_y"] = max_y; + keywords["Toutput"] = Toutput; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Toutput", _op._get_attr_type("Toutput") }; + _execute.record_gradient("QuantizedMul", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mul_eager_fallback(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType Toutput, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y, min_x, max_x, min_y, max_y }; + object[] _attrs = new object[] { "T1", x.dtype, "T2", y.dtype, "Toutput", Toutput }; + var _result = _execute.execute("QuantizedMul", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMul", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Counts the number of occurrences of each value in an integer array. + /// + /// + /// + /// Outputs a vector with length `size` and the same dtype as `weights`. If + /// `weights` are empty, then index `i` stores the number of times the value `i` is + /// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of + /// the value in `weights` at each index where the corresponding value in `arr` is + /// `i`. + /// + /// Values in `arr` outside of the range [0, size) are ignored. + /// + /// + /// + /// + /// + /// + /// + /// + /// bool; Whether the kernel should count the appearance or number of occurrences. + /// + /// + /// + public static Tensor ragged_bincount(Tensor splits, Tensor values, Tensor size, Tensor weights, bool binary_output = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RaggedBincount", name) { args = new object[] { splits, values, size, weights }, attrs = new Dictionary() { ["binary_output"] = binary_output } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return ragged_bincount_eager_fallback(splits, values, size, weights, binary_output: binary_output, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["splits"] = splits; + keywords["values"] = values; + keywords["size"] = size; + keywords["weights"] = weights; + keywords["binary_output"] = binary_output; + var _op = tf.OpDefLib._apply_op_helper("RaggedBincount", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx"), "T", _op._get_attr_type("T"), "binary_output", _op._get_attr_bool("binary_output") }; + _execute.record_gradient("RaggedBincount", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor ragged_bincount_eager_fallback(Tensor splits, Tensor values, Tensor size, Tensor weights, bool binary_output, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { splits, values, size, weights }; + object[] _attrs = new object[] { "Tidx", values.dtype, "T", weights.dtype, "binary_output", binary_output }; + var _result = _execute.execute("RaggedBincount", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RaggedBincount", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a sequence of numbers. + /// + /// + /// + /// This operation creates a sequence of numbers that begins at `start` and + /// extends by increments of `delta` up to but not including `limit`. + /// + /// For example: + /// + /// ``` + /// # 'start' is 3 + /// # 'limit' is 18 + /// # 'delta' is 3 + /// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor range(Tensor start, Tensor limit, Tensor delta, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Range", name) { args = new object[] { start, limit, delta }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return range_eager_fallback(start, limit, delta, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["start"] = start; + keywords["limit"] = limit; + keywords["delta"] = delta; + var _op = tf.OpDefLib._apply_op_helper("Range", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Range", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor range_eager_fallback(Tensor start, Tensor limit, Tensor delta, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { start, limit, delta }; + object[] _attrs = new object[] { "Tidx", start.dtype }; + var _result = _execute.execute("Range", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Range", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the real part of a complex number. + /// + /// + /// + /// Given a tensor `input` of complex numbers, this operation returns a tensor of + /// type `float` that is the real part of each element in `input`. All elements in + /// `input` must be complex numbers of the form \(a + bj\), where *a* is the real + /// part returned by this operation and *b* is the imaginary part. + /// + /// For example: + /// + /// ``` + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.real(input) ==> [-2.25, 3.25] + /// ``` + /// + /// + /// + /// + /// + public static Tensor real(Tensor input, TF_DataType Tout = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Real", name) { args = new object[] { input }, attrs = new Dictionary() { ["Tout"] = Tout } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return real_eager_fallback(input, Tout: Tout, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["Tout"] = Tout; + var _op = tf.OpDefLib._apply_op_helper("Real", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tout", _op._get_attr_type("Tout") }; + _execute.record_gradient("Real", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor real_eager_fallback(Tensor input, TF_DataType Tout, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "Tout", Tout }; + var _result = _execute.execute("Real", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Real", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x / y element-wise for real types. + /// + /// + /// + /// If `x` and `y` are reals, this will return the floating-point division. + /// + /// *NOTE*: `Div` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor real_div(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RealDiv", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return real_div_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("RealDiv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("RealDiv", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor real_div_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("RealDiv", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RealDiv", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the reciprocal of x element-wise. + /// + /// + /// + /// I.e., \(y = 1 / x\). + /// + /// + /// + /// + public static Tensor reciprocal(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Reciprocal", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reciprocal_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Reciprocal", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Reciprocal", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reciprocal_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Reciprocal", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Reciprocal", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient for the inverse of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` + /// is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor reciprocal_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReciprocalGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return reciprocal_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("ReciprocalGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ReciprocalGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor reciprocal_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("ReciprocalGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReciprocalGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes a range that covers the actual values present in a quantized tensor. + /// + /// + /// + /// Given a quantized tensor described by `(input, input_min, input_max)`, outputs a + /// range that covers the actual values present in that tensor. This op is typically + /// used to produce the `requested_output_min` and `requested_output_max` for + /// `Requantize`. + /// + /// + /// + /// + /// + /// + public static Tensor[] requantization_range(Tensor input, Tensor input_min, Tensor input_max, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RequantizationRange", name) { args = new object[] { input, input_min, input_max }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return requantization_range_eager_fallback(input, input_min, input_max, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + var _op = tf.OpDefLib._apply_op_helper("RequantizationRange", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput") }; + _execute.record_gradient("RequantizationRange", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] requantization_range_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max }; + object[] _attrs = new object[] { "Tinput", input.dtype }; + var _result = _execute.execute("RequantizationRange", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RequantizationRange", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes requantization range per channel. + /// + /// + /// + /// + /// + /// + /// The maximum value of the output that needs to be clipped. + /// Example: set this to 6 for Relu6. + /// + /// + /// + public static Tensor[] requantization_range_per_channel(Tensor input, Tensor input_min, Tensor input_max, float clip_value_max, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RequantizationRangePerChannel", name) { args = new object[] { input, input_min, input_max }, attrs = new Dictionary() { ["clip_value_max"] = clip_value_max } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return requantization_range_per_channel_eager_fallback(input, input_min, input_max, clip_value_max: clip_value_max, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["clip_value_max"] = clip_value_max; + var _op = tf.OpDefLib._apply_op_helper("RequantizationRangePerChannel", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "clip_value_max", _op.get_attr("clip_value_max") }; + _execute.record_gradient("RequantizationRangePerChannel", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] requantization_range_per_channel_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, float clip_value_max, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max }; + object[] _attrs = new object[] { "T", input.dtype, "clip_value_max", clip_value_max }; + var _result = _execute.execute("RequantizationRangePerChannel", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RequantizationRangePerChannel", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Converts the quantized `input` tensor into a lower-precision `output`. + /// + /// + /// + /// Converts the quantized `input` tensor into a lower-precision `output`, using the + /// output range specified with `requested_output_min` and `requested_output_max`. + /// + /// `[input_min, input_max]` are scalar floats that specify the range for the float + /// interpretation of the `input` data. For example, if `input_min` is -1.0f and + /// `input_max` is 1.0f, and we are dealing with `quint16` quantized data, then a 0 + /// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. Should be a lower bit depth than Tinput. + /// + /// + /// + public static Tensor[] requantize(Tensor input, Tensor input_min, Tensor input_max, Tensor requested_output_min, Tensor requested_output_max, TF_DataType out_type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Requantize", name) { args = new object[] { input, input_min, input_max, requested_output_min, requested_output_max }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return requantize_eager_fallback(input, input_min, input_max, requested_output_min, requested_output_max, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["requested_output_min"] = requested_output_min; + keywords["requested_output_max"] = requested_output_max; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("Requantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("Requantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] requantize_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, Tensor requested_output_min, Tensor requested_output_max, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max, requested_output_min, requested_output_max }; + object[] _attrs = new object[] { "Tinput", input.dtype, "out_type", out_type }; + var _result = _execute.execute("Requantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Requantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Requantizes input with min and max values known per channel. + /// + /// + /// + /// + /// + /// + /// + /// + /// The quantized type of output tensor that needs to be converted. + /// + /// + /// + public static Tensor[] requantize_per_channel(Tensor input, Tensor input_min, Tensor input_max, Tensor requested_output_min, Tensor requested_output_max, TF_DataType out_type = TF_DataType.TF_QUINT8, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RequantizePerChannel", name) { args = new object[] { input, input_min, input_max, requested_output_min, requested_output_max }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return requantize_per_channel_eager_fallback(input, input_min, input_max, requested_output_min, requested_output_max, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["input_min"] = input_min; + keywords["input_max"] = input_max; + keywords["requested_output_min"] = requested_output_min; + keywords["requested_output_max"] = requested_output_max; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("RequantizePerChannel", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("RequantizePerChannel", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] requantize_per_channel_eager_fallback(Tensor input, Tensor input_min, Tensor input_max, Tensor requested_output_min, Tensor requested_output_max, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, input_min, input_max, requested_output_min, requested_output_max }; + object[] _attrs = new object[] { "T", input.dtype, "out_type", out_type }; + var _result = _execute.execute("RequantizePerChannel", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RequantizePerChannel", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns element-wise integer closest to x. + /// + /// + /// + /// If the result is midway between two representable values, + /// the even representable is chosen. + /// For example: + /// + /// ``` + /// rint(-1.5) ==> -2.0 + /// rint(0.5000001) ==> 1.0 + /// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] + /// ``` + /// + /// + /// + /// + public static Tensor rint(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Rint", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return rint_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Rint", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Rint", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor rint_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Rint", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Rint", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Rounds the values of a tensor to the nearest integer, element-wise. + /// + /// + /// + /// Rounds half to even. Also known as bankers rounding. If you want to round + /// according to the current system rounding mode use std::cint. + /// + /// + /// + /// + public static Tensor round(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Round", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return round_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Round", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Round", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor round_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Round", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Round", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes reciprocal of square root of x element-wise. + /// + /// + /// + /// I.e., \(y = 1 / sqrt{x}\). + /// + /// + /// + /// + public static Tensor rsqrt(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Rsqrt", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return rsqrt_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Rsqrt", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor rsqrt_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Rsqrt", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Rsqrt", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient for the rsqrt of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` + /// is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor rsqrt_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RsqrtGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return rsqrt_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("RsqrtGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("RsqrtGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor rsqrt_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("RsqrtGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RsqrtGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the maximum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output_i = max_j(data_j)\) where `max` is over `j` such + /// that `segment_ids[j] == i`. + /// + /// If the max is empty for a given segment ID `i`, `output[i] = 0`. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be sorted, + /// and an error is thrown for indices that are not increasing. On GPU, this + /// does not throw an error for unsorted indices. On GPU, out-of-order indices + /// result in safe but unspecified behavior, which may include treating + /// out-of-order indices as the same as a smaller following index. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + /// >>> tf.math.segment_max(c, tf.constant([0, 0, 1])).numpy() + /// array([[4, 3, 3, 4], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + ///
+ /// + /// + /// + public static Tensor segment_max(Tensor data, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SegmentMax", name) { args = new object[] { data, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return segment_max_eager_fallback(data, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SegmentMax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("SegmentMax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor segment_max_eager_fallback(Tensor data, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype }; + var _result = _execute.execute("SegmentMax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SegmentMax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the mean along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output_i = rac{sum_j data_j}{N}\) where `mean` is + /// over `j` such that `segment_ids[j] == i` and `N` is the total number of + /// values summed. + /// + /// If the mean is empty for a given segment ID `i`, `output[i] = 0`. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be sorted, + /// and an error is thrown for indices that are not increasing. On GPU, this + /// does not throw an error for unsorted indices. On GPU, out-of-order indices + /// result in safe but unspecified behavior, which may include treating + /// out-of-order indices as a smaller following index when computing the numerator + /// of the mean. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + /// >>> tf.math.segment_mean(c, tf.constant([0, 0, 1])).numpy() + /// array([[2.5, 2.5, 2.5, 2.5], + /// [5., 6., 7., 8.]], dtype=float32) + /// + /// + ///
+ /// + /// + /// + public static Tensor segment_mean(Tensor data, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SegmentMean", name) { args = new object[] { data, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return segment_mean_eager_fallback(data, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SegmentMean", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("SegmentMean", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor segment_mean_eager_fallback(Tensor data, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype }; + var _result = _execute.execute("SegmentMean", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SegmentMean", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the minimum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output_i = min_j(data_j)\) where `min` is over `j` such + /// that `segment_ids[j] == i`. + /// + /// If the min is empty for a given segment ID `i`, `output[i] = 0`. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be sorted, + /// and an error is thrown for indices that are not increasing. On GPU, this + /// does not throw an error for unsorted indices. On GPU, out-of-order indices + /// result in safe but unspecified behavior, which may include treating + /// out-of-order indices as the same as a smaller following index. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + /// >>> tf.math.segment_min(c, tf.constant([0, 0, 1])).numpy() + /// array([[1, 2, 2, 1], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + ///
+ /// + /// + /// + public static Tensor segment_min(Tensor data, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SegmentMin", name) { args = new object[] { data, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return segment_min_eager_fallback(data, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SegmentMin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("SegmentMin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor segment_min_eager_fallback(Tensor data, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype }; + var _result = _execute.execute("SegmentMin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SegmentMin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the product along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output_i = prod_j data_j\) where the product is over `j` such + /// that `segment_ids[j] == i`. + /// + /// If the product is empty for a given segment ID `i`, `output[i] = 1`. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be sorted, + /// and an error is thrown for indices that are not increasing. On GPU, this + /// does not throw an error for unsorted indices. On GPU, out-of-order indices + /// result in safe but unspecified behavior, which may include treating + /// out-of-order indices as the same as a smaller following index. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + /// >>> tf.math.segment_prod(c, tf.constant([0, 0, 1])).numpy() + /// array([[4, 6, 6, 4], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + ///
+ /// + /// + /// + public static Tensor segment_prod(Tensor data, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SegmentProd", name) { args = new object[] { data, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return segment_prod_eager_fallback(data, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SegmentProd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("SegmentProd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor segment_prod_eager_fallback(Tensor data, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype }; + var _result = _execute.execute("SegmentProd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SegmentProd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output_i = sum_j data_j\) where sum is over `j` such + /// that `segment_ids[j] == i`. + /// + /// If the sum is empty for a given segment ID `i`, `output[i] = 0`. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be sorted, + /// and an error is thrown for indices that are not increasing. On GPU, this + /// does not throw an error for unsorted indices. On GPU, out-of-order indices + /// result in safe but unspecified behavior, which may include treating + /// out-of-order indices as the same as a smaller following index. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + /// >>> tf.math.segment_sum(c, tf.constant([0, 0, 1])).numpy() + /// array([[5, 5, 5, 5], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + ///
+ /// + /// + /// + public static Tensor segment_sum(Tensor data, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SegmentSum", name) { args = new object[] { data, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return segment_sum_eager_fallback(data, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SegmentSum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("SegmentSum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor segment_sum_eager_fallback(Tensor data, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype }; + var _result = _execute.execute("SegmentSum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SegmentSum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Selects elements from `t` or `e`, depending on `condition`. + /// + /// + /// + /// The `t`, and `e` tensors must all have the same shape, and the + /// output will also have that shape. + /// + /// The `condition` tensor must be a scalar if `t` and `e` are scalars. + /// If `t` and `e` are vectors or higher rank, then `condition` must be either a + /// scalar, a vector with size matching the first dimension of `t`, or must have + /// the same shape as `t`. + /// + /// The `condition` tensor acts as a mask that chooses, based on the value at each + /// element, whether the corresponding element / row in the output should be + /// taken from `t` (if true) or `e` (if false). + /// + /// If `condition` is a vector and `t` and `e` are higher rank matrices, then + /// it chooses which row (outer dimension) to copy from `t` and `e`. + /// If `condition` has the same shape as `t` and `e`, then it chooses which + /// element to copy from `t` and `e`. + /// + /// For example: + /// + /// ```python + /// # 'condition' tensor is [[True, False] + /// # [False, True]] + /// # 't' is [[1, 2], + /// # [3, 4]] + /// # 'e' is [[5, 6], + /// # [7, 8]] + /// select(condition, t, e) # => [[1, 6], [7, 4]] + /// + /// + /// # 'condition' tensor is [True, False] + /// # 't' is [[1, 2], + /// # [3, 4]] + /// # 'e' is [[5, 6], + /// # [7, 8]] + /// select(condition, t, e) ==> [[1, 2], + /// [7, 8]] + /// + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor select(Tensor condition, Tensor t, Tensor e, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Select", name) { args = new object[] { condition, t, e }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return select_eager_fallback(condition, t, e, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["condition"] = condition; + keywords["t"] = t; + keywords["e"] = e; + var _op = tf.OpDefLib._apply_op_helper("Select", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Select", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor select_eager_fallback(Tensor condition, Tensor t, Tensor e, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { condition, t, e }; + object[] _attrs = new object[] { "T", t.dtype }; + var _result = _execute.execute("Select", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Select", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor select_v2(Tensor condition, Tensor t, Tensor e, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SelectV2", name) { args = new object[] { condition, t, e }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return select_v2_eager_fallback(condition, t, e, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["condition"] = condition; + keywords["t"] = t; + keywords["e"] = e; + var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SelectV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor select_v2_eager_fallback(Tensor condition, Tensor t, Tensor e, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { condition, t, e }; + object[] _attrs = new object[] { "T", t.dtype }; + var _result = _execute.execute("SelectV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SelectV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes sigmoid of `x` element-wise. + /// + /// + /// + /// Specifically, `y = 1 / (1 + exp(-x))`. + /// + /// + /// + /// + public static Tensor sigmoid(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sigmoid", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sigmoid_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Sigmoid", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sigmoid", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sigmoid_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sigmoid", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sigmoid", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient of the sigmoid of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and + /// `dy` is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor sigmoid_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SigmoidGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sigmoid_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SigmoidGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sigmoid_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("SigmoidGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SigmoidGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns an element-wise indication of the sign of a number. + /// + /// + /// + /// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. + /// + /// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. + /// + /// Example usage: + /// >>> tf.math.sign([0., 2., -3.]) + /// + /// + /// + /// + /// + public static Tensor sign(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sign", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sign_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Sign", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sign", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sign_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sign", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sign", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes sine of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes sine of every + /// element in the tensor. Input range is `(-inf, inf)` and + /// output range is `[-1,1]`. + /// + /// ```python + /// x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10, float("inf")]) + /// tf.math.sin(x) ==> [nan -0.4121185 -0.47942555 0.84147096 0.9320391 -0.87329733 -0.54402107 nan] + /// ``` + /// + /// + /// + /// + public static Tensor sin(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sin", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sin_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Sin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sin_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes hyperbolic sine of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes hyperbolic sine of every + /// element in the tensor. Input range is `[-inf,inf]` and output range + /// is `[-inf,inf]`. + /// + /// ```python + /// x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 2, 10, float("inf")]) + /// tf.math.sinh(x) ==> [-inf -4.0515420e+03 -5.2109528e-01 1.1752012e+00 1.5094614e+00 3.6268604e+00 1.1013232e+04 inf] + /// ``` + /// + /// + /// + /// + public static Tensor sinh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sinh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sinh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Sinh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sinh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sinh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sinh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sinh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Generates points from the Sobol sequence. + /// + /// + /// + /// Creates a Sobol sequence with `num_results` samples. Each sample has dimension + /// `dim`. Skips the first `skip` samples. + /// + /// + /// + /// + /// + /// + /// + /// The type of the sample. One of: `float32` or `float64`. + /// + /// + /// + public static Tensor sobol_sample(Tensor dim, Tensor num_results, Tensor skip, TF_DataType dtype = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SobolSample", name) { args = new object[] { dim, num_results, skip }, attrs = new Dictionary() { ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sobol_sample_eager_fallback(dim, num_results, skip, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dim"] = dim; + keywords["num_results"] = num_results; + keywords["skip"] = skip; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("SobolSample", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("SobolSample", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sobol_sample_eager_fallback(Tensor dim, Tensor num_results, Tensor skip, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { dim, num_results, skip }; + object[] _attrs = new object[] { "dtype", dtype }; + var _result = _execute.execute("SobolSample", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SobolSample", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Counts the number of occurrences of each value in an integer array. + /// + /// + /// + /// Outputs a vector with length `size` and the same dtype as `weights`. If + /// `weights` are empty, then index `i` stores the number of times the value `i` is + /// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of + /// the value in `weights` at each index where the corresponding value in `arr` is + /// `i`. + /// + /// Values in `arr` outside of the range [0, size) are ignored. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// bool; Whether the kernel should count the appearance or number of occurrences. + /// + /// + /// + public static Tensor sparse_bincount(Tensor indices, Tensor values, Tensor dense_shape, Tensor size, Tensor weights, bool binary_output = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseBincount", name) { args = new object[] { indices, values, dense_shape, size, weights }, attrs = new Dictionary() { ["binary_output"] = binary_output } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_bincount_eager_fallback(indices, values, dense_shape, size, weights, binary_output: binary_output, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["indices"] = indices; + keywords["values"] = values; + keywords["dense_shape"] = dense_shape; + keywords["size"] = size; + keywords["weights"] = weights; + keywords["binary_output"] = binary_output; + var _op = tf.OpDefLib._apply_op_helper("SparseBincount", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tidx", _op._get_attr_type("Tidx"), "T", _op._get_attr_type("T"), "binary_output", _op._get_attr_bool("binary_output") }; + _execute.record_gradient("SparseBincount", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_bincount_eager_fallback(Tensor indices, Tensor values, Tensor dense_shape, Tensor size, Tensor weights, bool binary_output, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { indices, values, dense_shape, size, weights }; + object[] _attrs = new object[] { "Tidx", values.dtype, "T", weights.dtype, "binary_output", binary_output }; + var _result = _execute.execute("SparseBincount", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseBincount", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Multiply matrix "a" by matrix "b". + /// + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of "a" must + /// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not + /// `SparseTensor`s. This op is optimized for the case where at least one of "a" or + /// "b" is sparse, in the sense that they have a large proportion of zero values. + /// The breakeven for using this versus a dense matrix multiply on one platform was + /// 30% zero values in the sparse matrix. + /// + /// The gradient computation of this operation will only take advantage of sparsity + /// in the input gradient when that gradient comes from a Relu. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, bool a_is_sparse = false, bool b_is_sparse = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseMatMul", name) { args = new object[] { a, b }, attrs = new Dictionary() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["a_is_sparse"] = a_is_sparse, ["b_is_sparse"] = b_is_sparse } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_mat_mul_eager_fallback(a, b, transpose_a: transpose_a, transpose_b: transpose_b, a_is_sparse: a_is_sparse, b_is_sparse: b_is_sparse, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["a_is_sparse"] = a_is_sparse; + keywords["b_is_sparse"] = b_is_sparse; + var _op = tf.OpDefLib._apply_op_helper("SparseMatMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "a_is_sparse", _op._get_attr_bool("a_is_sparse"), "b_is_sparse", _op._get_attr_bool("b_is_sparse"), "Ta", _op._get_attr_type("Ta"), "Tb", _op._get_attr_type("Tb") }; + _execute.record_gradient("SparseMatMul", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_mat_mul_eager_fallback(Tensor a, Tensor b, bool transpose_a, bool transpose_b, bool a_is_sparse, bool b_is_sparse, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b }; + object[] _attrs = new object[] { "transpose_a", transpose_a, "transpose_b", transpose_b, "a_is_sparse", a_is_sparse, "b_is_sparse", b_is_sparse, "Ta", a.dtype, "Tb", b.dtype }; + var _result = _execute.execute("SparseMatMul", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseMatMul", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the mean along sparse segments of a tensor. + /// + /// + /// + /// See `tf.sparse.segment_sum` for usage examples. + /// + /// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first + /// dimension, selecting a subset of dimension 0, specified by `indices`. + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_mean(Tensor data, Tensor indices, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentMean", name) { args = new object[] { data, indices, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_mean_eager_fallback(data, indices, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentMean", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentMean", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_mean_eager_fallback(Tensor data, Tensor indices, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, indices, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tidx", indices.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentMean", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentMean", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients for SparseSegmentMean. + /// + /// + /// + /// Returns tensor "output" with same shape as grad, except for dimension 0 whose + /// value is output_dim0. + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_mean_grad(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentMeanGrad", name) { args = new object[] { grad, indices, segment_ids, output_dim0 }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_mean_grad_eager_fallback(grad, indices, segment_ids, output_dim0, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["grad"] = grad; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + keywords["output_dim0"] = output_dim0; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentMeanGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentMeanGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_mean_grad_eager_fallback(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { grad, indices, segment_ids, output_dim0 }; + object[] _attrs = new object[] { "T", grad.dtype, "Tidx", indices.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentMeanGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentMeanGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the mean along sparse segments of a tensor. + /// + /// + /// + /// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is + /// missing, the `output` tensor at that position will be zeroed. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_mean_with_num_segments(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentMeanWithNumSegments", name) { args = new object[] { data, indices, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_mean_with_num_segments_eager_fallback(data, indices, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentMeanWithNumSegments", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tnumsegments", _op._get_attr_type("Tnumsegments"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentMeanWithNumSegments", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_mean_with_num_segments_eager_fallback(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, indices, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tidx", indices.dtype, "Tnumsegments", num_segments.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentMeanWithNumSegments", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentMeanWithNumSegments", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum along sparse segments of a tensor divided by the sqrt of N. + /// + /// + /// + /// N is the size of the segment being reduced. + /// + /// See `tf.sparse.segment_sum` for usage examples. + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_sqrt_n(Tensor data, Tensor indices, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentSqrtN", name) { args = new object[] { data, indices, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_sqrt_n_eager_fallback(data, indices, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentSqrtN", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentSqrtN", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_sqrt_n_eager_fallback(Tensor data, Tensor indices, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, indices, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tidx", indices.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentSqrtN", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentSqrtN", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum along sparse segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first + /// dimension, selecting a subset of dimension 0, specified by `indices`. + /// + /// For example: + /// + /// ```python + /// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + /// + /// # Select two rows, one segment. + /// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) + /// # => [[0 0 0 0]] + /// + /// # Select two rows, two segment. + /// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) + /// # => [[ 1 2 3 4] + /// # [-1 -2 -3 -4]] + /// + /// # Select all rows, two segments. + /// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) + /// # => [[0 0 0 0] + /// # [5 6 7 8]] + /// + /// # Which is equivalent to: + /// tf.segment_sum(c, tf.constant([0, 0, 1])) + /// ``` + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_sum(Tensor data, Tensor indices, Tensor segment_ids, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentSum", name) { args = new object[] { data, indices, segment_ids }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_sum_eager_fallback(data, indices, segment_ids, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentSum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentSum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_sum_eager_fallback(Tensor data, Tensor indices, Tensor segment_ids, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, indices, segment_ids }; + object[] _attrs = new object[] { "T", data.dtype, "Tidx", indices.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentSum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentSum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients for SparseSegmentSum. + /// + /// + /// + /// Returns tensor "output" with same shape as grad, except for dimension 0 whose + /// value is output_dim0. + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_sum_grad(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentSumGrad", name) { args = new object[] { grad, indices, segment_ids, output_dim0 }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_sum_grad_eager_fallback(grad, indices, segment_ids, output_dim0, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["grad"] = grad; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + keywords["output_dim0"] = output_dim0; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentSumGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentSumGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_sum_grad_eager_fallback(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { grad, indices, segment_ids, output_dim0 }; + object[] _attrs = new object[] { "T", grad.dtype, "Tidx", indices.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentSumGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentSumGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum along sparse segments of a tensor. + /// + /// + /// + /// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is + /// missing, the `output` tensor at that position will be zeroed. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation) + /// for an explanation of segments. + /// + /// For example: + /// + /// ```python + /// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + /// + /// tf.sparse_segment_sum_with_num_segments( + /// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) + /// # => [[0 0 0 0] + /// # [0 0 0 0] + /// # [0 0 0 0]] + /// + /// tf.sparse_segment_sum_with_num_segments(c, + /// tf.constant([0, 1]), + /// tf.constant([0, 2], + /// num_segments=4)) + /// # => [[ 1 2 3 4] + /// # [ 0 0 0 0] + /// # [-1 -2 -3 -4] + /// # [ 0 0 0 0]] + /// ``` + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_segment_sum_with_num_segments(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSegmentSumWithNumSegments", name) { args = new object[] { data, indices, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_segment_sum_with_num_segments_eager_fallback(data, indices, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["indices"] = indices; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("SparseSegmentSumWithNumSegments", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx"), "Tnumsegments", _op._get_attr_type("Tnumsegments"), "Tsegmentids", _op._get_attr_type("Tsegmentids") }; + _execute.record_gradient("SparseSegmentSumWithNumSegments", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sparse_segment_sum_with_num_segments_eager_fallback(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, indices, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tidx", indices.dtype, "Tnumsegments", num_segments.dtype, "Tsegmentids", segment_ids.dtype }; + var _result = _execute.execute("SparseSegmentSumWithNumSegments", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSegmentSumWithNumSegments", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes square root of x element-wise. + /// + /// + /// + /// I.e., \(y = sqrt{x} = x^{1/2}\). + /// + /// + /// + /// + public static Tensor sqrt(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sqrt", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sqrt_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Sqrt", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sqrt", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sqrt_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sqrt", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sqrt", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient for the sqrt of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` + /// is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor sqrt_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SqrtGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sqrt_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("SqrtGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SqrtGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sqrt_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("SqrtGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SqrtGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes square of x element-wise. + /// + /// + /// + /// I.e., \(y = x * x = x^2\). + /// + /// + /// + /// + public static Tensor square(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Square", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return square_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Square", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Square", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor square_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Square", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Square", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns conj(x - y)(x - y) element-wise. + /// + /// + /// + /// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor squared_difference(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SquaredDifference", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return squared_difference_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("SquaredDifference", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SquaredDifference", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor squared_difference_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("SquaredDifference", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SquaredDifference", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x - y element-wise. + /// + /// + /// + /// *NOTE*: `Sub` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor sub(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sub", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sub_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Sub", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Sub", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sub_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Sub", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sub", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum of elements across dimensions of a tensor. + /// + /// + /// + /// Reduces `input` along the dimensions given in `reduction_indices`. Unless + /// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in + /// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are + /// retained with length 1. + /// + /// + /// + /// + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// + public static Tensor sum(Tensor input, Tensor reduction_indices, bool keep_dims = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Sum", name) { args = new object[] { input, reduction_indices }, attrs = new Dictionary() { ["keep_dims"] = keep_dims } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sum_eager_fallback(input, reduction_indices, keep_dims: keep_dims, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["reduction_indices"] = reduction_indices; + keywords["keep_dims"] = keep_dims; + var _op = tf.OpDefLib._apply_op_helper("Sum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "keep_dims", _op._get_attr_bool("keep_dims"), "T", _op._get_attr_type("T"), "Tidx", _op._get_attr_type("Tidx") }; + _execute.record_gradient("Sum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor sum_eager_fallback(Tensor input, Tensor reduction_indices, bool keep_dims, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, reduction_indices }; + object[] _attrs = new object[] { "keep_dims", keep_dims, "T", input.dtype, "Tidx", reduction_indices.dtype }; + var _result = _execute.execute("Sum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Sum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes tan of x element-wise. + /// + /// + /// + /// Given an input tensor, this function computes tangent of every + /// element in the tensor. Input range is `(-inf, inf)` and + /// output range is `(-inf, inf)`. If input lies outside the boundary, `nan` + /// is returned. + /// + /// ```python + /// x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")]) + /// tf.math.tan(x) ==> [nan 0.45231566 -0.5463025 1.5574077 2.572152 -1.7925274 0.32097113 nan] + /// ``` + /// + /// + /// + /// + public static Tensor tan(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Tan", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tan_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Tan", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Tan", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tan_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Tan", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Tan", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes hyperbolic tangent of `x` element-wise. + /// + /// + /// + /// Given an input tensor, this function computes hyperbolic tangent of every + /// element in the tensor. Input range is `[-inf, inf]` and + /// output range is `[-1,1]`. + /// + /// >>> x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")]) + /// >>> tf.math.tanh(x) + /// + /// + /// + /// + /// + /// + public static Tensor tanh(Tensor x, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Tanh", name) { args = new object[] { x }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tanh_eager_fallback(x, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + var _op = tf.OpDefLib._apply_op_helper("Tanh", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Tanh", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tanh_eager_fallback(Tensor x, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Tanh", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Tanh", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient for the tanh of `x` wrt its input. + /// + /// + /// + /// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` + /// is the corresponding input gradient. + /// + /// + /// + /// + /// + public static Tensor tanh_grad(Tensor y, Tensor dy, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TanhGrad", name) { args = new object[] { y, dy }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return tanh_grad_eager_fallback(y, dy, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["y"] = y; + keywords["dy"] = dy; + var _op = tf.OpDefLib._apply_op_helper("TanhGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("TanhGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tanh_grad_eager_fallback(Tensor y, Tensor dy, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y, dy }; + object[] _attrs = new object[] { "T", y.dtype }; + var _result = _execute.execute("TanhGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TanhGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns x / y element-wise for integer types. + /// + /// + /// + /// Truncation designates that negative numbers will round fractional quantities + /// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different + /// than Python semantics. See `FloorDiv` for a division function that matches + /// Python Semantics. + /// + /// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor truncate_div(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TruncateDiv", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return truncate_div_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("TruncateDiv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("TruncateDiv", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor truncate_div_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("TruncateDiv", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TruncateDiv", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns element-wise remainder of division. This emulates C semantics in that + /// + /// + /// + /// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * + /// y + truncate_mod(x, y) = x`. + /// + /// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// + /// + /// + /// + public static Tensor truncate_mod(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TruncateMod", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return truncate_mod_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("TruncateMod", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("TruncateMod", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor truncate_mod_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("TruncateMod", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TruncateMod", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the maximum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to `tf.math.unsorted_segment_sum`, + /// Instead of computing the sum over segments, it computes the maximum such that: + /// + /// \(output_i = max_{j...} data[j...]\) where max is over tuples `j...` such + /// that `segment_ids[j...] == i`. + /// + /// If the maximum is empty for a given segment ID `i`, it outputs the smallest + /// possible value for the specific numeric type, + /// `output[i] = numeric_limits::lowest()`. + /// + /// If the given segment ID `i` is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be less than + /// `num_segments`, and an error is thrown for out-of-bound indices. On GPU, this + /// does not throw an error for out-of-bound indices. On Gpu, out-of-bound indices + /// result in safe but unspecified behavior, which may include ignoring + /// out-of-bound indices or outputting a tensor with a 0 stored in the first + /// dimension of its shape if `num_segments` is 0. + /// + ///
+ /// + ///
+ /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) + /// >>> tf.math.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2).numpy() + /// array([[4, 3, 3, 4], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + ///
+ /// + /// + /// + /// + public static Tensor unsorted_segment_max(Tensor data, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UnsortedSegmentMax", name) { args = new object[] { data, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unsorted_segment_max_eager_fallback(data, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("UnsortedSegmentMax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices"), "Tnumsegments", _op._get_attr_type("Tnumsegments") }; + _execute.record_gradient("UnsortedSegmentMax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor unsorted_segment_max_eager_fallback(Tensor data, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype, "Tnumsegments", num_segments.dtype }; + var _result = _execute.execute("UnsortedSegmentMax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UnsortedSegmentMax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the minimum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to `tf.math.unsorted_segment_sum`, + /// Instead of computing the sum over segments, it computes the minimum such that: + /// + /// \(output_i = min_{j...} data_[j...]\) where min is over tuples `j...` such + /// that `segment_ids[j...] == i`. + /// + /// If the minimum is empty for a given segment ID `i`, it outputs the largest + /// possible value for the specific numeric type, + /// `output[i] = numeric_limits::max()`. + /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) + /// >>> tf.math.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2).numpy() + /// array([[1, 2, 2, 1], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// If the given segment ID `i` is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be less than + /// `num_segments`, and an error is thrown for out-of-bound indices. On GPU, this + /// does not throw an error for out-of-bound indices. On Gpu, out-of-bound indices + /// result in safe but unspecified behavior, which may include ignoring + /// out-of-bound indices or outputting a tensor with a 0 stored in the first + /// dimension of its shape if `num_segments` is 0. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_min(Tensor data, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UnsortedSegmentMin", name) { args = new object[] { data, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unsorted_segment_min_eager_fallback(data, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("UnsortedSegmentMin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices"), "Tnumsegments", _op._get_attr_type("Tnumsegments") }; + _execute.record_gradient("UnsortedSegmentMin", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor unsorted_segment_min_eager_fallback(Tensor data, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype, "Tnumsegments", num_segments.dtype }; + var _result = _execute.execute("UnsortedSegmentMin", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UnsortedSegmentMin", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the product along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to `tf.math.unsorted_segment_sum`, + /// Instead of computing the sum over segments, it computes the product of all + /// entries belonging to a segment such that: + /// + /// \(output_i = prod_{j...} data[j...]\) where the product is over tuples + /// `j...` such that `segment_ids[j...] == i`. + /// + /// For example: + /// + /// >>> c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) + /// >>> tf.math.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2).numpy() + /// array([[4, 6, 6, 4], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// If there is no entry for a given segment ID `i`, it outputs 1. + /// + /// If the given segment ID `i` is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// Caution: On CPU, values in `segment_ids` are always validated to be less than + /// `num_segments`, and an error is thrown for out-of-bound indices. On GPU, this + /// does not throw an error for out-of-bound indices. On Gpu, out-of-bound indices + /// result in safe but unspecified behavior, which may include ignoring + /// out-of-bound indices or outputting a tensor with a 0 stored in the first + /// dimension of its shape if `num_segments` is 0. + /// + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_prod(Tensor data, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UnsortedSegmentProd", name) { args = new object[] { data, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unsorted_segment_prod_eager_fallback(data, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("UnsortedSegmentProd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices"), "Tnumsegments", _op._get_attr_type("Tnumsegments") }; + _execute.record_gradient("UnsortedSegmentProd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor unsorted_segment_prod_eager_fallback(Tensor data, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype, "Tnumsegments", num_segments.dtype }; + var _result = _execute.execute("UnsortedSegmentProd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UnsortedSegmentProd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \(output[i] = sum_{j...} data[j...]\) where the sum is over tuples `j...` such + /// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` + /// need not be sorted and need not cover all values in the full + /// range of valid values. + /// + /// If the sum is empty for a given segment ID `i`, `output[i] = 0`. + /// If the given segment ID `i` is negative, the value is dropped and will not be + /// added to the sum of the segment. + /// + /// `num_segments` should equal the number of distinct segment IDs. + /// + /// Caution: On CPU, values in `segment_ids` are always validated to be less than + /// `num_segments`, and an error is thrown for out-of-bound indices. On GPU, this + /// does not throw an error for out-of-bound indices. On Gpu, out-of-bound indices + /// result in safe but unspecified behavior, which may include ignoring + /// out-of-bound indices or outputting a tensor with a 0 stored in the first + /// dimension of its shape if `num_segments` is 0. + /// + ///
+ /// + ///
+ /// + /// >>> c = [[1,2,3,4], [5,6,7,8], [4,3,2,1]] + /// >>> tf.math.unsorted_segment_sum(c, [0, 1, 0], num_segments=2).numpy() + /// array([[5, 5, 5, 5], + /// [5, 6, 7, 8]], dtype=int32) + /// + /// + /// + ///
+ /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "UnsortedSegmentSum", name) { args = new object[] { data, segment_ids, num_segments }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return unsorted_segment_sum_eager_fallback(data, segment_ids, num_segments, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["data"] = data; + keywords["segment_ids"] = segment_ids; + keywords["num_segments"] = num_segments; + var _op = tf.OpDefLib._apply_op_helper("UnsortedSegmentSum", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tindices", _op._get_attr_type("Tindices"), "Tnumsegments", _op._get_attr_type("Tnumsegments") }; + _execute.record_gradient("UnsortedSegmentSum", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor unsorted_segment_sum_eager_fallback(Tensor data, Tensor segment_ids, Tensor num_segments, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { data, segment_ids, num_segments }; + object[] _attrs = new object[] { "T", data.dtype, "Tindices", segment_ids.dtype, "Tnumsegments", num_segments.dtype }; + var _result = _execute.execute("UnsortedSegmentSum", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("UnsortedSegmentSum", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns 0 if x == 0, and x / y otherwise, elementwise. + /// + /// + /// + /// + public static Tensor xdivy(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Xdivy", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return xdivy_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Xdivy", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Xdivy", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor xdivy_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Xdivy", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Xdivy", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise. + /// + /// + /// + /// + public static Tensor xlog1py(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Xlog1py", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return xlog1py_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Xlog1py", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Xlog1py", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor xlog1py_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Xlog1py", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Xlog1py", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns 0 if x == 0, and x * log(y) otherwise, elementwise. + /// + /// + /// + /// + public static Tensor xlogy(Tensor x, Tensor y, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Xlogy", name) { args = new object[] { x, y }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return xlogy_eager_fallback(x, y, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["y"] = y; + var _op = tf.OpDefLib._apply_op_helper("Xlogy", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Xlogy", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor xlogy_eager_fallback(Tensor x, Tensor y, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, y }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Xlogy", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Xlogy", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). + /// + /// + /// + /// The Hurwitz zeta function is defined as: + /// + /// + /// \(zeta(x, q) = sum_{n=0}^{infty} (q + n)^{-x}\) + /// + /// + /// + /// + /// + public static Tensor zeta(Tensor x, Tensor q, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Zeta", name) { args = new object[] { x, q }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return zeta_eager_fallback(x, q, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["q"] = q; + var _op = tf.OpDefLib._apply_op_helper("Zeta", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Zeta", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor zeta_eager_fallback(Tensor x, Tensor q, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, q }; + object[] _attrs = new object[] { "T", x.dtype }; + var _result = _execute.execute("Zeta", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Zeta", _inputs_flat, _attrs, _result); } + return _result[0]; } } diff --git a/src/TensorFlowNET.Core/Operations/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/gen_nn_ops.cs new file mode 100644 index 000000000..59c740c46 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_nn_ops.cs @@ -0,0 +1,8493 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_nn_ops +{ + /// + /// Returns min/max k values and their indices of the input operand in an approximate manner. + /// + /// + /// + /// See https://arxiv.org/abs/2206.14286 for the algorithm details. + /// This op is only optimized on TPU currently. + /// + /// + /// + /// + /// Specifies the number of min/max-k. + /// + /// + /// Integer dimension along which to search. Default: -1. + /// + /// + /// Recall target for the approximation. Range in (0,1] + /// + /// + /// When true, computes max-k; otherwise computes min-k. + /// + /// + /// + /// When set to a positive value, it overrides the size determined by + /// `input[reduction_dim]` for evaluating the recall. This option is useful when + /// the given `input` is only a subset of the overall computation in SPMD or + /// distributed pipelines, where the true input size cannot be deferred by the + /// `input` shape. + /// + /// + /// + /// + /// When true, aggregates approximate results to top-k. When false, returns the + /// approximate results. The number of the approximate results is implementation + /// defined and is greater equals to the specified `k`. + /// + /// + /// + public static Tensor[] approx_top_k(Tensor input, int k = 0, int reduction_dimension = -1, float recall_target = 0.95f, bool is_max_k = true, int reduction_input_size_override = -1, bool aggregate_to_topk = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ApproxTopK", name) { args = new object[] { input }, attrs = new Dictionary() { ["k"] = k, ["reduction_dimension"] = reduction_dimension, ["recall_target"] = recall_target, ["is_max_k"] = is_max_k, ["reduction_input_size_override"] = reduction_input_size_override, ["aggregate_to_topk"] = aggregate_to_topk } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return approx_top_k_eager_fallback(input, k: k, reduction_dimension: reduction_dimension, recall_target: recall_target, is_max_k: is_max_k, reduction_input_size_override: reduction_input_size_override, aggregate_to_topk: aggregate_to_topk, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["k"] = k; + keywords["reduction_dimension"] = reduction_dimension; + keywords["recall_target"] = recall_target; + keywords["is_max_k"] = is_max_k; + keywords["reduction_input_size_override"] = reduction_input_size_override; + keywords["aggregate_to_topk"] = aggregate_to_topk; + var _op = tf.OpDefLib._apply_op_helper("ApproxTopK", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "k", _op._get_attr_int("k"), "reduction_dimension", _op._get_attr_int("reduction_dimension"), "recall_target", _op.get_attr("recall_target"), "is_max_k", _op._get_attr_bool("is_max_k"), "reduction_input_size_override", _op._get_attr_int("reduction_input_size_override"), "aggregate_to_topk", _op._get_attr_bool("aggregate_to_topk"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("ApproxTopK", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] approx_top_k_eager_fallback(Tensor input, int k, int reduction_dimension, float recall_target, bool is_max_k, int reduction_input_size_override, bool aggregate_to_topk, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "k", k, "reduction_dimension", reduction_dimension, "recall_target", recall_target, "is_max_k", is_max_k, "reduction_input_size_override", reduction_input_size_override, "aggregate_to_topk", aggregate_to_topk, "T", input.dtype }; + var _result = _execute.execute("ApproxTopK", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ApproxTopK", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Performs average pooling on the input. + /// + /// + /// + /// Each entry in `output` is the mean of the corresponding size `ksize` + /// window in `value`. + /// + /// + /// + /// + /// + /// The size of the sliding window for each dimension of `value`. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of `value`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor avg_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AvgPool", name) { args = new object[] { value }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return avg_pool_eager_fallback(value, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("AvgPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("AvgPool", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor avg_pool_eager_fallback(Tensor value, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", value.dtype }; + var _result = _execute.execute("AvgPool", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AvgPool", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs 3D average pooling on the input. + /// + /// + /// + /// Each entry in `output` is the mean of the corresponding size `ksize` window in + /// `value`. + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have `ksize[0] = ksize[4] = 1`. + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + public static Tensor avg_pool3d(Tensor input, int[] ksize, int[] strides, string padding, string data_format = "NDHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AvgPool3D", name) { args = new object[] { input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return avg_pool3d_eager_fallback(input, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("AvgPool3D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("AvgPool3D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor avg_pool3d_eager_fallback(Tensor input, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", input.dtype }; + var _result = _execute.execute("AvgPool3D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AvgPool3D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of average pooling function. + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have `ksize[0] = ksize[4] = 1`. + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + public static Tensor avg_pool3d_grad(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NDHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AvgPool3DGrad", name) { args = new object[] { orig_input_shape, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return avg_pool3d_grad_eager_fallback(orig_input_shape, grad, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["orig_input_shape"] = orig_input_shape; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("AvgPool3DGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("AvgPool3DGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor avg_pool3d_grad_eager_fallback(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input_shape, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", grad.dtype }; + var _result = _execute.execute("AvgPool3DGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AvgPool3DGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of the average pooling function. + /// + /// + /// + /// + /// + /// The size of the sliding window for each dimension of the input. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor avg_pool_grad(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AvgPoolGrad", name) { args = new object[] { orig_input_shape, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return avg_pool_grad_eager_fallback(orig_input_shape, grad, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["orig_input_shape"] = orig_input_shape; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("AvgPoolGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("AvgPoolGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor avg_pool_grad_eager_fallback(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input_shape, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", grad.dtype }; + var _result = _execute.execute("AvgPoolGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AvgPoolGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Batch normalization. + /// + /// + /// + /// This op is deprecated. Prefer `tf.nn.batch_normalization`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number to avoid dividing by 0. + /// + /// + /// + /// + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// + public static Tensor batch_norm_with_global_normalization(Tensor t, Tensor m, Tensor v, Tensor beta, Tensor gamma, float variance_epsilon, bool scale_after_normalization, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchNormWithGlobalNormalization", name) { args = new object[] { t, m, v, beta, gamma }, attrs = new Dictionary() { ["variance_epsilon"] = variance_epsilon, ["scale_after_normalization"] = scale_after_normalization } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_norm_with_global_normalization_eager_fallback(t, m, v, beta, gamma, variance_epsilon: variance_epsilon, scale_after_normalization: scale_after_normalization, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["t"] = t; + keywords["m"] = m; + keywords["v"] = v; + keywords["beta"] = beta; + keywords["gamma"] = gamma; + keywords["variance_epsilon"] = variance_epsilon; + keywords["scale_after_normalization"] = scale_after_normalization; + var _op = tf.OpDefLib._apply_op_helper("BatchNormWithGlobalNormalization", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "variance_epsilon", _op.get_attr("variance_epsilon"), "scale_after_normalization", _op._get_attr_bool("scale_after_normalization") }; + _execute.record_gradient("BatchNormWithGlobalNormalization", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor batch_norm_with_global_normalization_eager_fallback(Tensor t, Tensor m, Tensor v, Tensor beta, Tensor gamma, float variance_epsilon, bool scale_after_normalization, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { t, m, v, beta, gamma }; + object[] _attrs = new object[] { "T", t.dtype, "variance_epsilon", variance_epsilon, "scale_after_normalization", scale_after_normalization }; + var _result = _execute.execute("BatchNormWithGlobalNormalization", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchNormWithGlobalNormalization", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gradients for batch normalization. + /// + /// + /// + /// This op is deprecated. See `tf.nn.batch_normalization`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number to avoid dividing by 0. + /// + /// + /// + /// + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// + public static Tensor[] batch_norm_with_global_normalization_grad(Tensor t, Tensor m, Tensor v, Tensor gamma, Tensor backprop, float variance_epsilon, bool scale_after_normalization, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BatchNormWithGlobalNormalizationGrad", name) { args = new object[] { t, m, v, gamma, backprop }, attrs = new Dictionary() { ["variance_epsilon"] = variance_epsilon, ["scale_after_normalization"] = scale_after_normalization } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return batch_norm_with_global_normalization_grad_eager_fallback(t, m, v, gamma, backprop, variance_epsilon: variance_epsilon, scale_after_normalization: scale_after_normalization, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["t"] = t; + keywords["m"] = m; + keywords["v"] = v; + keywords["gamma"] = gamma; + keywords["backprop"] = backprop; + keywords["variance_epsilon"] = variance_epsilon; + keywords["scale_after_normalization"] = scale_after_normalization; + var _op = tf.OpDefLib._apply_op_helper("BatchNormWithGlobalNormalizationGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "variance_epsilon", _op.get_attr("variance_epsilon"), "scale_after_normalization", _op._get_attr_bool("scale_after_normalization") }; + _execute.record_gradient("BatchNormWithGlobalNormalizationGrad", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] batch_norm_with_global_normalization_grad_eager_fallback(Tensor t, Tensor m, Tensor v, Tensor gamma, Tensor backprop, float variance_epsilon, bool scale_after_normalization, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { t, m, v, gamma, backprop }; + object[] _attrs = new object[] { "T", t.dtype, "variance_epsilon", variance_epsilon, "scale_after_normalization", scale_after_normalization }; + var _result = _execute.execute("BatchNormWithGlobalNormalizationGrad", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BatchNormWithGlobalNormalizationGrad", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Adds `bias` to `value`. + /// + /// + /// + /// This is a special case of `tf.add` where `bias` is restricted to be 1-D. + /// Broadcasting is supported, so `value` may have any number of dimensions. + /// + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the bias tensor will be added to the last dimension + /// of the value tensor. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// The tensor will be added to "in_channels", the third-to-the-last + /// dimension. + /// + /// + /// + public static Tensor bias_add(Tensor value, Tensor bias, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BiasAdd", name) { args = new object[] { value, bias }, attrs = new Dictionary() { ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bias_add_eager_fallback(value, bias, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["bias"] = bias; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("BiasAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("BiasAdd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor bias_add_eager_fallback(Tensor value, Tensor bias, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value, bias }; + object[] _attrs = new object[] { "T", value.dtype, "data_format", data_format }; + var _result = _execute.execute("BiasAdd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BiasAdd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// The backward operation for "BiasAdd" on the "bias" tensor. + /// + /// + /// + /// It accumulates all the values from out_backprop into the feature dimension. + /// For NHWC data format, the feature dimension is the last. For NCHW data format, + /// the feature dimension is the third-to-last. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the bias tensor will be added to the last dimension + /// of the value tensor. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// The tensor will be added to "in_channels", the third-to-the-last + /// dimension. + /// + /// + /// + public static Tensor bias_add_grad(Tensor out_backprop, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BiasAddGrad", name) { args = new object[] { out_backprop }, attrs = new Dictionary() { ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bias_add_grad_eager_fallback(out_backprop, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["out_backprop"] = out_backprop; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("BiasAddGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("BiasAddGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor bias_add_grad_eager_fallback(Tensor out_backprop, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { out_backprop }; + object[] _attrs = new object[] { "T", out_backprop.dtype, "data_format", data_format }; + var _result = _execute.execute("BiasAddGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BiasAddGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Adds `bias` to `value`. + /// + /// + /// + /// This is a deprecated version of BiasAdd and will be soon removed. + /// + /// This is a special case of `tf.add` where `bias` is restricted to be 1-D. + /// Broadcasting is supported, so `value` may have any number of dimensions. + /// + /// + /// + /// + /// + public static Tensor bias_add_v1(Tensor value, Tensor bias, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "BiasAddV1", name) { args = new object[] { value, bias }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return bias_add_v1_eager_fallback(value, bias, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["bias"] = bias; + var _op = tf.OpDefLib._apply_op_helper("BiasAddV1", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("BiasAddV1", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor bias_add_v1_eager_fallback(Tensor value, Tensor bias, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value, bias }; + object[] _attrs = new object[] { "T", value.dtype }; + var _result = _execute.execute("BiasAddV1", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("BiasAddV1", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes a 2-D convolution given 4-D `input` and `filter` tensors. + /// + /// + /// + /// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` + /// and a filter / kernel tensor of shape + /// `[filter_height, filter_width, in_channels, out_channels]`, this op + /// performs the following: + /// + /// 1. Flattens the filter to a 2-D matrix with shape + /// `[filter_height * filter_width * in_channels, output_channels]`. + /// 2. Extracts image patches from the input tensor to form a *virtual* + /// tensor of shape `[batch, out_height, out_width, + /// filter_height * filter_width * in_channels]`. + /// 3. For each patch, right-multiplies the filter matrix and the image patch + /// vector. + /// + /// In detail, with the default NHWC format, + /// + /// output[b, i, j, k] = + /// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * + /// filter[di, dj, q, k] + /// + /// Must have `strides[0] = strides[3] = 1`. For the most common case of the same + /// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 4. The stride of the sliding window for each + /// dimension of `input`. The dimension order is determined by the value of + /// `data_format`, see below for details. + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith + /// dimension, the amount of padding inserted before and after the dimension is + /// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If + /// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of `data_format`, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// + public static Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv2D", name) { args = new object[] { input, filter }, attrs = new Dictionary() { ["strides"] = strides, ["use_cudnn_on_gpu"] = use_cudnn_on_gpu, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv2d_eager_fallback(input, filter, strides: strides, use_cudnn_on_gpu: use_cudnn_on_gpu, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["strides"] = strides; + keywords["use_cudnn_on_gpu"] = use_cudnn_on_gpu; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "use_cudnn_on_gpu", _op._get_attr_bool("use_cudnn_on_gpu"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv2D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv2d_eager_fallback(Tensor input, Tensor filter, int[] strides, bool use_cudnn_on_gpu, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "use_cudnn_on_gpu", use_cudnn_on_gpu, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("Conv2D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv2D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of convolution with respect to the filter. + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// of the convolution. Must be in the same order as the dimension specified with + /// format. + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith + /// dimension, the amount of padding inserted before and after the dimension is + /// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If + /// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// `data_format`, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// + public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, bool use_cudnn_on_gpu = true, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv2DBackpropFilter", name) { args = new object[] { input, filter_sizes, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["use_cudnn_on_gpu"] = use_cudnn_on_gpu, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv2d_backprop_filter_eager_fallback(input, filter_sizes, out_backprop, strides: strides, use_cudnn_on_gpu: use_cudnn_on_gpu, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter_sizes"] = filter_sizes; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["use_cudnn_on_gpu"] = use_cudnn_on_gpu; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "use_cudnn_on_gpu", _op._get_attr_bool("use_cudnn_on_gpu"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv2DBackpropFilter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv2d_backprop_filter_eager_fallback(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, bool use_cudnn_on_gpu, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter_sizes, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "use_cudnn_on_gpu", use_cudnn_on_gpu, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("Conv2DBackpropFilter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv2DBackpropFilter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of convolution with respect to the input. + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// of the convolution. Must be in the same order as the dimension specified with + /// format. + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith + /// dimension, the amount of padding inserted before and after the dimension is + /// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If + /// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// `data_format`, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// + public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, bool use_cudnn_on_gpu = true, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv2DBackpropInput", name) { args = new object[] { input_sizes, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["use_cudnn_on_gpu"] = use_cudnn_on_gpu, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv2d_backprop_input_eager_fallback(input_sizes, filter, out_backprop, strides: strides, use_cudnn_on_gpu: use_cudnn_on_gpu, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input_sizes"] = input_sizes; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["use_cudnn_on_gpu"] = use_cudnn_on_gpu; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "use_cudnn_on_gpu", _op._get_attr_bool("use_cudnn_on_gpu"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv2DBackpropInput", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv2d_backprop_input_eager_fallback(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, bool use_cudnn_on_gpu, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_sizes, filter, out_backprop }; + object[] _attrs = new object[] { "T", filter.dtype, "strides", strides, "use_cudnn_on_gpu", use_cudnn_on_gpu, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("Conv2DBackpropInput", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv2DBackpropInput", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes a 3-D convolution given 5-D `input` and `filter` tensors. + /// + /// + /// + /// In signal processing, cross-correlation is a measure of similarity of + /// two waveforms as a function of a time-lag applied to one of them. This + /// is also known as a sliding dot product or sliding inner-product. + /// + /// Our Conv3D implements a form of cross-correlation. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of `data_format`, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// + public static Tensor conv3d(Tensor input, Tensor filter, int[] strides, string padding, string data_format = "NDHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv3D", name) { args = new object[] { input, filter }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv3d_eager_fallback(input, filter, strides: strides, padding: padding, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv3D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv3D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv3d_eager_fallback(Tensor input, Tensor filter, int[] strides, string padding, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("Conv3D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv3D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of 3-D convolution with respect to the filter. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + public static Tensor conv3d_backprop_filter(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv3DBackpropFilter", name) { args = new object[] { input, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv3d_backprop_filter_eager_fallback(input, filter, out_backprop, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv3DBackpropFilter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv3DBackpropFilter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv3d_backprop_filter_eager_fallback(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("Conv3DBackpropFilter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv3DBackpropFilter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of 3-D convolution with respect to the filter. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of `data_format`, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// + public static Tensor conv3d_backprop_filter_v2(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, string data_format = "NDHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv3DBackpropFilterV2", name) { args = new object[] { input, filter_sizes, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv3d_backprop_filter_v2_eager_fallback(input, filter_sizes, out_backprop, strides: strides, padding: padding, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter_sizes"] = filter_sizes; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv3DBackpropFilterV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv3DBackpropFilterV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv3d_backprop_filter_v2_eager_fallback(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter_sizes, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("Conv3DBackpropFilterV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv3DBackpropFilterV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of 3-D convolution with respect to the input. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + public static Tensor conv3d_backprop_input(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv3DBackpropInput", name) { args = new object[] { input, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv3d_backprop_input_eager_fallback(input, filter, out_backprop, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv3DBackpropInput", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("Conv3DBackpropInput", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv3d_backprop_input_eager_fallback(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("Conv3DBackpropInput", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv3DBackpropInput", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of 3-D convolution with respect to the input. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of `data_format`, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// + public static Tensor conv3d_backprop_input_v2(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, string data_format = "NDHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Conv3DBackpropInputV2", name) { args = new object[] { input_sizes, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return conv3d_backprop_input_v2_eager_fallback(input_sizes, filter, out_backprop, strides: strides, padding: padding, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["input_sizes"] = input_sizes; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("Conv3DBackpropInputV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations"), "Tshape", _op._get_attr_type("Tshape") }; + _execute.record_gradient("Conv3DBackpropInputV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor conv3d_backprop_input_v2_eager_fallback(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_sizes, filter, out_backprop }; + object[] _attrs = new object[] { "T", filter.dtype, "strides", strides, "padding", padding, "data_format", data_format, "dilations", dilations, "Tshape", input_sizes.dtype }; + var _result = _execute.execute("Conv3DBackpropInputV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Conv3DBackpropInputV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the dimension index in the destination data format given the one in + /// + /// + /// + /// the source data format. + /// + /// + /// + /// + /// + /// source data format. + /// + /// + /// + /// + /// destination data format. + /// + /// + /// + public static Tensor data_format_dim_map(Tensor x, string src_format = "NHWC", string dst_format = "NCHW", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DataFormatDimMap", name) { args = new object[] { x }, attrs = new Dictionary() { ["src_format"] = src_format, ["dst_format"] = dst_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return data_format_dim_map_eager_fallback(x, src_format: src_format, dst_format: dst_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (src_format is null) + { + src_format = "NHWC"; + } + if (dst_format is null) + { + dst_format = "NCHW"; + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["src_format"] = src_format; + keywords["dst_format"] = dst_format; + var _op = tf.OpDefLib._apply_op_helper("DataFormatDimMap", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "src_format", _op.get_attr("src_format"), "dst_format", _op.get_attr("dst_format") }; + _execute.record_gradient("DataFormatDimMap", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor data_format_dim_map_eager_fallback(Tensor x, string src_format, string dst_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype, "src_format", src_format, "dst_format", dst_format }; + var _result = _execute.execute("DataFormatDimMap", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DataFormatDimMap", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Permute input tensor from `src_format` to `dst_format`. + /// + /// + /// + /// Given source and destination format strings of length n=4 or 5, the input + /// tensor must be a vector of size n or n-2, or a 2D tensor of shape + /// (n, 2) or (n-2, 2). + /// + /// If the first dimension of the input tensor is n-2, it is assumed that + /// non-spatial dimensions are omitted (i.e `N`, `C`). + /// + /// For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and input: + /// ``` + /// [1, 2, 3, 4] + /// ``` + /// , the output will be: + /// ``` + /// [1, 4, 2, 3] + /// ``` + /// With `src_format` of `NDHWC`, `dst_format` of `NCDHW`, and input: + /// ``` + /// [[1, 6], [2, 7], [3, 8], [4, 9], [5, 10]] + /// ``` + /// , the output will be: + /// ``` + /// [[1, 6], [5, 10], [2, 7], [3, 8], [4, 9]] + /// ``` + /// With `src_format` of `NHWC`, `dst_format` of `NCHW`, and input: + /// ``` + /// [1, 2] + /// ``` + /// , the output will be: + /// ``` + /// [1, 2] + /// ``` + /// + /// + /// + /// + /// + /// source data format. + /// + /// + /// + /// + /// destination data format. + /// + /// + /// + public static Tensor data_format_vec_permute(Tensor x, string src_format = "NHWC", string dst_format = "NCHW", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DataFormatVecPermute", name) { args = new object[] { x }, attrs = new Dictionary() { ["src_format"] = src_format, ["dst_format"] = dst_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return data_format_vec_permute_eager_fallback(x, src_format: src_format, dst_format: dst_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (src_format is null) + { + src_format = "NHWC"; + } + if (dst_format is null) + { + dst_format = "NCHW"; + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["src_format"] = src_format; + keywords["dst_format"] = dst_format; + var _op = tf.OpDefLib._apply_op_helper("DataFormatVecPermute", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "src_format", _op.get_attr("src_format"), "dst_format", _op.get_attr("dst_format") }; + _execute.record_gradient("DataFormatVecPermute", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor data_format_vec_permute_eager_fallback(Tensor x, string src_format, string dst_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x }; + object[] _attrs = new object[] { "T", x.dtype, "src_format", src_format, "dst_format", dst_format }; + var _result = _execute.execute("DataFormatVecPermute", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DataFormatVecPermute", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. + /// + /// + /// + /// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` + /// and a filter / kernel tensor of shape + /// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing + /// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies + /// a different filter to each input channel (expanding from 1 channel to + /// `channel_multiplier` channels for each), then concatenates the results + /// together. Thus, the output has `in_channels * channel_multiplier` channels. + /// + /// ``` + /// for k in 0..in_channels-1 + /// for q in 0..channel_multiplier-1 + /// output[b, i, j, k * channel_multiplier + q] = + /// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * + /// filter[di, dj, k, q] + /// ``` + /// + /// Must have `strides[0] = strides[3] = 1`. For the most common case of the same + /// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + /// + /// + /// + /// + /// + /// + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of `input`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// `data_format`, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// + public static Tensor depthwise_conv2d_native(Tensor input, Tensor filter, int[] strides, string padding, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DepthwiseConv2dNative", name) { args = new object[] { input, filter }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return depthwise_conv2d_native_eager_fallback(input, filter, strides: strides, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNative", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("DepthwiseConv2dNative", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor depthwise_conv2d_native_eager_fallback(Tensor input, Tensor filter, int[] strides, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("DepthwiseConv2dNative", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DepthwiseConv2dNative", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of depthwise convolution with respect to the filter. + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// of the convolution. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// `data_format`, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// + public static Tensor depthwise_conv2d_native_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DepthwiseConv2dNativeBackpropFilter", name) { args = new object[] { input, filter_sizes, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return depthwise_conv2d_native_backprop_filter_eager_fallback(input, filter_sizes, out_backprop, strides: strides, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter_sizes"] = filter_sizes; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNativeBackpropFilter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("DepthwiseConv2dNativeBackpropFilter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor depthwise_conv2d_native_backprop_filter_eager_fallback(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter_sizes, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("DepthwiseConv2dNativeBackpropFilter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DepthwiseConv2dNativeBackpropFilter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradients of depthwise convolution with respect to the input. + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// of the convolution. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// `data_format`, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// + public static Tensor depthwise_conv2d_native_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] explicit_paddings = null, string data_format = "NHWC", int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DepthwiseConv2dNativeBackpropInput", name) { args = new object[] { input_sizes, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format, ["dilations"] = dilations } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return depthwise_conv2d_native_backprop_input_eager_fallback(input_sizes, filter, out_backprop, strides: strides, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input_sizes"] = input_sizes; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNativeBackpropInput", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("DepthwiseConv2dNativeBackpropInput", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor depthwise_conv2d_native_backprop_input_eager_fallback(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] explicit_paddings, string data_format, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_sizes, filter, out_backprop }; + object[] _attrs = new object[] { "T", filter.dtype, "strides", strides, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "dilations", dilations }; + var _result = _execute.execute("DepthwiseConv2dNativeBackpropInput", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DepthwiseConv2dNativeBackpropInput", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. + /// + /// + /// + /// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the + /// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each + /// input channel is processed independently of the others with its own structuring + /// function. The `output` tensor has shape + /// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output + /// tensor depend on the `padding` algorithm. We currently only support the default + /// "NHWC" `data_format`. + /// + /// In detail, the grayscale morphological 2-D dilation is the max-sum correlation + /// (for consistency with `conv2d`, we use unmirrored filters): + /// + /// output[b, y, x, c] = + /// max_{dy, dx} input[b, + /// strides[1] * y + rates[1] * dy, + /// strides[2] * x + rates[2] * dx, + /// c] + + /// filter[dy, dx, c] + /// + /// Max-pooling is a special case when the filter has size equal to the pooling + /// kernel size and contains all zeros. + /// + /// Note on duality: The dilation of `input` by the `filter` is equal to the + /// negation of the erosion of `-input` by the reflected `filter`. + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// tensor. Must be: `[1, stride_height, stride_width, 1]`. + /// + /// + /// + /// + /// The input stride for atrous morphological dilation. Must be: + /// `[1, rate_height, rate_width, 1]`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor dilation2d(Tensor input, Tensor filter, int[] strides, int[] rates, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Dilation2D", name) { args = new object[] { input, filter }, attrs = new Dictionary() { ["strides"] = strides, ["rates"] = rates, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return dilation2d_eager_fallback(input, filter, strides: strides, rates: rates, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["strides"] = strides; + keywords["rates"] = rates; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("Dilation2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "rates", _op.get_attr("rates"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("Dilation2D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor dilation2d_eager_fallback(Tensor input, Tensor filter, int[] strides, int[] rates, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "rates", rates, "padding", padding }; + var _result = _execute.execute("Dilation2D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Dilation2D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient of morphological 2-D dilation with respect to the filter. + /// + /// + /// + /// + /// + /// + /// 1-D of length 4. The stride of the sliding window for each dimension of + /// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. + /// + /// + /// + /// + /// 1-D of length 4. The input stride for atrous morphological dilation. + /// Must be: `[1, rate_height, rate_width, 1]`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor dilation2d_backprop_filter(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Dilation2DBackpropFilter", name) { args = new object[] { input, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["rates"] = rates, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return dilation2d_backprop_filter_eager_fallback(input, filter, out_backprop, strides: strides, rates: rates, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["rates"] = rates; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("Dilation2DBackpropFilter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "rates", _op.get_attr("rates"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("Dilation2DBackpropFilter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor dilation2d_backprop_filter_eager_fallback(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "rates", rates, "padding", padding }; + var _result = _execute.execute("Dilation2DBackpropFilter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Dilation2DBackpropFilter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the gradient of morphological 2-D dilation with respect to the input. + /// + /// + /// + /// + /// + /// + /// 1-D of length 4. The stride of the sliding window for each dimension of + /// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. + /// + /// + /// + /// + /// 1-D of length 4. The input stride for atrous morphological dilation. + /// Must be: `[1, rate_height, rate_width, 1]`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor dilation2d_backprop_input(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Dilation2DBackpropInput", name) { args = new object[] { input, filter, out_backprop }, attrs = new Dictionary() { ["strides"] = strides, ["rates"] = rates, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return dilation2d_backprop_input_eager_fallback(input, filter, out_backprop, strides: strides, rates: rates, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["out_backprop"] = out_backprop; + keywords["strides"] = strides; + keywords["rates"] = rates; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("Dilation2DBackpropInput", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "strides", _op.get_attr("strides"), "rates", _op.get_attr("rates"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("Dilation2DBackpropInput", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor dilation2d_backprop_input_eager_fallback(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, out_backprop }; + object[] _attrs = new object[] { "T", input.dtype, "strides", strides, "rates", rates, "padding", padding }; + var _result = _execute.execute("Dilation2DBackpropInput", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Dilation2DBackpropInput", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes the exponential linear function. + /// + /// + /// + /// The ELU function is defined as: + /// + /// * $ e ^ x - 1 $ if $ x < 0 $ + /// * $ x $ if $ x >= 0 $ + /// + /// Examples: + /// + /// >>> tf.nn.elu(1.0) + /// + /// >>> tf.nn.elu(0.0) + /// + /// >>> tf.nn.elu(-1000.0) + /// + /// + /// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) + /// ](http://arxiv.org/abs/1511.07289) + /// + /// + /// + /// + public static Tensor elu(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Elu", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return elu_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Elu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Elu", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor elu_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Elu", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Elu", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients for the exponential linear (Elu) operation. + /// + /// + /// + /// + public static Tensor elu_grad(Tensor gradients, Tensor outputs, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EluGrad", name) { args = new object[] { gradients, outputs }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return elu_grad_eager_fallback(gradients, outputs, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["outputs"] = outputs; + var _op = tf.OpDefLib._apply_op_helper("EluGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("EluGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor elu_grad_eager_fallback(Tensor gradients, Tensor outputs, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, outputs }; + object[] _attrs = new object[] { "T", gradients.dtype }; + var _result = _execute.execute("EluGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EluGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs fractional average pooling on the input. + /// + /// + /// + /// Fractional average pooling is similar to Fractional max pooling in the pooling + /// region generation step. The only difference is that after pooling regions are + /// generated, a mean operation is performed instead of a max operation in each + /// pooling region. + /// + /// + /// + /// + /// + /// Pooling ratio for each dimension of `value`, currently only + /// supports row and col dimension and should be >= 1.0. For example, a valid + /// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + /// must be 1.0 because we don't allow pooling on batch and channels + /// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + /// respectively. + /// + /// + /// + /// + /// When set to True, generates the pooling sequence in a + /// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + /// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for + /// difference between pseudorandom and random. + /// + /// + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// `index 0 1 2 3 4` + /// + /// `value 20 5 16 3 7` + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [41/3, 26/3] for fractional avg pooling. + /// + /// + /// + /// + /// When set to True, a fixed pooling region will be used when + /// iterating over a FractionalAvgPool node in the computation graph. Mainly used + /// in unit test to make FractionalAvgPool deterministic. + /// + /// + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// + public static Tensor[] fractional_avg_pool(Tensor value, float[] pooling_ratio, bool pseudo_random = false, bool overlapping = false, bool deterministic = false, int seed = 0, int seed2 = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FractionalAvgPool", name) { args = new object[] { value }, attrs = new Dictionary() { ["pooling_ratio"] = pooling_ratio, ["pseudo_random"] = pseudo_random, ["overlapping"] = overlapping, ["deterministic"] = deterministic, ["seed"] = seed, ["seed2"] = seed2 } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fractional_avg_pool_eager_fallback(value, pooling_ratio: pooling_ratio, pseudo_random: pseudo_random, overlapping: overlapping, deterministic: deterministic, seed: seed, seed2: seed2, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["pooling_ratio"] = pooling_ratio; + keywords["pseudo_random"] = pseudo_random; + keywords["overlapping"] = overlapping; + keywords["deterministic"] = deterministic; + keywords["seed"] = seed; + keywords["seed2"] = seed2; + var _op = tf.OpDefLib._apply_op_helper("FractionalAvgPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "pooling_ratio", _op.get_attr("pooling_ratio"), "pseudo_random", _op._get_attr_bool("pseudo_random"), "overlapping", _op._get_attr_bool("overlapping"), "deterministic", _op._get_attr_bool("deterministic"), "seed", _op._get_attr_int("seed"), "seed2", _op._get_attr_int("seed2"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("FractionalAvgPool", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fractional_avg_pool_eager_fallback(Tensor value, float[] pooling_ratio, bool pseudo_random, bool overlapping, bool deterministic, int seed, int seed2, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value }; + object[] _attrs = new object[] { "pooling_ratio", pooling_ratio, "pseudo_random", pseudo_random, "overlapping", overlapping, "deterministic", deterministic, "seed", seed, "seed2", seed2, "T", value.dtype }; + var _result = _execute.execute("FractionalAvgPool", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FractionalAvgPool", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes gradient of the FractionalAvgPool function. + /// + /// + /// + /// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for + /// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of + /// out_backprop to those indices that form the same pooling cell. Therefore, we + /// just need to know the shape of original input tensor, instead of the whole + /// tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// `index 0 1 2 3 4` + /// + /// `value 20 5 16 3 7` + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [41/3, 26/3] for fractional avg pooling. + /// + /// + /// + public static Tensor fractional_avg_pool_grad(Tensor orig_input_tensor_shape, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool overlapping = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FractionalAvgPoolGrad", name) { args = new object[] { orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence }, attrs = new Dictionary() { ["overlapping"] = overlapping } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fractional_avg_pool_grad_eager_fallback(orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, overlapping: overlapping, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["orig_input_tensor_shape"] = orig_input_tensor_shape; + keywords["out_backprop"] = out_backprop; + keywords["row_pooling_sequence"] = row_pooling_sequence; + keywords["col_pooling_sequence"] = col_pooling_sequence; + keywords["overlapping"] = overlapping; + var _op = tf.OpDefLib._apply_op_helper("FractionalAvgPoolGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "overlapping", _op._get_attr_bool("overlapping"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("FractionalAvgPoolGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fractional_avg_pool_grad_eager_fallback(Tensor orig_input_tensor_shape, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool overlapping, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence }; + object[] _attrs = new object[] { "overlapping", overlapping, "T", out_backprop.dtype }; + var _result = _execute.execute("FractionalAvgPoolGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FractionalAvgPoolGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs fractional max pooling on the input. + /// + /// + /// + /// Fractional max pooling is slightly different than regular max pooling. In + /// regular max pooling, you downsize an input set by taking the maximum value of + /// smaller N x N subsections of the set (often 2x2), and try to reduce the set by + /// a factor of N, where N is an integer. Fractional max pooling, as you might + /// expect from the word "fractional", means that the overall reduction ratio N + /// does not have to be an integer. + /// + /// The sizes of the pooling regions are generated randomly but are fairly uniform. + /// For example, let's look at the height dimension, and the constraints on the + /// list of rows that will be pool boundaries. + /// + /// First we define the following: + /// + /// 1. input_row_length : the number of rows from the input set + /// 2. output_row_length : which will be smaller than the input + /// 3. alpha = input_row_length / output_row_length : our reduction ratio + /// 4. K = floor(alpha) + /// 5. row_pooling_sequence : this is the result list of pool boundary rows + /// + /// Then, row_pooling_sequence should satisfy: + /// + /// 1. a[0] = 0 : the first value of the sequence is 0 + /// 2. a[end] = input_row_length : the last value of the sequence is the size + /// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size + /// 4. length(row_pooling_sequence) = output_row_length+1 + /// + /// For more details on fractional max pooling, see this paper: + /// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) + /// + /// + /// + /// + /// + /// Pooling ratio for each dimension of `value`, currently only + /// supports row and col dimension and should be >= 1.0. For example, a valid + /// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + /// must be 1.0 because we don't allow pooling on batch and channels + /// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + /// respectively. + /// + /// + /// + /// + /// When set to True, generates the pooling sequence in a + /// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + /// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for + /// difference between pseudorandom and random. + /// + /// + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// `index 0 1 2 3 4` + /// + /// `value 20 5 16 3 7` + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [20, 16] for fractional max pooling. + /// + /// + /// + /// + /// When set to True, a fixed pooling region will be used when + /// iterating over a FractionalMaxPool node in the computation graph. Mainly used + /// in unit test to make FractionalMaxPool deterministic. + /// + /// + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// + public static Tensor[] fractional_max_pool(Tensor value, float[] pooling_ratio, bool pseudo_random = false, bool overlapping = false, bool deterministic = false, int seed = 0, int seed2 = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FractionalMaxPool", name) { args = new object[] { value }, attrs = new Dictionary() { ["pooling_ratio"] = pooling_ratio, ["pseudo_random"] = pseudo_random, ["overlapping"] = overlapping, ["deterministic"] = deterministic, ["seed"] = seed, ["seed2"] = seed2 } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fractional_max_pool_eager_fallback(value, pooling_ratio: pooling_ratio, pseudo_random: pseudo_random, overlapping: overlapping, deterministic: deterministic, seed: seed, seed2: seed2, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["value"] = value; + keywords["pooling_ratio"] = pooling_ratio; + keywords["pseudo_random"] = pseudo_random; + keywords["overlapping"] = overlapping; + keywords["deterministic"] = deterministic; + keywords["seed"] = seed; + keywords["seed2"] = seed2; + var _op = tf.OpDefLib._apply_op_helper("FractionalMaxPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "pooling_ratio", _op.get_attr("pooling_ratio"), "pseudo_random", _op._get_attr_bool("pseudo_random"), "overlapping", _op._get_attr_bool("overlapping"), "deterministic", _op._get_attr_bool("deterministic"), "seed", _op._get_attr_int("seed"), "seed2", _op._get_attr_int("seed2"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("FractionalMaxPool", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fractional_max_pool_eager_fallback(Tensor value, float[] pooling_ratio, bool pseudo_random, bool overlapping, bool deterministic, int seed, int seed2, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { value }; + object[] _attrs = new object[] { "pooling_ratio", pooling_ratio, "pseudo_random", pseudo_random, "overlapping", overlapping, "deterministic", deterministic, "seed", seed, "seed2", seed2, "T", value.dtype }; + var _result = _execute.execute("FractionalMaxPool", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FractionalMaxPool", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes gradient of the FractionalMaxPool function. + /// + /// + /// + /// + /// + /// + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// `index 0 1 2 3 4` + /// + /// `value 20 5 16 3 7` + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [20, 16] for fractional max pooling. + /// + /// + /// + public static Tensor fractional_max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool overlapping = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FractionalMaxPoolGrad", name) { args = new object[] { orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence }, attrs = new Dictionary() { ["overlapping"] = overlapping } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fractional_max_pool_grad_eager_fallback(orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence, overlapping: overlapping, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["out_backprop"] = out_backprop; + keywords["row_pooling_sequence"] = row_pooling_sequence; + keywords["col_pooling_sequence"] = col_pooling_sequence; + keywords["overlapping"] = overlapping; + var _op = tf.OpDefLib._apply_op_helper("FractionalMaxPoolGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "overlapping", _op._get_attr_bool("overlapping"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("FractionalMaxPoolGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fractional_max_pool_grad_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool overlapping, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence }; + object[] _attrs = new object[] { "overlapping", overlapping, "T", orig_input.dtype }; + var _result = _execute.execute("FractionalMaxPoolGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FractionalMaxPoolGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// + /// The data format for x and y. Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon = 0.0001f, float exponential_avg_factor = 1f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNorm", name) { args = new object[] { x, scale, offset, mean, variance }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["exponential_avg_factor"] = exponential_avg_factor, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_eager_fallback(x, scale, offset, mean, variance, epsilon: epsilon, exponential_avg_factor: exponential_avg_factor, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["scale"] = scale; + keywords["offset"] = offset; + keywords["mean"] = mean; + keywords["variance"] = variance; + keywords["epsilon"] = epsilon; + keywords["exponential_avg_factor"] = exponential_avg_factor; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNorm", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "epsilon", _op.get_attr("epsilon"), "exponential_avg_factor", _op.get_attr("exponential_avg_factor"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNorm", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_eager_fallback(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon, float exponential_avg_factor, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, scale, offset, mean, variance }; + object[] _attrs = new object[] { "T", x.dtype, "epsilon", epsilon, "exponential_avg_factor", exponential_avg_factor, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNorm", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNorm", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Gradient for batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// The data format for y_backprop, x, x_backprop. + /// Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm_grad(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float epsilon = 0.0001f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNormGrad", name) { args = new object[] { y_backprop, x, scale, reserve_space_1, reserve_space_2 }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_grad_eager_fallback(y_backprop, x, scale, reserve_space_1, reserve_space_2, epsilon: epsilon, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["y_backprop"] = y_backprop; + keywords["x"] = x; + keywords["scale"] = scale; + keywords["reserve_space_1"] = reserve_space_1; + keywords["reserve_space_2"] = reserve_space_2; + keywords["epsilon"] = epsilon; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "epsilon", _op.get_attr("epsilon"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNormGrad", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_grad_eager_fallback(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float epsilon, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y_backprop, x, scale, reserve_space_1, reserve_space_2 }; + object[] _attrs = new object[] { "T", y_backprop.dtype, "epsilon", epsilon, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNormGrad", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNormGrad", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Gradient for batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// The data format for y_backprop, x, x_backprop. + /// Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm_grad_v2(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float epsilon = 0.0001f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNormGradV2", name) { args = new object[] { y_backprop, x, scale, reserve_space_1, reserve_space_2 }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_grad_v2_eager_fallback(y_backprop, x, scale, reserve_space_1, reserve_space_2, epsilon: epsilon, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["y_backprop"] = y_backprop; + keywords["x"] = x; + keywords["scale"] = scale; + keywords["reserve_space_1"] = reserve_space_1; + keywords["reserve_space_2"] = reserve_space_2; + keywords["epsilon"] = epsilon; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormGradV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "U", _op._get_attr_type("U"), "epsilon", _op.get_attr("epsilon"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNormGradV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_grad_v2_eager_fallback(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float epsilon, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y_backprop, x, scale, reserve_space_1, reserve_space_2 }; + object[] _attrs = new object[] { "T", y_backprop.dtype, "U", reserve_space_1.dtype, "epsilon", epsilon, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNormGradV2", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNormGradV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Gradient for batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// The data format for y_backprop, x, x_backprop. + /// Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm_grad_v3(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, Tensor reserve_space_3, float epsilon = 0.0001f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNormGradV3", name) { args = new object[] { y_backprop, x, scale, reserve_space_1, reserve_space_2, reserve_space_3 }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_grad_v3_eager_fallback(y_backprop, x, scale, reserve_space_1, reserve_space_2, reserve_space_3, epsilon: epsilon, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["y_backprop"] = y_backprop; + keywords["x"] = x; + keywords["scale"] = scale; + keywords["reserve_space_1"] = reserve_space_1; + keywords["reserve_space_2"] = reserve_space_2; + keywords["reserve_space_3"] = reserve_space_3; + keywords["epsilon"] = epsilon; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "U", _op._get_attr_type("U"), "epsilon", _op.get_attr("epsilon"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNormGradV3", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_grad_v3_eager_fallback(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, Tensor reserve_space_3, float epsilon, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { y_backprop, x, scale, reserve_space_1, reserve_space_2, reserve_space_3 }; + object[] _attrs = new object[] { "T", y_backprop.dtype, "U", reserve_space_1.dtype, "epsilon", epsilon, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNormGradV3", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNormGradV3", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// + /// The data format for x and y. Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm_v2(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon = 0.0001f, float exponential_avg_factor = 1f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNormV2", name) { args = new object[] { x, scale, offset, mean, variance }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["exponential_avg_factor"] = exponential_avg_factor, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_v2_eager_fallback(x, scale, offset, mean, variance, epsilon: epsilon, exponential_avg_factor: exponential_avg_factor, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["scale"] = scale; + keywords["offset"] = offset; + keywords["mean"] = mean; + keywords["variance"] = variance; + keywords["epsilon"] = epsilon; + keywords["exponential_avg_factor"] = exponential_avg_factor; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "U", _op._get_attr_type("U"), "epsilon", _op.get_attr("epsilon"), "exponential_avg_factor", _op.get_attr("exponential_avg_factor"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNormV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_v2_eager_fallback(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon, float exponential_avg_factor, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, scale, offset, mean, variance }; + object[] _attrs = new object[] { "T", x.dtype, "U", scale.dtype, "epsilon", epsilon, "exponential_avg_factor", exponential_avg_factor, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNormV2", 5, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNormV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Batch normalization. + /// + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// + /// + /// + /// The data format for x and y. Either "NHWC" (default) or "NCHW". + /// + /// + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// + public static Tensor[] fused_batch_norm_v3(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon = 0.0001f, float exponential_avg_factor = 1f, string data_format = "NHWC", bool is_training = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedBatchNormV3", name) { args = new object[] { x, scale, offset, mean, variance }, attrs = new Dictionary() { ["epsilon"] = epsilon, ["exponential_avg_factor"] = exponential_avg_factor, ["data_format"] = data_format, ["is_training"] = is_training } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_batch_norm_v3_eager_fallback(x, scale, offset, mean, variance, epsilon: epsilon, exponential_avg_factor: exponential_avg_factor, data_format: data_format, is_training: is_training, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["x"] = x; + keywords["scale"] = scale; + keywords["offset"] = offset; + keywords["mean"] = mean; + keywords["variance"] = variance; + keywords["epsilon"] = epsilon; + keywords["exponential_avg_factor"] = exponential_avg_factor; + keywords["data_format"] = data_format; + keywords["is_training"] = is_training; + var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV3", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "U", _op._get_attr_type("U"), "epsilon", _op.get_attr("epsilon"), "exponential_avg_factor", _op.get_attr("exponential_avg_factor"), "data_format", _op.get_attr("data_format"), "is_training", _op._get_attr_bool("is_training") }; + _execute.record_gradient("FusedBatchNormV3", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] fused_batch_norm_v3_eager_fallback(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float epsilon, float exponential_avg_factor, string data_format, bool is_training, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { x, scale, offset, mean, variance }; + object[] _attrs = new object[] { "T", x.dtype, "U", scale.dtype, "epsilon", epsilon, "exponential_avg_factor", exponential_avg_factor, "data_format", data_format, "is_training", is_training }; + var _result = _execute.execute("FusedBatchNormV3", 6, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedBatchNormV3", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Performs a padding as a preprocess during a convolution. + /// + /// + /// + /// Similar to FusedResizeAndPadConv2d, this op allows for an optimized + /// implementation where the spatial padding transformation stage is fused with the + /// im2col lookup, but in this case without the bilinear filtering required for + /// resizing. Fusing the padding prevents the need to write out the intermediate + /// results as whole tensors, reducing memory pressure, and we can get some latency + /// gains by merging the transformation calculations. + /// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' + /// order is used instead. + /// Internally this op uses a single per-graph scratch buffer, which means that it + /// will block if multiple versions are being run in parallel. This is because this + /// operator is primarily an optimization to minimize memory usage. + /// + /// + /// + /// + /// + /// + /// + /// + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of `input`. Must be in the same order as the dimension specified with format. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor fused_pad_conv2d(Tensor input, Tensor paddings, Tensor filter, string mode, int[] strides, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedPadConv2D", name) { args = new object[] { input, paddings, filter }, attrs = new Dictionary() { ["mode"] = mode, ["strides"] = strides, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_pad_conv2d_eager_fallback(input, paddings, filter, mode: mode, strides: strides, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["paddings"] = paddings; + keywords["filter"] = filter; + keywords["mode"] = mode; + keywords["strides"] = strides; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("FusedPadConv2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "mode", _op.get_attr("mode"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("FusedPadConv2D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fused_pad_conv2d_eager_fallback(Tensor input, Tensor paddings, Tensor filter, string mode, int[] strides, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, paddings, filter }; + object[] _attrs = new object[] { "T", input.dtype, "mode", mode, "strides", strides, "padding", padding }; + var _result = _execute.execute("FusedPadConv2D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedPadConv2D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs a resize and padding as a preprocess during a convolution. + /// + /// + /// + /// It's often possible to do spatial transformations more efficiently as part of + /// the packing stage of a convolution, so this op allows for an optimized + /// implementation where these stages are fused together. This prevents the need to + /// write out the intermediate results as whole tensors, reducing memory pressure, + /// and we can get some latency gains by merging the transformation calculations. + /// The data_format attribute for Conv2D isn't supported by this op, and defaults to + /// 'NHWC' order. + /// Internally this op uses a single per-graph scratch buffer, which means that it + /// will block if multiple versions are being run in parallel. This is because this + /// operator is primarily an optimization to minimize memory usage. + /// + /// + /// + /// + /// + /// + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// + /// + /// + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of `input`. Must be in the same order as the dimension specified with format. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor fused_resize_and_pad_conv2d(Tensor input, Tensor size, Tensor paddings, Tensor filter, string mode, int[] strides, string padding, bool resize_align_corners = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FusedResizeAndPadConv2D", name) { args = new object[] { input, size, paddings, filter }, attrs = new Dictionary() { ["resize_align_corners"] = resize_align_corners, ["mode"] = mode, ["strides"] = strides, ["padding"] = padding } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return fused_resize_and_pad_conv2d_eager_fallback(input, size, paddings, filter, resize_align_corners: resize_align_corners, mode: mode, strides: strides, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["size"] = size; + keywords["paddings"] = paddings; + keywords["filter"] = filter; + keywords["resize_align_corners"] = resize_align_corners; + keywords["mode"] = mode; + keywords["strides"] = strides; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("FusedResizeAndPadConv2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "resize_align_corners", _op._get_attr_bool("resize_align_corners"), "mode", _op.get_attr("mode"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("FusedResizeAndPadConv2D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fused_resize_and_pad_conv2d_eager_fallback(Tensor input, Tensor size, Tensor paddings, Tensor filter, bool resize_align_corners, string mode, int[] strides, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, size, paddings, filter }; + object[] _attrs = new object[] { "T", input.dtype, "resize_align_corners", resize_align_corners, "mode", mode, "strides", strides, "padding", padding }; + var _result = _execute.execute("FusedResizeAndPadConv2D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FusedResizeAndPadConv2D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Says whether the targets are in the top `K` predictions. + /// + /// + /// + /// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the + /// prediction for the target class is among the top `k` predictions among + /// all predictions for example `i`. Note that the behavior of `InTopK` differs + /// from the `TopK` op in its handling of ties; if multiple classes have the + /// same prediction value and straddle the top-`k` boundary, all of those + /// classes are considered to be in the top `k`. + /// + /// More formally, let + /// + /// \(predictions_i\) be the predictions for all classes for example `i`, + /// \(targets_i\) be the target class for example `i`, + /// \(out_i\) be the output for example `i`, + /// + /// $$out_i = predictions_{i, targets_i} in TopKIncludingTies(predictions_i)$$ + /// + /// + /// + /// + /// + /// + /// Number of top elements to look at for computing precision. + /// + /// + /// + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k = 0, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InTopK", name) { args = new object[] { predictions, targets }, attrs = new Dictionary() { ["k"] = k } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return in_top_k_eager_fallback(predictions, targets, k: k, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["predictions"] = predictions; + keywords["targets"] = targets; + keywords["k"] = k; + var _op = tf.OpDefLib._apply_op_helper("InTopK", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "k", _op._get_attr_int("k"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("InTopK", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor in_top_k_eager_fallback(Tensor predictions, Tensor targets, int k, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { predictions, targets }; + object[] _attrs = new object[] { "k", k, "T", targets.dtype }; + var _result = _execute.execute("InTopK", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InTopK", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Says whether the targets are in the top `K` predictions. + /// + /// + /// + /// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the + /// prediction for the target class is among the top `k` predictions among + /// all predictions for example `i`. Note that the behavior of `InTopK` differs + /// from the `TopK` op in its handling of ties; if multiple classes have the + /// same prediction value and straddle the top-`k` boundary, all of those + /// classes are considered to be in the top `k`. + /// + /// More formally, let + /// + /// \(predictions_i\) be the predictions for all classes for example `i`, + /// \(targets_i\) be the target class for example `i`, + /// \(out_i\) be the output for example `i`, + /// + /// $$out_i = predictions_{i, targets_i} in TopKIncludingTies(predictions_i)$$ + /// + /// + /// + /// + /// + /// + public static Tensor in_top_kv2(Tensor predictions, Tensor targets, Tensor k, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "InTopKV2", name) { args = new object[] { predictions, targets, k }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return in_top_kv2_eager_fallback(predictions, targets, k, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["predictions"] = predictions; + keywords["targets"] = targets; + keywords["k"] = k; + var _op = tf.OpDefLib._apply_op_helper("InTopKV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("InTopKV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor in_top_kv2_eager_fallback(Tensor predictions, Tensor targets, Tensor k, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { predictions, targets, k }; + object[] _attrs = new object[] { "T", targets.dtype }; + var _result = _execute.execute("InTopKV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("InTopKV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Solves a batch of isotonic regression problems. + /// + /// + /// + /// Dtype of output. + /// + /// + public static Tensor[] isotonic_regression(Tensor input, TF_DataType output_dtype = TF_DataType.TF_FLOAT, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "IsotonicRegression", name) { args = new object[] { input }, attrs = new Dictionary() { ["output_dtype"] = output_dtype } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return isotonic_regression_eager_fallback(input, output_dtype: output_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["output_dtype"] = output_dtype; + var _op = tf.OpDefLib._apply_op_helper("IsotonicRegression", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "output_dtype", _op._get_attr_type("output_dtype") }; + _execute.record_gradient("IsotonicRegression", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] isotonic_regression_eager_fallback(Tensor input, TF_DataType output_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "output_dtype", output_dtype }; + var _result = _execute.execute("IsotonicRegression", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("IsotonicRegression", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Local Response Normalization. + /// + /// + /// + /// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last + /// dimension), and each vector is normalized independently. Within a given vector, + /// each component is divided by the weighted, squared sum of inputs within + /// `depth_radius`. In detail, + /// + /// sqr_sum[a, b, c, d] = + /// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) + /// output = input / (bias + alpha * sqr_sum) ** beta + /// + /// For details, see [Krizhevsky et al., ImageNet classification with deep + /// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). + /// + /// + /// + /// + /// + /// 0-D. Half-width of the 1-D normalization window. + /// + /// + /// + /// + /// An offset (usually positive to avoid dividing by 0). + /// + /// + /// + /// + /// A scale factor, usually positive. + /// + /// + /// + /// + /// An exponent. + /// + /// + /// + public static Tensor lrn(Tensor input, int depth_radius = 5, float bias = 1f, float alpha = 1f, float beta = 0.5f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LRN", name) { args = new object[] { input }, attrs = new Dictionary() { ["depth_radius"] = depth_radius, ["bias"] = bias, ["alpha"] = alpha, ["beta"] = beta } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return lrn_eager_fallback(input, depth_radius: depth_radius, bias: bias, alpha: alpha, beta: beta, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["depth_radius"] = depth_radius; + keywords["bias"] = bias; + keywords["alpha"] = alpha; + keywords["beta"] = beta; + var _op = tf.OpDefLib._apply_op_helper("LRN", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "depth_radius", _op._get_attr_int("depth_radius"), "bias", _op.get_attr("bias"), "alpha", _op.get_attr("alpha"), "beta", _op.get_attr("beta"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("LRN", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor lrn_eager_fallback(Tensor input, int depth_radius, float bias, float alpha, float beta, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "depth_radius", depth_radius, "bias", bias, "alpha", alpha, "beta", beta, "T", input.dtype }; + var _result = _execute.execute("LRN", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LRN", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes rectified linear: `max(features, features * alpha)`. + /// + /// + /// + /// + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LeakyRelu", name) { args = new object[] { features }, attrs = new Dictionary() { ["alpha"] = alpha } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return leaky_relu_eager_fallback(features, alpha: alpha, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["alpha"] = alpha; + var _op = tf.OpDefLib._apply_op_helper("LeakyRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "alpha", _op.get_attr("alpha"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("LeakyRelu", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor leaky_relu_eager_fallback(Tensor features, float alpha, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "alpha", alpha, "T", features.dtype }; + var _result = _execute.execute("LeakyRelu", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LeakyRelu", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes rectified linear gradients for a LeakyRelu operation. + /// + /// + /// + /// + /// + public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LeakyReluGrad", name) { args = new object[] { gradients, features }, attrs = new Dictionary() { ["alpha"] = alpha } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return leaky_relu_grad_eager_fallback(gradients, features, alpha: alpha, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["features"] = features; + keywords["alpha"] = alpha; + var _op = tf.OpDefLib._apply_op_helper("LeakyReluGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "alpha", _op.get_attr("alpha"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("LeakyReluGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor leaky_relu_grad_eager_fallback(Tensor gradients, Tensor features, float alpha, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, features }; + object[] _attrs = new object[] { "alpha", alpha, "T", gradients.dtype }; + var _result = _execute.execute("LeakyReluGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LeakyReluGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes log softmax activations. + /// + /// + /// + /// For each batch `i` and class `j` we have + /// + /// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) + /// + /// + /// + /// + public static Tensor log_softmax(Tensor logits, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "LogSoftmax", name) { args = new object[] { logits }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return log_softmax_eager_fallback(logits, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["logits"] = logits; + var _op = tf.OpDefLib._apply_op_helper("LogSoftmax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("LogSoftmax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor log_softmax_eager_fallback(Tensor logits, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { logits }; + object[] _attrs = new object[] { "T", logits.dtype }; + var _result = _execute.execute("LogSoftmax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("LogSoftmax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs max pooling on the input. + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool(Tensor input, int[] ksize, int[] strides, string padding, int[] explicit_paddings = null, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPool", name) { args = new object[] { input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_eager_fallback(input, ksize: ksize, strides: strides, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("MaxPool", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_eager_fallback(Tensor input, int[] ksize, int[] strides, string padding, int[] explicit_paddings, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype, "ksize", ksize, "strides", strides, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format }; + var _result = _execute.execute("MaxPool", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPool", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs 3D max pooling on the input. + /// + /// + /// + /// + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have `ksize[0] = ksize[4] = 1`. + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool3d(Tensor input, int[] ksize, int[] strides, string padding, string data_format = "NDHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPool3D", name) { args = new object[] { input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool3d_eager_fallback(input, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPool3D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPool3D", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool3d_eager_fallback(Tensor input, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", input.dtype }; + var _result = _execute.execute("MaxPool3D", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPool3D", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of 3D max pooling function. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have `ksize[0] = ksize[4] = 1`. + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool3d_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NDHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPool3DGrad", name) { args = new object[] { orig_input, orig_output, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool3d_grad_eager_fallback(orig_input, orig_output, grad, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPool3DGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T"), "TInput", _op._get_attr_type("TInput") }; + _execute.record_gradient("MaxPool3DGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool3d_grad_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", grad.dtype, "TInput", orig_input.dtype }; + var _result = _execute.execute("MaxPool3DGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPool3DGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have `ksize[0] = ksize[4] = 1`. + /// + /// + /// + /// + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of `input`. Must have `strides[0] = strides[4] = 1`. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool3d_grad_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NDHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPool3DGradGrad", name) { args = new object[] { orig_input, orig_output, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool3d_grad_grad_eager_fallback(orig_input, orig_output, grad, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NDHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPool3DGradGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPool3DGradGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool3d_grad_grad_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", orig_input.dtype }; + var _result = _execute.execute("MaxPool3DGradGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPool3DGradGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, int[] explicit_paddings = null, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (explicit_paddings is null) + { + explicit_paddings = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGrad", name) { args = new object[] { orig_input, orig_output, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["explicit_paddings"] = explicit_paddings, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_eager_fallback(orig_input, orig_output, grad, ksize: ksize, strides: strides, padding: padding, explicit_paddings: explicit_paddings, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["explicit_paddings"] = explicit_paddings; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "explicit_paddings", _op.get_attr("explicit_paddings"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, int[] explicit_paddings, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "explicit_paddings", explicit_paddings, "data_format", data_format, "T", orig_input.dtype }; + var _result = _execute.execute("MaxPoolGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool_grad_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGradGrad", name) { args = new object[] { orig_input, orig_output, grad }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_grad_eager_fallback(orig_input, orig_output, grad, ksize: ksize, strides: strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGradGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGradGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_grad_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "data_format", data_format, "T", orig_input.dtype }; + var _result = _execute.execute("MaxPoolGradGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGradGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool_grad_grad_v2(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGradGradV2", name) { args = new object[] { orig_input, orig_output, grad, ksize, strides }, attrs = new Dictionary() { ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_grad_v2_eager_fallback(orig_input, orig_output, grad, ksize, strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGradGradV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGradGradV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_grad_v2_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad, ksize, strides }; + object[] _attrs = new object[] { "padding", padding, "data_format", data_format, "T", orig_input.dtype }; + var _result = _execute.execute("MaxPoolGradGradV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGradGradV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Whether to include batch dimension in flattened index of `argmax`. + /// + /// + /// + public static Tensor max_pool_grad_grad_with_argmax(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, bool include_batch_in_index = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGradGradWithArgmax", name) { args = new object[] { input, grad, argmax }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["include_batch_in_index"] = include_batch_in_index } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_grad_with_argmax_eager_fallback(input, grad, argmax, ksize: ksize, strides: strides, padding: padding, include_batch_in_index: include_batch_in_index, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["grad"] = grad; + keywords["argmax"] = argmax; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["include_batch_in_index"] = include_batch_in_index; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGradGradWithArgmax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "include_batch_in_index", _op._get_attr_bool("include_batch_in_index"), "Targmax", _op._get_attr_type("Targmax"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGradGradWithArgmax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_grad_with_argmax_eager_fallback(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, bool include_batch_in_index, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, grad, argmax }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "include_batch_in_index", include_batch_in_index, "Targmax", argmax.dtype, "T", input.dtype }; + var _result = _execute.execute("MaxPoolGradGradWithArgmax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGradGradWithArgmax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool_grad_v2(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGradV2", name) { args = new object[] { orig_input, orig_output, grad, ksize, strides }, attrs = new Dictionary() { ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_v2_eager_fallback(orig_input, orig_output, grad, ksize, strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["orig_input"] = orig_input; + keywords["orig_output"] = orig_output; + keywords["grad"] = grad; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGradV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGradV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_v2_eager_fallback(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { orig_input, orig_output, grad, ksize, strides }; + object[] _attrs = new object[] { "padding", padding, "data_format", data_format, "T", orig_input.dtype }; + var _result = _execute.execute("MaxPoolGradV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGradV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Whether to include batch dimension in flattened index of `argmax`. + /// + /// + /// + public static Tensor max_pool_grad_with_argmax(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, bool include_batch_in_index = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolGradWithArgmax", name) { args = new object[] { input, grad, argmax }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding, ["include_batch_in_index"] = include_batch_in_index } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_grad_with_argmax_eager_fallback(input, grad, argmax, ksize: ksize, strides: strides, padding: padding, include_batch_in_index: include_batch_in_index, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["grad"] = grad; + keywords["argmax"] = argmax; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["include_batch_in_index"] = include_batch_in_index; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolGradWithArgmax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "include_batch_in_index", _op._get_attr_bool("include_batch_in_index"), "Targmax", _op._get_attr_type("Targmax"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolGradWithArgmax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_grad_with_argmax_eager_fallback(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, bool include_batch_in_index, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, grad, argmax }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "padding", padding, "include_batch_in_index", include_batch_in_index, "Targmax", argmax.dtype, "T", input.dtype }; + var _result = _execute.execute("MaxPoolGradWithArgmax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolGradWithArgmax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs max pooling on the input. + /// + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// + public static Tensor max_pool_v2(Tensor input, Tensor ksize, Tensor strides, string padding, string data_format = "NHWC", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolV2", name) { args = new object[] { input, ksize, strides }, attrs = new Dictionary() { ["padding"] = padding, ["data_format"] = data_format } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_v2_eager_fallback(input, ksize, strides, padding: padding, data_format: data_format, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (data_format is null) + { + data_format = "NHWC"; + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["data_format"] = data_format; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "padding", _op.get_attr("padding"), "data_format", _op.get_attr("data_format") }; + _execute.record_gradient("MaxPoolV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor max_pool_v2_eager_fallback(Tensor input, Tensor ksize, Tensor strides, string padding, string data_format, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, ksize, strides }; + object[] _attrs = new object[] { "T", input.dtype, "padding", padding, "data_format", data_format }; + var _result = _execute.execute("MaxPoolV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Performs max pooling on the input and outputs both max values and indices. + /// + /// + /// + /// The indices in `argmax` are flattened, so that a maximum value at position + /// `[b, y, x, c]` becomes flattened index: + /// `(y * width + x) * channels + c` if `include_batch_in_index` is False; + /// `((b * height + y) * width + x) * channels + c` if `include_batch_in_index` is True. + /// + /// The indices returned are always in `[0, height) x [0, width)` before flattening, + /// even if padding is involved and the mathematically correct answer is outside + /// (either negative or too large). This is a bug, but fixing it is difficult to do + /// in a safe backwards compatible way, especially due to flattening. + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Whether to include batch dimension in flattened index of `argmax`. + /// + /// + /// + public static Tensor[] max_pool_with_argmax(Tensor input, int[] ksize, int[] strides, string padding, TF_DataType Targmax = TF_DataType.TF_INT64, bool include_batch_in_index = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MaxPoolWithArgmax", name) { args = new object[] { input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["Targmax"] = Targmax, ["padding"] = padding, ["include_batch_in_index"] = include_batch_in_index } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return max_pool_with_argmax_eager_fallback(input, ksize: ksize, strides: strides, Targmax: Targmax, padding: padding, include_batch_in_index: include_batch_in_index, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["Targmax"] = Targmax; + keywords["padding"] = padding; + keywords["include_batch_in_index"] = include_batch_in_index; + var _op = tf.OpDefLib._apply_op_helper("MaxPoolWithArgmax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "Targmax", _op._get_attr_type("Targmax"), "padding", _op.get_attr("padding"), "include_batch_in_index", _op._get_attr_bool("include_batch_in_index"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("MaxPoolWithArgmax", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] max_pool_with_argmax_eager_fallback(Tensor input, int[] ksize, int[] strides, TF_DataType Targmax, string padding, bool include_batch_in_index, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "ksize", ksize, "strides", strides, "Targmax", Targmax, "padding", padding, "include_batch_in_index", include_batch_in_index, "T", input.dtype }; + var _result = _execute.execute("MaxPoolWithArgmax", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MaxPoolWithArgmax", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds values of the `n`-th order statistic for the last dimension. + /// + /// + /// + /// If the input is a vector (rank-1), finds the entries which is the nth-smallest + /// value in the vector and outputs their values as scalar tensor. + /// + /// For matrices (resp. higher rank input), computes the entries which is the + /// nth-smallest value in each row (resp. vector along the last dimension). Thus, + /// + /// values.shape = input.shape[:-1] + /// + /// + /// + /// + /// + /// + /// When set to True, find the nth-largest value in the vector and vice + /// versa. + /// + /// + /// + public static Tensor nth_element(Tensor input, Tensor n, bool reverse = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "NthElement", name) { args = new object[] { input, n }, attrs = new Dictionary() { ["reverse"] = reverse } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return nth_element_eager_fallback(input, n, reverse: reverse, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["n"] = n; + keywords["reverse"] = reverse; + var _op = tf.OpDefLib._apply_op_helper("NthElement", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "reverse", _op._get_attr_bool("reverse"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("NthElement", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor nth_element_eager_fallback(Tensor input, Tensor n, bool reverse, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, n }; + object[] _attrs = new object[] { "reverse", reverse, "T", input.dtype }; + var _result = _execute.execute("NthElement", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("NthElement", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Produces the average pool of the input tensor for quantized types. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// The length must be 4 to match the number of dimensions of the input. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// tensor. The length must be 4 to match the number of dimensions of the input. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor[] quantized_avg_pool(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedAvgPool", name) { args = new object[] { input, min_input, max_input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_avg_pool_eager_fallback(input, min_input, max_input, ksize: ksize, strides: strides, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("QuantizedAvgPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("QuantizedAvgPool", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_avg_pool_eager_fallback(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, min_input, max_input }; + object[] _attrs = new object[] { "T", input.dtype, "ksize", ksize, "strides", strides, "padding", padding }; + var _result = _execute.execute("QuantizedAvgPool", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedAvgPool", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Quantized Batch normalization. + /// + /// + /// + /// This op is deprecated and will be removed in the future. Prefer + /// `tf.nn.batch_normalization`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A small float number to avoid dividing by 0. + /// + /// + /// + /// + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// + public static Tensor[] quantized_batch_norm_with_global_normalization(Tensor t, Tensor t_min, Tensor t_max, Tensor m, Tensor m_min, Tensor m_max, Tensor v, Tensor v_min, Tensor v_max, Tensor beta, Tensor beta_min, Tensor beta_max, Tensor gamma, Tensor gamma_min, Tensor gamma_max, TF_DataType out_type, float variance_epsilon, bool scale_after_normalization, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedBatchNormWithGlobalNormalization", name) { args = new object[] { t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max }, attrs = new Dictionary() { ["out_type"] = out_type, ["variance_epsilon"] = variance_epsilon, ["scale_after_normalization"] = scale_after_normalization } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_batch_norm_with_global_normalization_eager_fallback(t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max, out_type: out_type, variance_epsilon: variance_epsilon, scale_after_normalization: scale_after_normalization, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["t"] = t; + keywords["t_min"] = t_min; + keywords["t_max"] = t_max; + keywords["m"] = m; + keywords["m_min"] = m_min; + keywords["m_max"] = m_max; + keywords["v"] = v; + keywords["v_min"] = v_min; + keywords["v_max"] = v_max; + keywords["beta"] = beta; + keywords["beta_min"] = beta_min; + keywords["beta_max"] = beta_max; + keywords["gamma"] = gamma; + keywords["gamma_min"] = gamma_min; + keywords["gamma_max"] = gamma_max; + keywords["out_type"] = out_type; + keywords["variance_epsilon"] = variance_epsilon; + keywords["scale_after_normalization"] = scale_after_normalization; + var _op = tf.OpDefLib._apply_op_helper("QuantizedBatchNormWithGlobalNormalization", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type"), "variance_epsilon", _op.get_attr("variance_epsilon"), "scale_after_normalization", _op._get_attr_bool("scale_after_normalization") }; + _execute.record_gradient("QuantizedBatchNormWithGlobalNormalization", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_batch_norm_with_global_normalization_eager_fallback(Tensor t, Tensor t_min, Tensor t_max, Tensor m, Tensor m_min, Tensor m_max, Tensor v, Tensor v_min, Tensor v_max, Tensor beta, Tensor beta_min, Tensor beta_max, Tensor gamma, Tensor gamma_min, Tensor gamma_max, TF_DataType out_type, float variance_epsilon, bool scale_after_normalization, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max }; + object[] _attrs = new object[] { "Tinput", t.dtype, "out_type", out_type, "variance_epsilon", variance_epsilon, "scale_after_normalization", scale_after_normalization }; + var _result = _execute.execute("QuantizedBatchNormWithGlobalNormalization", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedBatchNormWithGlobalNormalization", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Adds Tensor 'bias' to Tensor 'input' for Quantized types. + /// + /// + /// + /// Broadcasts the values of bias on dimensions 0..N-2 of 'input'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_bias_add(Tensor input, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_bias, Tensor max_bias, TF_DataType out_type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedBiasAdd", name) { args = new object[] { input, bias, min_input, max_input, min_bias, max_bias }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_bias_add_eager_fallback(input, bias, min_input, max_input, min_bias, max_bias, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_bias"] = min_bias; + keywords["max_bias"] = max_bias; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("QuantizedBiasAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("QuantizedBiasAdd", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_bias_add_eager_fallback(Tensor input, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_bias, Tensor max_bias, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, bias, min_input, max_input, min_bias, max_bias }; + object[] _attrs = new object[] { "T1", input.dtype, "T2", bias.dtype, "out_type", out_type }; + var _result = _execute.execute("QuantizedBiasAdd", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedBiasAdd", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes a 2D convolution given quantized 4D input and filter tensors. + /// + /// + /// + /// The inputs are quantized tensors where the lowest value represents the real + /// number of the associated minimum, and the highest represents the maximum. + /// This means that you can only interpret the quantized output in the same way, by + /// taking the returned minimum and maximum values into account. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// tensor. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// `input`. If set to k > 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of `data_format`, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// + public static Tensor[] quantized_conv2d(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2D", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("QuantizedConv2D", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("QuantizedConv2D", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2D", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_and_relu(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DAndRelu", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_and_relu_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DAndRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DAndRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_and_relu_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DAndRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DAndRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_and_relu_and_requantize(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QUINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DAndReluAndRequantize", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_and_relu_and_requantize_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_and_relu_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_and_requantize(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DAndRequantize", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_and_requantize_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes QuantizedConv2D per channel. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The quantized type of output tensor that needs to be converted. + /// + /// + /// + /// list of stride values. + /// + /// + /// + /// list of dilation values. + /// + /// + public static Tensor[] quantized_conv2d_per_channel(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DPerChannel", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_per_channel_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DPerChannel", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("QuantizedConv2DPerChannel", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_per_channel_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("QuantizedConv2DPerChannel", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DPerChannel", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBias", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBias", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBias", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBias", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBias", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_and_relu(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasAndRelu", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_and_relu_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasAndRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasAndRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_and_relu_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasAndRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasAndRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_and_relu_and_requantize(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QUINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasAndReluAndRequantize", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_and_relu_and_requantize_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "Tbias", _op._get_attr_type("Tbias"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_and_relu_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "Tbias", bias.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_and_requantize(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasAndRequantize", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_and_requantize_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "Tbias", _op._get_attr_type("Tbias"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "Tbias", bias.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_signed_sum_and_relu_and_requantize(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, Tensor summand, Tensor min_summand, Tensor max_summand, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QUINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_signed_sum_and_relu_and_requantize_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["summand"] = summand; + keywords["min_summand"] = min_summand; + keywords["max_summand"] = max_summand; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "Tbias", _op._get_attr_type("Tbias"), "Tsummand", _op._get_attr_type("Tsummand"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_signed_sum_and_relu_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, Tensor summand, Tensor min_summand, Tensor max_summand, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "Tbias", bias.dtype, "Tsummand", summand.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_sum_and_relu(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor summand, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasSumAndRelu", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, summand }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_sum_and_relu_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, summand, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["summand"] = summand; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasSumAndRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasSumAndRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_sum_and_relu_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor summand, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, summand }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasSumAndRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasSumAndRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_conv2d_with_bias_sum_and_relu_and_requantize(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, Tensor summand, Tensor min_summand, Tensor max_summand, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QUINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedConv2DWithBiasSumAndReluAndRequantize", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_conv2d_with_bias_sum_and_relu_and_requantize_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["summand"] = summand; + keywords["min_summand"] = min_summand; + keywords["max_summand"] = max_summand; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedConv2DWithBiasSumAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "Tbias", _op._get_attr_type("Tbias"), "Tsummand", _op._get_attr_type("Tsummand"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedConv2DWithBiasSumAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_conv2d_with_bias_sum_and_relu_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, Tensor summand, Tensor min_summand, Tensor max_summand, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, summand, min_summand, max_summand }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "Tbias", bias.dtype, "Tsummand", summand.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedConv2DWithBiasSumAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedConv2DWithBiasSumAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes quantized depthwise Conv2D. + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. + /// + /// + /// List of stride values. + /// + /// + /// + /// List of dilation values. + /// + /// + public static Tensor[] quantized_depthwise_conv2d(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedDepthwiseConv2D", name) { args = new object[] { input, filter, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_depthwise_conv2d_eager_fallback(input, filter, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("QuantizedDepthwiseConv2D", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("QuantizedDepthwiseConv2D", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_depthwise_conv2d_eager_fallback(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("QuantizedDepthwiseConv2D", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedDepthwiseConv2D", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes quantized depthwise Conv2D with Bias. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. + /// + /// + /// List of stride values. + /// + /// + /// + /// List of dilation values. + /// + /// + public static Tensor[] quantized_depthwise_conv2d_with_bias(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedDepthwiseConv2DWithBias", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_depthwise_conv2d_with_bias_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + var _op = tf.OpDefLib._apply_op_helper("QuantizedDepthwiseConv2DWithBias", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations") }; + _execute.record_gradient("QuantizedDepthwiseConv2DWithBias", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_depthwise_conv2d_with_bias_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations }; + var _result = _execute.execute("QuantizedDepthwiseConv2DWithBias", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedDepthwiseConv2DWithBias", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes quantized depthwise Conv2D with Bias and Relu. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. + /// + /// + /// List of stride values. + /// + /// + /// + /// List of dilation values. + /// + /// + /// + public static Tensor[] quantized_depthwise_conv2d_with_bias_and_relu(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QINT32, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedDepthwiseConv2DWithBiasAndRelu", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_depthwise_conv2d_with_bias_and_relu_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedDepthwiseConv2DWithBiasAndRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedDepthwiseConv2DWithBiasAndRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_depthwise_conv2d_with_bias_and_relu_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedDepthwiseConv2DWithBiasAndRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedDepthwiseConv2DWithBiasAndRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes quantized depthwise Conv2D with Bias, Relu and Requantize. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The type of the output. + /// + /// + /// List of stride values. + /// + /// + /// + /// List of dilation values. + /// + /// + /// + public static Tensor[] quantized_depthwise_conv2d_with_bias_and_relu_and_requantize(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, int[] strides, string padding, TF_DataType out_type = TF_DataType.TF_QUINT8, int[] dilations = null, int[] padding_list = null, string? name = null) + { + var _ctx = tf.Context; + if (dilations is null) + { + dilations = new int[] { 1, 1, 1, 1 }; + } + if (padding_list is null) + { + padding_list = new int[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", name) { args = new object[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["out_type"] = out_type, ["strides"] = strides, ["padding"] = padding, ["dilations"] = dilations, ["padding_list"] = padding_list } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_depthwise_conv2d_with_bias_and_relu_and_requantize_eager_fallback(input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, out_type: out_type, strides: strides, padding: padding, dilations: dilations, padding_list: padding_list, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["filter"] = filter; + keywords["bias"] = bias; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["min_filter"] = min_filter; + keywords["max_filter"] = max_filter; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["out_type"] = out_type; + keywords["strides"] = strides; + keywords["padding"] = padding; + keywords["dilations"] = dilations; + keywords["padding_list"] = padding_list; + var _op = tf.OpDefLib._apply_op_helper("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "Tfilter", _op._get_attr_type("Tfilter"), "Tbias", _op._get_attr_type("Tbias"), "out_type", _op._get_attr_type("out_type"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding"), "dilations", _op.get_attr("dilations"), "padding_list", _op.get_attr("padding_list") }; + _execute.record_gradient("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_depthwise_conv2d_with_bias_and_relu_and_requantize_eager_fallback(Tensor input, Tensor filter, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType out_type, int[] strides, string padding, int[] dilations, int[] padding_list, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "Tinput", input.dtype, "Tfilter", filter.dtype, "Tbias", bias.dtype, "out_type", out_type, "strides", strides, "padding", padding, "dilations", dilations, "padding_list", padding_list }; + var _result = _execute.execute("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// ~~%~~Performs a quantized matrix multiplication of `a` by the matrix `b` with bias~~%~~add.~~%~~ + /// + /// + /// + /// The inputs must be two-dimensional matrices and 1D bias vector. And the inner + /// dimension of `a` (after being transposed if `transpose_a` is non-zero) must + /// match the outer dimension of `b` (after being transposed if `transposed_b` is + /// non-zero). Then do broadcast add operation with bias values on the matrix + /// multiplication result. The bias size must match inner dimension of `b`. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If true, `a` is transposed before multiplication. + /// + /// + /// If true, `b` is transposed before multiplication. + /// + /// + /// + /// Input data quantization mode. Either MIN_FIRST(default) or SCALED. + /// + /// + /// + public static Tensor[] quantized_mat_mul_with_bias(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput = TF_DataType.TF_QINT32, bool transpose_a = false, bool transpose_b = false, string input_quant_mode = "MIN_FIRST", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMulWithBias", name) { args = new object[] { a, b, bias, min_a, max_a, min_b, max_b }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["input_quant_mode"] = input_quant_mode } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_with_bias_eager_fallback(a, b, bias, min_a, max_a, min_b, max_b, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, input_quant_mode: input_quant_mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (input_quant_mode is null) + { + input_quant_mode = "MIN_FIRST"; + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["bias"] = bias; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["input_quant_mode"] = input_quant_mode; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMulWithBias", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Tbias", _op._get_attr_type("Tbias"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "input_quant_mode", _op.get_attr("input_quant_mode") }; + _execute.record_gradient("QuantizedMatMulWithBias", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mat_mul_with_bias_eager_fallback(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput, bool transpose_a, bool transpose_b, string input_quant_mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, bias, min_a, max_a, min_b, max_b }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Tbias", bias.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "input_quant_mode", input_quant_mode }; + var _result = _execute.execute("QuantizedMatMulWithBias", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMulWithBias", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor quantized_mat_mul_with_bias_and_dequantize(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput, bool transpose_a = false, bool transpose_b = false, string input_quant_mode = "MIN_FIRST", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMulWithBiasAndDequantize", name) { args = new object[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["input_quant_mode"] = input_quant_mode } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_with_bias_and_dequantize_eager_fallback(a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, input_quant_mode: input_quant_mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (input_quant_mode is null) + { + input_quant_mode = "MIN_FIRST"; + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["bias"] = bias; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["input_quant_mode"] = input_quant_mode; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMulWithBiasAndDequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Tbias", _op._get_attr_type("Tbias"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "input_quant_mode", _op.get_attr("input_quant_mode") }; + _execute.record_gradient("QuantizedMatMulWithBiasAndDequantize", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor quantized_mat_mul_with_bias_and_dequantize_eager_fallback(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput, bool transpose_a, bool transpose_b, string input_quant_mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Tbias", bias.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "input_quant_mode", input_quant_mode }; + var _result = _execute.execute("QuantizedMatMulWithBiasAndDequantize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMulWithBiasAndDequantize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// ~~%~~Perform a quantized matrix multiplication of `a` by the matrix `b` with bias~~%~~add and relu fusion.~~%~~ + /// + /// + /// + /// The inputs must be two-dimensional matrices and 1D bias vector. And the inner + /// dimension of `a` (after being transposed if `transpose_a` is non-zero) must + /// match the outer dimension of `b` (after being transposed if `transposed_b` is + /// non-zero). Then do broadcast add operation with bias values on the matrix + /// multiplication result. The bias size must match inner dimension of `b`. Then do + /// relu activation to get non-negative result. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If true, `a` is transposed before multiplication. + /// + /// + /// If true, `b` is transposed before multiplication. + /// + /// + /// + /// Input data quantization mode. Either MIN_FIRST(default) or SCALED. + /// + /// + /// + public static Tensor[] quantized_mat_mul_with_bias_and_relu(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput = TF_DataType.TF_QINT32, bool transpose_a = false, bool transpose_b = false, string input_quant_mode = "MIN_FIRST", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMulWithBiasAndRelu", name) { args = new object[] { a, b, bias, min_a, max_a, min_b, max_b }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["input_quant_mode"] = input_quant_mode } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_with_bias_and_relu_eager_fallback(a, b, bias, min_a, max_a, min_b, max_b, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, input_quant_mode: input_quant_mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (input_quant_mode is null) + { + input_quant_mode = "MIN_FIRST"; + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["bias"] = bias; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["input_quant_mode"] = input_quant_mode; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMulWithBiasAndRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "input_quant_mode", _op.get_attr("input_quant_mode") }; + _execute.record_gradient("QuantizedMatMulWithBiasAndRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mat_mul_with_bias_and_relu_eager_fallback(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType Toutput, bool transpose_a, bool transpose_b, string input_quant_mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, bias, min_a, max_a, min_b, max_b }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "input_quant_mode", input_quant_mode }; + var _result = _execute.execute("QuantizedMatMulWithBiasAndRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMulWithBiasAndRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// ~~%~~Perform a quantized matrix multiplication of `a` by the matrix `b` with bias~~%~~add and relu and requantize fusion.~~%~~ + /// + /// + /// + /// The inputs must be two-dimensional matrices and 1D bias vector. And the inner + /// dimension of `a` (after being transposed if `transpose_a` is non-zero) must + /// match the outer dimension of `b` (after being transposed if `transposed_b` is + /// non-zero). Then do broadcast add operation with bias values on the matrix + /// multiplication result. The bias size must match inner dimension of `b`. Then do + /// relu activation to get non-negative result. Then do requantize operation to get + /// final uint8 result. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If true, `a` is transposed before multiplication. + /// + /// + /// If true, `b` is transposed before multiplication. + /// + /// + /// + /// Input data quantization mode. Either MIN_FIRST(default) or SCALED. + /// + /// + /// + public static Tensor[] quantized_mat_mul_with_bias_and_relu_and_requantize(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput = TF_DataType.TF_QUINT8, bool transpose_a = false, bool transpose_b = false, string input_quant_mode = "MIN_FIRST", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMulWithBiasAndReluAndRequantize", name) { args = new object[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["input_quant_mode"] = input_quant_mode } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_with_bias_and_relu_and_requantize_eager_fallback(a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, input_quant_mode: input_quant_mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (input_quant_mode is null) + { + input_quant_mode = "MIN_FIRST"; + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["bias"] = bias; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["input_quant_mode"] = input_quant_mode; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMulWithBiasAndReluAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Tbias", _op._get_attr_type("Tbias"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "input_quant_mode", _op.get_attr("input_quant_mode") }; + _execute.record_gradient("QuantizedMatMulWithBiasAndReluAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mat_mul_with_bias_and_relu_and_requantize_eager_fallback(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput, bool transpose_a, bool transpose_b, string input_quant_mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Tbias", bias.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "input_quant_mode", input_quant_mode }; + var _result = _execute.execute("QuantizedMatMulWithBiasAndReluAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMulWithBiasAndReluAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_mat_mul_with_bias_and_requantize(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput = TF_DataType.TF_QUINT8, bool transpose_a = false, bool transpose_b = false, string input_quant_mode = "MIN_FIRST", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMatMulWithBiasAndRequantize", name) { args = new object[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }, attrs = new Dictionary() { ["Toutput"] = Toutput, ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b, ["input_quant_mode"] = input_quant_mode } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_mat_mul_with_bias_and_requantize_eager_fallback(a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output, Toutput: Toutput, transpose_a: transpose_a, transpose_b: transpose_b, input_quant_mode: input_quant_mode, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (input_quant_mode is null) + { + input_quant_mode = "MIN_FIRST"; + } + Dictionary keywords = new(); + keywords["a"] = a; + keywords["b"] = b; + keywords["bias"] = bias; + keywords["min_a"] = min_a; + keywords["max_a"] = max_a; + keywords["min_b"] = min_b; + keywords["max_b"] = max_b; + keywords["min_freezed_output"] = min_freezed_output; + keywords["max_freezed_output"] = max_freezed_output; + keywords["Toutput"] = Toutput; + keywords["transpose_a"] = transpose_a; + keywords["transpose_b"] = transpose_b; + keywords["input_quant_mode"] = input_quant_mode; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMatMulWithBiasAndRequantize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T1", _op._get_attr_type("T1"), "T2", _op._get_attr_type("T2"), "Tbias", _op._get_attr_type("Tbias"), "Toutput", _op._get_attr_type("Toutput"), "transpose_a", _op._get_attr_bool("transpose_a"), "transpose_b", _op._get_attr_bool("transpose_b"), "input_quant_mode", _op.get_attr("input_quant_mode") }; + _execute.record_gradient("QuantizedMatMulWithBiasAndRequantize", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_mat_mul_with_bias_and_requantize_eager_fallback(Tensor a, Tensor b, Tensor bias, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, Tensor min_freezed_output, Tensor max_freezed_output, TF_DataType Toutput, bool transpose_a, bool transpose_b, string input_quant_mode, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { a, b, bias, min_a, max_a, min_b, max_b, min_freezed_output, max_freezed_output }; + object[] _attrs = new object[] { "T1", a.dtype, "T2", b.dtype, "Tbias", bias.dtype, "Toutput", Toutput, "transpose_a", transpose_a, "transpose_b", transpose_b, "input_quant_mode", input_quant_mode }; + var _result = _execute.execute("QuantizedMatMulWithBiasAndRequantize", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMatMulWithBiasAndRequantize", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Produces the max pool of the input tensor for quantized types. + /// + /// + /// + /// + /// + /// + /// The size of the window for each dimension of the input tensor. + /// The length must be 4 to match the number of dimensions of the input. + /// + /// + /// + /// + /// The stride of the sliding window for each dimension of the input + /// tensor. The length must be 4 to match the number of dimensions of the input. + /// + /// + /// + /// + /// The type of padding algorithm to use. + /// + /// + /// + public static Tensor[] quantized_max_pool(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedMaxPool", name) { args = new object[] { input, min_input, max_input }, attrs = new Dictionary() { ["ksize"] = ksize, ["strides"] = strides, ["padding"] = padding } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_max_pool_eager_fallback(input, min_input, max_input, ksize: ksize, strides: strides, padding: padding, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["min_input"] = min_input; + keywords["max_input"] = max_input; + keywords["ksize"] = ksize; + keywords["strides"] = strides; + keywords["padding"] = padding; + var _op = tf.OpDefLib._apply_op_helper("QuantizedMaxPool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "ksize", _op.get_attr("ksize"), "strides", _op.get_attr("strides"), "padding", _op.get_attr("padding") }; + _execute.record_gradient("QuantizedMaxPool", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_max_pool_eager_fallback(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, min_input, max_input }; + object[] _attrs = new object[] { "T", input.dtype, "ksize", ksize, "strides", strides, "padding", padding }; + var _result = _execute.execute("QuantizedMaxPool", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedMaxPool", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes Quantized Rectified Linear: `max(features, 0)` + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_relu(Tensor features, Tensor min_features, Tensor max_features, TF_DataType out_type = TF_DataType.TF_QUINT8, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedRelu", name) { args = new object[] { features, min_features, max_features }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_relu_eager_fallback(features, min_features, max_features, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["min_features"] = min_features; + keywords["max_features"] = max_features; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("QuantizedRelu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("QuantizedRelu", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_relu_eager_fallback(Tensor features, Tensor min_features, Tensor max_features, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features, min_features, max_features }; + object[] _attrs = new object[] { "Tinput", features.dtype, "out_type", out_type }; + var _result = _execute.execute("QuantizedRelu", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedRelu", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_relu6(Tensor features, Tensor min_features, Tensor max_features, TF_DataType out_type = TF_DataType.TF_QUINT8, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedRelu6", name) { args = new object[] { features, min_features, max_features }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_relu6_eager_fallback(features, min_features, max_features, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["min_features"] = min_features; + keywords["max_features"] = max_features; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("QuantizedRelu6", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("QuantizedRelu6", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_relu6_eager_fallback(Tensor features, Tensor min_features, Tensor max_features, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features, min_features, max_features }; + object[] _attrs = new object[] { "Tinput", features.dtype, "out_type", out_type }; + var _result = _execute.execute("QuantizedRelu6", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedRelu6", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] quantized_relu_x(Tensor features, Tensor max_value, Tensor min_features, Tensor max_features, TF_DataType out_type = TF_DataType.TF_QUINT8, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "QuantizedReluX", name) { args = new object[] { features, max_value, min_features, max_features }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return quantized_relu_x_eager_fallback(features, max_value, min_features, max_features, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["max_value"] = max_value; + keywords["min_features"] = min_features; + keywords["max_features"] = max_features; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("QuantizedReluX", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tinput", _op._get_attr_type("Tinput"), "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("QuantizedReluX", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] quantized_relu_x_eager_fallback(Tensor features, Tensor max_value, Tensor min_features, Tensor max_features, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features, max_value, min_features, max_features }; + object[] _attrs = new object[] { "Tinput", features.dtype, "out_type", out_type }; + var _result = _execute.execute("QuantizedReluX", 3, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("QuantizedReluX", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes rectified linear: `max(features, 0)`. + /// + /// + /// + /// See: https://en.wikipedia.org/wiki/Rectifier_(neural_networks) + /// Example usage: + /// >>> tf.nn.relu([-2., 0., 3.]).numpy() + /// array([0., 0., 3.], dtype=float32) + /// + /// + /// + /// + public static Tensor relu(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Relu", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return relu_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Relu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Relu", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor relu_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Relu", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Relu", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes rectified linear 6: `min(max(features, 0), 6)`. + /// + /// + /// + public static Tensor relu6(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Relu6", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return relu6_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Relu6", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Relu6", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor relu6_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Relu6", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Relu6", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes rectified linear gradients for a Relu operation. + /// + /// + /// + /// + public static Tensor relu_grad(Tensor gradients, Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReluGrad", name) { args = new object[] { gradients, features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return relu_grad_eager_fallback(gradients, features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("ReluGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ReluGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor relu_grad_eager_fallback(Tensor gradients, Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, features }; + object[] _attrs = new object[] { "T", gradients.dtype }; + var _result = _execute.execute("ReluGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReluGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` + /// + /// + /// + /// if < 0, `scale * features` otherwise. + /// + /// To be used together with + /// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. + /// For correct dropout, use `tf.contrib.nn.alpha_dropout`. + /// + /// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + /// + /// + /// + /// + public static Tensor selu(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Selu", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return selu_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Selu", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Selu", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor selu_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Selu", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Selu", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes gradients for the scaled exponential linear (Selu) operation. + /// + /// + /// + /// + public static Tensor selu_grad(Tensor gradients, Tensor outputs, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SeluGrad", name) { args = new object[] { gradients, outputs }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return selu_grad_eager_fallback(gradients, outputs, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["outputs"] = outputs; + var _op = tf.OpDefLib._apply_op_helper("SeluGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SeluGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor selu_grad_eager_fallback(Tensor gradients, Tensor outputs, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, outputs }; + object[] _attrs = new object[] { "T", gradients.dtype }; + var _result = _execute.execute("SeluGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SeluGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softmax activations. + /// + /// + /// + /// For each batch `i` and class `j` we have + /// + /// $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ + /// + /// + /// + /// + public static Tensor softmax(Tensor logits, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Softmax", name) { args = new object[] { logits }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softmax_eager_fallback(logits, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["logits"] = logits; + var _op = tf.OpDefLib._apply_op_helper("Softmax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Softmax", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor softmax_eager_fallback(Tensor logits, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { logits }; + object[] _attrs = new object[] { "T", logits.dtype }; + var _result = _execute.execute("Softmax", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Softmax", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softmax cross entropy cost and gradients to backpropagate. + /// + /// + /// + /// Inputs are the logits, not probabilities. + /// + /// + /// + /// + /// + public static Tensor[] softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SoftmaxCrossEntropyWithLogits", name) { args = new object[] { features, labels }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softmax_cross_entropy_with_logits_eager_fallback(features, labels, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["labels"] = labels; + var _op = tf.OpDefLib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SoftmaxCrossEntropyWithLogits", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] softmax_cross_entropy_with_logits_eager_fallback(Tensor features, Tensor labels, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features, labels }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("SoftmaxCrossEntropyWithLogits", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SoftmaxCrossEntropyWithLogits", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + public static Tensor softplus(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Softplus", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softplus_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Softplus", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Softplus", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor softplus_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Softplus", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Softplus", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softplus gradients for a softplus operation. + /// + /// + /// + /// + public static Tensor softplus_grad(Tensor gradients, Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SoftplusGrad", name) { args = new object[] { gradients, features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softplus_grad_eager_fallback(gradients, features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("SoftplusGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SoftplusGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor softplus_grad_eager_fallback(Tensor gradients, Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, features }; + object[] _attrs = new object[] { "T", gradients.dtype }; + var _result = _execute.execute("SoftplusGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SoftplusGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softsign: `features / (abs(features) + 1)`. + /// + /// + /// + public static Tensor softsign(Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Softsign", name) { args = new object[] { features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softsign_eager_fallback(features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("Softsign", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("Softsign", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor softsign_eager_fallback(Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features }; + object[] _attrs = new object[] { "T", features.dtype }; + var _result = _execute.execute("Softsign", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Softsign", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softsign gradients for a softsign operation. + /// + /// + /// + /// + public static Tensor softsign_grad(Tensor gradients, Tensor features, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SoftsignGrad", name) { args = new object[] { gradients, features }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return softsign_grad_eager_fallback(gradients, features, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["gradients"] = gradients; + keywords["features"] = features; + var _op = tf.OpDefLib._apply_op_helper("SoftsignGrad", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("SoftsignGrad", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor softsign_grad_eager_fallback(Tensor gradients, Tensor features, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { gradients, features }; + object[] _attrs = new object[] { "T", gradients.dtype }; + var _result = _execute.execute("SoftsignGrad", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SoftsignGrad", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Computes softmax cross entropy cost and gradients to backpropagate. + /// + /// + /// + /// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept + /// a matrix of label probabilities, but rather a single label per row + /// of features. This label is considered to have probability 1.0 for the + /// given row. + /// + /// Inputs are the logits, not probabilities. + /// + /// + /// + /// + /// + public static Tensor[] sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SparseSoftmaxCrossEntropyWithLogits", name) { args = new object[] { features, labels }, attrs = new Dictionary() { } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return sparse_softmax_cross_entropy_with_logits_eager_fallback(features, labels, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["features"] = features; + keywords["labels"] = labels; + var _op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T"), "Tlabels", _op._get_attr_type("Tlabels") }; + _execute.record_gradient("SparseSoftmaxCrossEntropyWithLogits", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] sparse_softmax_cross_entropy_with_logits_eager_fallback(Tensor features, Tensor labels, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { features, labels }; + object[] _attrs = new object[] { "T", features.dtype, "Tlabels", labels.dtype }; + var _result = _execute.execute("SparseSoftmaxCrossEntropyWithLogits", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SparseSoftmaxCrossEntropyWithLogits", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds values and indices of the `k` largest elements for the last dimension. + /// + /// + /// + /// If the input is a vector (rank-1), finds the `k` largest entries in the vector + /// and outputs their values and indices as vectors. Thus `values[j]` is the + /// `j`-th largest entry in `input`, and its index is `indices[j]`. + /// + /// For matrices (resp. higher rank input), computes the top `k` entries in each + /// row (resp. vector along the last dimension). Thus, + /// + /// values.shape = indices.shape = input.shape[:-1] + [k] + /// + /// If two elements are equal, the lower-index element appears first. + /// + /// If `k` varies dynamically, use `TopKV2` below. + /// + /// + /// + /// + /// + /// Number of top elements to look for along the last dimension (along each + /// row for matrices). + /// + /// + /// + /// + /// If true the resulting `k` elements will be sorted by the values in + /// descending order. + /// + /// + /// + public static Tensor[] top_k(Tensor input, int k = 0, bool sorted = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TopK", name) { args = new object[] { input }, attrs = new Dictionary() { ["k"] = k, ["sorted"] = sorted } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return top_k_eager_fallback(input, k: k, sorted: sorted, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["k"] = k; + keywords["sorted"] = sorted; + var _op = tf.OpDefLib._apply_op_helper("TopK", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "k", _op._get_attr_int("k"), "sorted", _op._get_attr_bool("sorted"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("TopK", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] top_k_eager_fallback(Tensor input, int k, bool sorted, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "k", k, "sorted", sorted, "T", input.dtype }; + var _result = _execute.execute("TopK", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TopK", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Finds values and indices of the `k` largest elements for the last dimension. + /// + /// + /// + /// If the input is a vector (rank-1), finds the `k` largest entries in the vector + /// and outputs their values and indices as vectors. Thus `values[j]` is the + /// `j`-th largest entry in `input`, and its index is `indices[j]`. + /// + /// For matrices (resp. higher rank input), computes the top `k` entries in each + /// row (resp. vector along the last dimension). Thus, + /// + /// values.shape = indices.shape = input.shape[:-1] + [k] + /// + /// If two elements are equal, the lower-index element appears first. + /// + /// + /// + /// + /// + /// + /// If true the resulting `k` elements will be sorted by the values in + /// descending order. + /// + /// + /// + public static Tensor[] top_kv2(Tensor input, Tensor k, bool sorted = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TopKV2", name) { args = new object[] { input, k }, attrs = new Dictionary() { ["sorted"] = sorted } }); + return _fast_path_result; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return top_kv2_eager_fallback(input, k, sorted: sorted, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["k"] = k; + keywords["sorted"] = sorted; + var _op = tf.OpDefLib._apply_op_helper("TopKV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "sorted", _op._get_attr_bool("sorted"), "T", _op._get_attr_type("T") }; + _execute.record_gradient("TopKV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] top_kv2_eager_fallback(Tensor input, Tensor k, bool sorted, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input, k }; + object[] _attrs = new object[] { "sorted", sorted, "T", input.dtype }; + var _result = _execute.execute("TopKV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TopKV2", _inputs_flat, _attrs, _result); + } + return _result; + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs new file mode 100644 index 000000000..5fa4c97dd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -0,0 +1,39171 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class gen_ops + { + /// + /// Raise a exception to abort the process when called. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Abort'. + /// + /// + /// A string which is the message associated with the exception. + /// + /// + /// + /// + /// Returns the description of the operation + /// + /// + /// If exit_without_error is true, the process will exit normally, + /// otherwise it will exit with a SIGABORT signal. + /// + /// Returns nothing but an exception. + /// + public static Operation abort(string error_msg = null, bool? exit_without_error = null, string name = "Abort") + { + var dict = new Dictionary(); + if (error_msg != null) + dict["error_msg"] = error_msg; + if (exit_without_error.HasValue) + dict["exit_without_error"] = exit_without_error.Value; + var op = tf.OpDefLib._apply_op_helper("Abort", name: name, keywords: dict); + return op; + } + + /// + /// Computes the absolute value of a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Abs'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor x, this operation returns a tensor containing the absolute + /// value of each element in x. For example, if x is an input element and y is + /// an output element, this operation computes \\(y = |x|\\). + /// + public static Tensor abs(Tensor x, string name = "Abs") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Abs", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the element-wise sum of a list of tensors. + /// + /// + /// A list of Tensor objects, each with same shape and type. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AccumulateNV2'. + /// + /// + /// Optional argument + /// Shape of elements of inputs. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// tf.accumulate_n_v2 performs the same operation as tf.add_n, but does not + /// wait for all of its inputs to be ready before beginning to sum. This can + /// save memory if inputs are ready at different times, since minimum temporary + /// storage is proportional to the output size rather than the inputs size. + /// + /// Unlike the original accumulate_n, accumulate_n_v2 is differentiable. + /// + /// Returns a Tensor of same shape and type as the elements of inputs. + /// + public static Tensor accumulate_n_v2(Tensor[] inputs, Shape shape, string name = "AccumulateNV2") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("AccumulateNV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies a gradient to a given accumulator. + /// + /// + /// The handle to a accumulator. + /// + /// + /// The local_step value at which the gradient was computed. + /// + /// + /// A tensor of the gradient to be accumulated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AccumulatorApplyGradient'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Does not add if local_step is lesser than the accumulator's global_step. + /// + public static Operation accumulator_apply_gradient(Tensor handle, Tensor local_step, Tensor gradient, string name = "AccumulatorApplyGradient") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["local_step"] = local_step; + dict["gradient"] = gradient; + var op = tf.OpDefLib._apply_op_helper("AccumulatorApplyGradient", name: name, keywords: dict); + return op; + } + + /// + /// Returns the number of gradients aggregated in the given accumulators. + /// + /// + /// The handle to an accumulator. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AccumulatorNumAccumulated'. + /// + /// + /// The number of gradients aggregated in the given accumulator. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor accumulator_num_accumulated(Tensor handle, string name = "AccumulatorNumAccumulated") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("AccumulatorNumAccumulated", name: name, keywords: dict); + return op.output; + } + + /// + /// Updates the accumulator with a new value for global_step. + /// + /// + /// The handle to an accumulator. + /// + /// + /// The new global_step value to set. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AccumulatorSetGlobalStep'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Logs warning if the accumulator's value is already higher than + /// new_global_step. + /// + public static Operation accumulator_set_global_step(Tensor handle, Tensor new_global_step, string name = "AccumulatorSetGlobalStep") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["new_global_step"] = new_global_step; + var op = tf.OpDefLib._apply_op_helper("AccumulatorSetGlobalStep", name: name, keywords: dict); + return op; + } + + /// + /// Extracts the average gradient in the given ConditionalAccumulator. + /// + /// + /// The handle to an accumulator. + /// + /// + /// Number of gradients required before we return an aggregate. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AccumulatorTakeGradient'. + /// + /// + /// Optional argument + /// The data type of accumulated gradients. Needs to correspond to the type + /// of the accumulator. + /// + /// + /// The average of the accumulated gradients. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The op blocks until sufficient (i.e., more than num_required) + /// gradients have been accumulated. If the accumulator has already + /// aggregated more than num_required gradients, it returns the average of + /// the accumulated gradients. Also automatically increments the recorded + /// global_step in the accumulator by 1, and resets the aggregate to 0. + /// + public static Tensor accumulator_take_gradient(Tensor handle, Tensor num_required, TF_DataType dtype, string name = "AccumulatorTakeGradient") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["num_required"] = num_required; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("AccumulatorTakeGradient", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes acos of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Acos'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor acos(Tensor x, string name = "Acos") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Acos", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes inverse hyperbolic cosine of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Acosh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor acosh(Tensor x, string name = "Acosh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Acosh", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x + y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Add'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Add supports broadcasting. AddN does not. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor add(Tensor x, Tensor y, string name = "Add") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Add", name: name, keywords: dict); + return op.output; + } + + /// + /// Add an N-minibatch SparseTensor to a SparseTensorsMap, return N handles. + /// + /// + /// 2-D. The indices of the minibatch SparseTensor. + /// sparse_indices[:, 0] must be ordered values in [0, N). + /// + /// + /// 1-D. The values of the minibatch SparseTensor. + /// + /// + /// 1-D. The shape of the minibatch SparseTensor. + /// The minibatch size N == sparse_shape[0]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AddManySparseToTensorsMap'. + /// + /// + /// The container name for the SparseTensorsMap created by this op. + /// + /// + /// The shared name for the SparseTensorsMap created by this op. + /// If blank, the new Operation's unique name is used. + /// + /// + /// 1-D. The handles of the SparseTensor now stored in the + /// SparseTensorsMap. Shape: [N]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A SparseTensor of rank R is represented by three tensors: sparse_indices, + /// sparse_values, and sparse_shape, where + /// + /// + /// sparse_indices.shape[1] == sparse_shape.shape[0] == R + /// + /// + /// An N-minibatch of SparseTensor objects is represented as a SparseTensor + /// having a first sparse_indices column taking values between [0, N), where + /// the minibatch size N == sparse_shape[0]. + /// + /// The input SparseTensor must have rank R greater than 1, and the first + /// dimension is treated as the minibatch dimension. Elements of the SparseTensor + /// must be sorted in increasing order of this first dimension. The stored + /// SparseTensor objects pointed to by each row of the output sparse_handles + /// will have rank R-1. + /// + /// The SparseTensor values can then be read out as part of a minibatch by passing + /// the given keys as vector elements to TakeManySparseFromTensorsMap. To ensure + /// the correct SparseTensorsMap is accessed, ensure that the same + /// container and shared_name are passed to that Op. If no shared_name + /// is provided here, instead use the *name* of the Operation created by calling + /// AddManySparseToTensorsMap as the shared_name passed to + /// TakeManySparseFromTensorsMap. Ensure the Operations are colocated. + /// + public static Tensor add_many_sparse_to_tensors_map(Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape, string container = null, string shared_name = null, string name = "AddManySparseToTensorsMap") + { + var dict = new Dictionary(); + dict["sparse_indices"] = sparse_indices; + dict["sparse_values"] = sparse_values; + dict["sparse_shape"] = sparse_shape; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("AddManySparseToTensorsMap", name: name, keywords: dict); + return op.output; + } + + /// + /// Add all input tensors element wise. + /// + /// + /// Must all be the same size and shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AddN'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor add_n(Tensor[] inputs, string name = "AddN") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("AddN", name: name, keywords: dict); + return op.output; + } + + /// + /// Add a SparseTensor to a SparseTensorsMap return its handle. + /// + /// + /// 2-D. The indices of the SparseTensor. + /// + /// + /// 1-D. The values of the SparseTensor. + /// + /// + /// 1-D. The shape of the SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AddSparseToTensorsMap'. + /// + /// + /// The container name for the SparseTensorsMap created by this op. + /// + /// + /// The shared name for the SparseTensorsMap created by this op. + /// If blank, the new Operation's unique name is used. + /// + /// + /// 0-D. The handle of the SparseTensor now stored in the + /// SparseTensorsMap. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A SparseTensor is represented by three tensors: sparse_indices, + /// sparse_values, and sparse_shape. + /// + /// This operator takes the given SparseTensor and adds it to a container + /// object (a SparseTensorsMap). A unique key within this container is generated + /// in the form of an int64, and this is the value that is returned. + /// + /// The SparseTensor can then be read out as part of a minibatch by passing + /// the key as a vector element to TakeManySparseFromTensorsMap. To ensure + /// the correct SparseTensorsMap is accessed, ensure that the same + /// container and shared_name are passed to that Op. If no shared_name + /// is provided here, instead use the *name* of the Operation created by calling + /// AddSparseToTensorsMap as the shared_name passed to + /// TakeManySparseFromTensorsMap. Ensure the Operations are colocated. + /// + public static Tensor add_sparse_to_tensors_map(Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape, string container = null, string shared_name = null, string name = "AddSparseToTensorsMap") + { + var dict = new Dictionary(); + dict["sparse_indices"] = sparse_indices; + dict["sparse_values"] = sparse_values; + dict["sparse_shape"] = sparse_shape; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("AddSparseToTensorsMap", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x + y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AddV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Add supports broadcasting. AddN does not. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor add_v2(Tensor x, Tensor y, string name = "AddV2") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("AddV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Disallowed in GraphDef version &gt;= 2. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AdjustContrast'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor adjust_contrast(Tensor images, Tensor contrast_factor, Tensor min_value, Tensor max_value, string name = "AdjustContrast") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["contrast_factor"] = contrast_factor; + dict["min_value"] = min_value; + dict["max_value"] = max_value; + var op = tf.OpDefLib._apply_op_helper("AdjustContrast", name: name, keywords: dict); + return op.output; + } + + /// + /// Adjust the contrast of one or more images. + /// + /// + /// Images to adjust. At least 3-D. + /// + /// + /// A float multiplier for adjusting contrast. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AdjustContrastv2'. + /// + /// + /// The contrast-adjusted image or images. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// images is a tensor of at least 3 dimensions. The last 3 dimensions are + /// interpreted as [height, width, channels]. The other dimensions only + /// represent a collection of images, such as [batch, height, width, channels]. + /// + /// Contrast is adjusted independently for each channel of each image. + /// + /// For each channel, the Op first computes the mean of the image pixels in the + /// channel and then adjusts each component of each pixel to + /// (x - mean) * contrast_factor + mean. + /// + public static Tensor adjust_contrastv2(Tensor images, Tensor contrast_factor, string name = "AdjustContrastv2") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["contrast_factor"] = contrast_factor; + var op = tf.OpDefLib._apply_op_helper("AdjustContrastv2", name: name, keywords: dict); + return op.output; + } + + /// + /// Adjust the hue of one or more images. + /// + /// + /// Images to adjust. At least 3-D. + /// + /// + /// A float delta to add to the hue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AdjustHue'. + /// + /// + /// The hue-adjusted image or images. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// images is a tensor of at least 3 dimensions. The last dimension is + /// interpretted as channels, and must be three. + /// + /// The input image is considered in the RGB colorspace. Conceptually, the RGB + /// colors are first mapped into HSV. A delta is then applied all the hue values, + /// and then remapped back to RGB colorspace. + /// + public static Tensor adjust_hue(Tensor images, Tensor delta, string name = "AdjustHue") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["delta"] = delta; + var op = tf.OpDefLib._apply_op_helper("AdjustHue", name: name, keywords: dict); + return op.output; + } + + /// + /// Adjust the saturation of one or more images. + /// + /// + /// Images to adjust. At least 3-D. + /// + /// + /// A float scale to add to the saturation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AdjustSaturation'. + /// + /// + /// The hue-adjusted image or images. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// images is a tensor of at least 3 dimensions. The last dimension is + /// interpretted as channels, and must be three. + /// + /// The input image is considered in the RGB colorspace. Conceptually, the RGB + /// colors are first mapped into HSV. A scale is then applied all the saturation + /// values, and then remapped back to RGB colorspace. + /// + public static Tensor adjust_saturation(Tensor images, Tensor scale, string name = "AdjustSaturation") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["scale"] = scale; + var op = tf.OpDefLib._apply_op_helper("AdjustSaturation", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the "logical and" of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'All'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor all(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "All") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("All", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a learned unigram distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AllCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to produce. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See explanations of candidate sampling and the data formats at + /// go/candidate-sampling. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) all_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int? seed = null, int? seed2 = null, string name = "AllCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("AllCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Returns the argument of a complex number. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Angle'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input of complex numbers, this operation returns a tensor of + /// type float that is the argument of each element in input. All elements in + /// input must be complex numbers of the form \\(a + bj\\), where *a* + /// is the real part and *b* is the imaginary part. + /// + /// The argument returned by this operation is of the form \\(atan2(b, a)\\). + /// + /// For example: + /// + /// + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.angle(input) ==&gt; [2.0132, 1.056] + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.angle. + /// @end_compatibility + /// + public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle") + { + return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout })); + } + + /// + /// A container for an iterator resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AnonymousIterator'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// A handle to the iterator that can be passed to a "MakeIterator" or + /// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents + /// resource sharing by name, and does not keep a reference to the resource + /// container. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor anonymous_iterator(TF_DataType[] output_types, Shape[] output_shapes, string name = "AnonymousIterator") + { + var dict = new Dictionary(); + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("AnonymousIterator", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the "logical or" of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Any'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor any(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Any") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Any", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the AdaMax algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAdaMax'. + /// + /// + /// If True, updating of the var, m, and v tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// v_t &lt;- max(beta2 * v_{t-1}, abs(g)) + /// variable &lt;- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) + /// + public static Tensor apply_ada_max(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ApplyAdaMax") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["v"] = v; + dict["beta1_power"] = beta1_power; + dict["lr"] = lr; + dict["beta1"] = beta1; + dict["beta2"] = beta2; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAdaMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the adadelta scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay factor. Must be a scalar. + /// + /// + /// Constant factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAdadelta'. + /// + /// + /// If True, updating of the var, accum and update_accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// accum = rho() * accum + (1 - rho()) * grad.square(); + /// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; + /// update_accum = rho() * update_accum + (1 - rho()) * update.square(); + /// var -= update; + /// + public static Tensor apply_adadelta(Tensor var, Tensor accum, Tensor accum_update, Tensor lr, Tensor rho, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ApplyAdadelta") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["accum_update"] = accum_update; + dict["lr"] = lr; + dict["rho"] = rho; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAdadelta", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// accum += grad * grad + /// var -= lr * grad * (1 / sqrt(accum)) + /// + public static Tensor apply_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor grad, bool? use_locking = null, bool? update_slots = null, string name = "ApplyAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (update_slots.HasValue) + dict["update_slots"] = update_slots.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAdagrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the proximal adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Training step number. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAdagradDA'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor apply_adagrad_d_a(Tensor var, Tensor gradient_accumulator, Tensor gradient_squared_accumulator, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor global_step, bool? use_locking = null, string name = "ApplyAdagradDA") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["gradient_accumulator"] = gradient_accumulator; + dict["gradient_squared_accumulator"] = gradient_squared_accumulator; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["global_step"] = global_step; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAdagradDA", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the Adam algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAdam'. + /// + /// + /// If True, updating of the var, m, and v tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, uses the nesterov update. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + /// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + /// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + /// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + /// + public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool? use_locking = null, bool? use_nesterov = null, string name = "ApplyAdam") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["v"] = v; + dict["beta1_power"] = beta1_power; + dict["beta2_power"] = beta2_power; + dict["lr"] = lr; + dict["beta1"] = beta1; + dict["beta2"] = beta2; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAdam", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the AddSign update. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyAddSign'. + /// + /// + /// If True, updating of the var and m tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// update &lt;- (alpha + sign_decay * sign(g) *sign(m)) * g + /// variable &lt;- variable - lr_t * update + /// + public static Tensor apply_add_sign(Tensor var, Tensor m, Tensor lr, Tensor alpha, Tensor sign_decay, Tensor beta, Tensor grad, bool? use_locking = null, string name = "ApplyAddSign") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["lr"] = lr; + dict["alpha"] = alpha; + dict["sign_decay"] = sign_decay; + dict["beta"] = beta; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyAddSign", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the centered RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyCenteredRMSProp'. + /// + /// + /// If True, updating of the var, mg, ms, and mom tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The centered RMSProp algorithm uses an estimate of the centered second moment + /// (i.e., the variance) for normalization, as opposed to regular RMSProp, which + /// uses the (uncentered) second moment. This often helps with training, but is + /// slightly more expensive in terms of computation and memory. + /// + /// Note that in dense implementation of this algorithm, mg, ms, and mom will + /// update even if the grad is zero, but in this sparse implementation, mg, ms, + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// mean_grad = decay * mean_grad + (1-decay) * gradient + /// + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) + /// + /// mg &lt;- rho * mg_{t-1} + (1-rho) * grad + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) + /// var &lt;- var - mom + /// + public static Tensor apply_centered_r_m_s_prop(Tensor var, Tensor mg, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ApplyCenteredRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["mg"] = mg; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyCenteredRMSProp", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regulariation. Must be a scalar. + /// + /// + /// L2 regulariation. Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyFtrl'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// accum_new = accum + grad * grad + /// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Tensor apply_ftrl(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor lr_power, bool? use_locking = null, string name = "ApplyFtrl") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyFtrl", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regulariation. Must be a scalar. + /// + /// + /// L2 shrinkage regulariation. Must be a scalar. + /// + /// + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyFtrlV2'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// grad_with_shrinkage = grad + 2 * l2_shrinkage * var + /// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage + /// linear += grad_with_shrinkage + + /// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Tensor apply_ftrl_v2(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor l2_shrinkage, Tensor lr_power, bool? use_locking = null, string name = "ApplyFtrlV2") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["l2_shrinkage"] = l2_shrinkage; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyFtrlV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' by subtracting 'alpha' * 'delta' from it. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The change. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool? use_locking = null, string name = "ApplyGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["delta"] = delta; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the momentum scheme. Set use_nesterov = True if you + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// Momentum. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyMomentum'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, the tensor passed to compute grad will be + /// var - lr * momentum * accum, so in the end, the var you get is actually + /// var - lr * momentum * accum. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// want to use Nesterov momentum. + /// + /// accum = accum * momentum + grad + /// var -= lr * accum + /// + public static Tensor apply_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool? use_locking = null, bool? use_nesterov = null, string name = "ApplyMomentum") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["momentum"] = momentum; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyMomentum", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the AddSign update. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyPowerSign'. + /// + /// + /// If True, updating of the var and m tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// update &lt;- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g + /// variable &lt;- variable - lr_t * update + /// + public static Tensor apply_power_sign(Tensor var, Tensor m, Tensor lr, Tensor logbase, Tensor sign_decay, Tensor beta, Tensor grad, bool? use_locking = null, string name = "ApplyPowerSign") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["lr"] = lr; + dict["logbase"] = logbase; + dict["sign_decay"] = sign_decay; + dict["beta"] = beta; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyPowerSign", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyProximalAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// accum += grad * grad + /// prox_v = var - lr * grad * (1 / sqrt(accum)) + /// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} + /// + public static Tensor apply_proximal_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor l1, Tensor l2, Tensor grad, bool? use_locking = null, string name = "ApplyProximalAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyProximalAdagrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' as FOBOS algorithm with fixed learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The change. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyProximalGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// prox_v = var - alpha * delta + /// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} + /// + public static Tensor apply_proximal_gradient_descent(Tensor var, Tensor alpha, Tensor l1, Tensor l2, Tensor delta, bool? use_locking = null, string name = "ApplyProximalGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["l1"] = l1; + dict["l2"] = l2; + dict["delta"] = delta; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyProximalGradientDescent", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApplyRMSProp'. + /// + /// + /// If True, updating of the var, ms, and mom tensors is protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that in dense implementation of this algorithm, ms and mom will + /// update even if the grad is zero, but in this sparse implementation, ms + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + /// + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + /// var &lt;- var - mom + /// + public static Tensor apply_r_m_s_prop(Tensor var, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ApplyRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ApplyRMSProp", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of abs(x-y) &lt; tolerance element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ApproximateEqual'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor approximate_equal(Tensor x, Tensor y, float? tolerance = null, string name = "ApproximateEqual") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + if (tolerance.HasValue) + dict["tolerance"] = tolerance.Value; + var op = tf.OpDefLib._apply_op_helper("ApproximateEqual", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the index with the largest value across dimensions of a tensor. + /// + /// + /// + /// + /// int32 or int64, must be in the range [-rank(input), rank(input)). + /// Describes which dimension of the input Tensor to reduce across. For vectors, + /// use dimension = 0. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ArgMax'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that in case of ties the identity of the return value is not guaranteed. + /// + public static Tensor arg_max(Tensor input, Tensor dimension, TF_DataType? output_type = null, string name = "ArgMax") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["dimension"] = dimension; + if (output_type.HasValue) + dict["output_type"] = output_type.Value; + var op = tf.OpDefLib._apply_op_helper("ArgMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the index with the smallest value across dimensions of a tensor. + /// + /// + /// + /// + /// int32 or int64, must be in the range [-rank(input), rank(input)). + /// Describes which dimension of the input Tensor to reduce across. For vectors, + /// use dimension = 0. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ArgMin'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that in case of ties the identity of the return value is not guaranteed. + /// + public static Tensor arg_min(Tensor input, Tensor dimension, TF_DataType? output_type = null, string name = "ArgMin") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["dimension"] = dimension; + if (output_type.HasValue) + dict["output_type"] = output_type.Value; + var op = tf.OpDefLib._apply_op_helper("ArgMin", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts each entry in the given tensor to strings. Supports many numeric + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AsString'. + /// + /// + /// The post-decimal precision to use for floating point numbers. + /// Only used if precision &gt; -1. + /// + /// + /// Use scientific notation for floating point numbers. + /// + /// + /// Use shortest representation (either scientific or standard) for + /// floating point numbers. + /// + /// + /// Pad pre-decimal numbers to this width. + /// Applies to both floating point and integer numbers. + /// Only used if width &gt; -1. + /// + /// + /// The value to pad if width &gt; -1. If empty, pads with spaces. + /// Another typical value is '0'. String cannot be longer than 1 character. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// types and boolean. + /// + public static Tensor as_string(Tensor input, int? precision = null, bool? scientific = null, bool? shortest = null, int? width = null, string fill = null, string name = "AsString") + { + var dict = new Dictionary(); + dict["input"] = input; + if (precision.HasValue) + dict["precision"] = precision.Value; + if (scientific.HasValue) + dict["scientific"] = scientific.Value; + if (shortest.HasValue) + dict["shortest"] = shortest.Value; + if (width.HasValue) + dict["width"] = width.Value; + if (fill != null) + dict["fill"] = fill; + var op = tf.OpDefLib._apply_op_helper("AsString", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes asin of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Asin'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor asin(Tensor x, string name = "Asin") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Asin", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes inverse hyperbolic sine of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Asinh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor asinh(Tensor x, string name = "Asinh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Asinh", name: name, keywords: dict); + return op.output; + } + + /// + /// Asserts that the given condition is true. + /// + /// + /// The condition to evaluate. + /// + /// + /// The tensors to print out when condition is false. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Assert'. + /// + /// + /// Print this many entries of each tensor. + /// + /// + /// Returns the description of the operation + /// + /// + /// If condition evaluates to false, print the list of tensors in data. + /// summarize determines how many entries of the tensors to print. + /// + public static Operation assert(Tensor condition, Tensor[] data, int? summarize = null, string name = "Assert") + { + var dict = new Dictionary(); + dict["condition"] = condition; + dict["data"] = data; + if (summarize.HasValue) + dict["summarize"] = summarize.Value; + var op = tf.OpDefLib._apply_op_helper("Assert", name: name, keywords: dict); + return op; + } + + /// + /// Update 'ref' by assigning 'value' to it. + /// + /// + /// Should be from a Variable node. May be uninitialized. + /// + /// + /// The value to be assigned to the variable. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Assign'. + /// + /// + /// If true, the operation will validate that the shape + /// of 'value' matches the shape of the Tensor being assigned to. If false, + /// 'ref' will take on the shape of 'value'. + /// + /// + /// If True, the assignment will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as "ref". Returned as a convenience for operations that want + /// to use the new value after the variable has been reset. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation outputs "ref" after the assignment is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + public static Tensor assign(Tensor referecne, Tensor value, bool? validate_shape = null, bool? use_locking = null, string name = "Assign") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["value"] = value; + if (validate_shape.HasValue) + dict["validate_shape"] = validate_shape.Value; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("Assign", name: name, keywords: dict); + return op.output; + } + + /// + /// Update 'ref' by adding 'value' to it. + /// + /// + /// Should be from a Variable node. + /// + /// + /// The value to be added to the variable. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AssignAdd'. + /// + /// + /// If True, the addition will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as "ref". Returned as a convenience for operations that want + /// to use the new value after the variable has been updated. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation outputs "ref" after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + public static Tensor assign_add(Tensor referecne, Tensor value, bool? use_locking = null, string name = "AssignAdd") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["value"] = value; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Adds a value to the current value of a variable. + /// + /// + /// handle to the resource in which to store the variable. + /// + /// + /// the value by which the variable will be incremented. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AssignAddVariableOp'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to + /// see the incremented value or a subsequent newer one. + /// + public static Operation assign_add_variable_op(Tensor resource, Tensor value, string name = "AssignAddVariableOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("AssignAddVariableOp", name: name, keywords: dict); + return op; + } + + /// + /// Update 'ref' by subtracting 'value' from it. + /// + /// + /// Should be from a Variable node. + /// + /// + /// The value to be subtracted to the variable. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AssignSub'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as "ref". Returned as a convenience for operations that want + /// to use the new value after the variable has been updated. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation outputs "ref" after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + public static Tensor assign_sub(Tensor referecne, Tensor value, bool? use_locking = null, string name = "AssignSub") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["value"] = value; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("AssignSub", name: name, keywords: dict); + return op.output; + } + + /// + /// Subtracts a value from the current value of a variable. + /// + /// + /// handle to the resource in which to store the variable. + /// + /// + /// the value by which the variable will be incremented. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AssignSubVariableOp'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to + /// see the decremented value or a subsequent newer one. + /// + public static Operation assign_sub_variable_op(Tensor resource, Tensor value, string name = "AssignSubVariableOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("AssignSubVariableOp", name: name, keywords: dict); + return op; + } + + /// + /// Assigns a new value to a variable. + /// + /// + /// handle to the resource in which to store the variable. + /// + /// + /// the value to set the new tensor to use. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AssignVariableOp'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to return + /// this value or a subsequent newer value of the variable. + /// + public static Operation assign_variable_op(Tensor resource, Tensor value, string name = "AssignVariableOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("AssignVariableOp", name: name, keywords: dict); + return op; + } + + /// + /// Computes atan of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Atan'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor atan(Tensor x, string name = "Atan") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Atan", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes arctangent of y/x element-wise, respecting signs of the arguments. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Atan2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is the angle \( \theta \in [-\pi, \pi] \) such that + /// \[ x = r \cos(\theta) \] + /// and + /// \[ y = r \sin(\theta) \] + /// where \(r = \sqrt(x^2 + y^2) \). + /// + public static Tensor atan2(Tensor y, Tensor x, string name = "Atan2") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Atan2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes inverse hyperbolic tangent of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Atanh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor atanh(Tensor x, string name = "Atanh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Atanh", name: name, keywords: dict); + return op.output; + } + + /// + /// Produces a visualization of audio data over time. + /// + /// + /// Float representation of audio data. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AudioSpectrogram'. + /// + /// + /// Optional argument + /// How wide the input window is in samples. For the highest efficiency + /// this should be a power of two, but other values are accepted. + /// + /// + /// Optional argument + /// How widely apart the center of adjacent sample windows should be. + /// + /// + /// Whether to return the squared magnitude or just the + /// magnitude. Using squared magnitude can avoid extra calculations. + /// + /// + /// 3D representation of the audio frequencies as an image. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Spectrograms are a standard way of representing audio information as a series of + /// slices of frequency information, one slice for each window of time. By joining + /// these together into a sequence, they form a distinctive fingerprint of the sound + /// over time. + /// + /// This op expects to receive audio data as an input, stored as floats in the range + /// -1 to 1, together with a window width in samples, and a stride specifying how + /// far to move the window between slices. From this it generates a three + /// dimensional output. The lowest dimension has an amplitude value for each + /// frequency during that time slice. The next dimension is time, with successive + /// frequency slices. The final dimension is for the channels in the input, so a + /// stereo audio input would have two here for example. + /// + /// This means the layout when converted and saved as an image is rotated 90 degrees + /// clockwise from a typical spectrogram. Time is descending down the Y axis, and + /// the frequency decreases from left to right. + /// + /// Each value in the result represents the square root of the sum of the real and + /// imaginary parts of an FFT on the current window of samples. In this way, the + /// lowest dimension represents the power of each frequency in the current window, + /// and adjacent windows are concatenated in the next dimension. + /// + /// To get a more intuitive and visual look at what this operation does, you can run + /// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the + /// resulting spectrogram as a PNG image. + /// + public static Tensor audio_spectrogram(Tensor input, int window_size, int stride, bool? magnitude_squared = null, string name = "AudioSpectrogram") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["window_size"] = window_size; + dict["stride"] = stride; + if (magnitude_squared.HasValue) + dict["magnitude_squared"] = magnitude_squared.Value; + var op = tf.OpDefLib._apply_op_helper("AudioSpectrogram", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with audio. + /// + /// + /// Scalar. Used to build the tag attribute of the summary values. + /// + /// + /// 2-D of shape [batch_size, frames]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AudioSummary'. + /// + /// + /// Optional argument + /// The sample rate of the signal in hertz. + /// + /// + /// Max number of batch elements to generate audio for. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The summary has up to max_outputs summary values containing audio. The + /// audio is built from tensor which must be 3-D with shape [batch_size, + /// frames, channels] or 2-D with shape [batch_size, frames]. The values are + /// assumed to be in the range of [-1.0, 1.0] with a sample rate of sample_rate. + /// + /// The tag argument is a scalar Tensor of type string. It is used to + /// build the tag of the summary values: + /// + /// * If max_outputs is 1, the summary value tag is '*tag*/audio'. + /// * If max_outputs is greater than 1, the summary value tags are + /// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. + /// + public static Tensor audio_summary(Tensor tag, Tensor tensor, float sample_rate, int? max_outputs = null, string name = "AudioSummary") + { + var dict = new Dictionary(); + dict["tag"] = tag; + dict["tensor"] = tensor; + dict["sample_rate"] = sample_rate; + if (max_outputs.HasValue) + dict["max_outputs"] = max_outputs.Value; + var op = tf.OpDefLib._apply_op_helper("AudioSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with audio. + /// + /// + /// Scalar. Used to build the tag attribute of the summary values. + /// + /// + /// 2-D of shape [batch_size, frames]. + /// + /// + /// The sample rate of the signal in hertz. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AudioSummaryV2'. + /// + /// + /// Max number of batch elements to generate audio for. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The summary has up to max_outputs summary values containing audio. The + /// audio is built from tensor which must be 3-D with shape [batch_size, + /// frames, channels] or 2-D with shape [batch_size, frames]. The values are + /// assumed to be in the range of [-1.0, 1.0] with a sample rate of sample_rate. + /// + /// The tag argument is a scalar Tensor of type string. It is used to + /// build the tag of the summary values: + /// + /// * If max_outputs is 1, the summary value tag is '*tag*/audio'. + /// * If max_outputs is greater than 1, the summary value tags are + /// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. + /// + public static Tensor audio_summary_v2(Tensor tag, Tensor tensor, Tensor sample_rate, int? max_outputs = null, string name = "AudioSummaryV2") + { + var dict = new Dictionary(); + dict["tag"] = tag; + dict["tensor"] = tensor; + dict["sample_rate"] = sample_rate; + if (max_outputs.HasValue) + dict["max_outputs"] = max_outputs.Value; + var op = tf.OpDefLib._apply_op_helper("AudioSummaryV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs average pooling on the input. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AvgPool'. + /// + /// + /// Optional argument + /// The size of the sliding window for each dimension of value. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of value. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// The average pooled output tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Each entry in output is the mean of the corresponding size ksize + /// window in value. + /// + public static Tensor avg_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = null, string name = "AvgPool") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("AvgPool", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs 3D average pooling on the input. + /// + /// + /// Shape [batch, depth, rows, cols, channels] tensor to pool over. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AvgPool3D'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have ksize[0] = ksize[4] = 1. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// The average pooled output tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor avg_pool3d(Tensor input, int[] ksize, int[] strides, string padding, string data_format = null, string name = "AvgPool3D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("AvgPool3D", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of average pooling function. + /// + /// + /// The original input dimensions. + /// + /// + /// Output backprop of shape [batch, depth, rows, cols, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AvgPool3DGrad'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have ksize[0] = ksize[4] = 1. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// The backprop for input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor avg_pool3d_grad(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "AvgPool3DGrad") + { + var dict = new Dictionary(); + dict["orig_input_shape"] = orig_input_shape; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("AvgPool3DGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of the average pooling function. + /// + /// + /// 1-D. Shape of the original input to avg_pool. + /// + /// + /// 4-D with shape [batch, height, width, channels]. Gradients w.r.t. + /// the output of avg_pool. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'AvgPoolGrad'. + /// + /// + /// Optional argument + /// The size of the sliding window for each dimension of the input. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// 4-D. Gradients w.r.t. the input of avg_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor avg_pool_grad(Tensor orig_input_shape, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "AvgPoolGrad") + { + var dict = new Dictionary(); + dict["orig_input_shape"] = orig_input_shape; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("AvgPoolGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Defines a barrier that persists across different graph executions. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Barrier'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. Each shape must be 1 in the + /// first dimension. The length of this attr must be the same as the length of + /// component_types. + /// + /// + /// The capacity of the barrier. The default capacity is MAX_INT32, + /// which is the largest capacity of the underlying queue. + /// + /// + /// If non-empty, this barrier is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this barrier will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the barrier. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A barrier represents a key-value map, where each key is a string, and + /// each value is a tuple of tensors. + /// + /// At runtime, the barrier contains 'complete' and 'incomplete' + /// elements. A complete element has defined tensors for all components of + /// its value tuple, and may be accessed using BarrierTakeMany. An + /// incomplete element has some undefined components in its value tuple, + /// and may be updated using BarrierInsertMany. + /// + public static Tensor barrier(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, string container = null, string shared_name = null, string name = "Barrier") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("Barrier", name: name, keywords: dict); + return op.output; + } + + /// + /// Closes the given barrier. + /// + /// + /// The handle to a barrier. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BarrierClose'. + /// + /// + /// If true, all pending enqueue requests that are + /// blocked on the barrier's queue will be canceled. InsertMany will fail, even + /// if no new key is introduced. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation signals that no more new elements will be inserted in the + /// given barrier. Subsequent InsertMany that try to introduce a new key will fail. + /// Subsequent InsertMany operations that just add missing components to already + /// existing elements will continue to succeed. Subsequent TakeMany operations will + /// continue to succeed if sufficient completed elements remain in the barrier. + /// Subsequent TakeMany operations that would block will fail immediately. + /// + public static Operation barrier_close(Tensor handle, bool? cancel_pending_enqueues = null, string name = "BarrierClose") + { + var dict = new Dictionary(); + dict["handle"] = handle; + if (cancel_pending_enqueues.HasValue) + dict["cancel_pending_enqueues"] = cancel_pending_enqueues.Value; + var op = tf.OpDefLib._apply_op_helper("BarrierClose", name: name, keywords: dict); + return op; + } + + /// + /// Computes the number of incomplete elements in the given barrier. + /// + /// + /// The handle to a barrier. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BarrierIncompleteSize'. + /// + /// + /// The number of incomplete elements (i.e. those with some of their value + /// components not set) in the barrier. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor barrier_incomplete_size(Tensor handle, string name = "BarrierIncompleteSize") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("BarrierIncompleteSize", name: name, keywords: dict); + return op.output; + } + + /// + /// For each key, assigns the respective value to the specified component. + /// + /// + /// The handle to a barrier. + /// + /// + /// A one-dimensional tensor of keys, with length n. + /// + /// + /// An any-dimensional tensor of values, which are associated with the + /// respective keys. The 0th dimension must have length n. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BarrierInsertMany'. + /// + /// + /// Optional argument + /// The component of the barrier elements that is being assigned. + /// + /// + /// Returns the description of the operation + /// + /// + /// If a key is not found in the barrier, this operation will create a new + /// incomplete element. If a key is found in the barrier, and the element + /// already has a value at component_index, this operation will fail with + /// INVALID_ARGUMENT, and leave the barrier in an undefined state. + /// + public static Operation barrier_insert_many(Tensor handle, Tensor keys, Tensor values, int component_index, string name = "BarrierInsertMany") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["keys"] = keys; + dict["values"] = values; + dict["component_index"] = component_index; + var op = tf.OpDefLib._apply_op_helper("BarrierInsertMany", name: name, keywords: dict); + return op; + } + + /// + /// Computes the number of complete elements in the given barrier. + /// + /// + /// The handle to a barrier. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BarrierReadySize'. + /// + /// + /// The number of complete elements (i.e. those with all of their value + /// components set) in the barrier. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor barrier_ready_size(Tensor handle, string name = "BarrierReadySize") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("BarrierReadySize", name: name, keywords: dict); + return op.output; + } + + /// + /// Takes the given number of completed elements from a barrier. + /// + /// + /// The handle to a barrier. + /// + /// + /// A single-element tensor containing the number of elements to + /// take. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BarrierTakeMany'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// Allow to return less than num_elements items if barrier is + /// already closed. + /// + /// + /// + /// + /// If the queue is empty, this operation will block for up to + /// timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// indices : A one-dimensional tensor of indices, with length num_elems. + /// These indices refer to the batch in which the values were placed into the + /// barrier (starting with MIN_LONG and increasing with each BarrierInsertMany). + /// keys : A one-dimensional tensor of keys, with length num_elements. + /// values : One any-dimensional tensor per component in a barrier element. All + /// values have length num_elements in the 0th dimension. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation concatenates completed-element component tensors along + /// the 0th dimension to make a single component tensor. + /// + /// Elements come out of the barrier when they are complete, and in the order + /// in which they were placed into the barrier. The indices output provides + /// information about the batch in which each element was originally inserted + /// into the barrier. + /// + public static (Tensor indices, Tensor keys, Tensor[] values) barrier_take_many(Tensor handle, Tensor num_elements, TF_DataType[] component_types, bool? allow_small_batch = null, bool? wait_for_incomplete = null, int? timeout_ms = null, string name = "BarrierTakeMany") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["num_elements"] = num_elements; + dict["component_types"] = component_types; + if (allow_small_batch.HasValue) + dict["allow_small_batch"] = allow_small_batch.Value; + if (wait_for_incomplete.HasValue) + dict["wait_for_incomplete"] = wait_for_incomplete.Value; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("BarrierTakeMany", name: name, keywords: dict); + int _idx = 0; + var indices = op.outputs[_idx++]; + var keys = op.outputs[_idx++]; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (indices, keys, values); + } + + /// + /// Batches all input tensors nondeterministically. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Batch'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// batched_tensors : + /// batch_index : + /// id : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// When many instances of this Op are being run concurrently with the same + /// container/shared_name in the same device, some will output zero-shaped Tensors + /// and others will output Tensors of size up to max_batch_size. + /// + /// All Tensors in in_tensors are batched together (so, for example, labels and + /// features should be batched with a single instance of this operation. + /// + /// Each invocation of batch emits an id scalar which will be used to identify + /// this particular invocation when doing unbatch or its gradient. + /// + /// Each op which emits a non-empty batch will also emit a non-empty batch_index + /// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, + /// start, and length of elements of each set of Tensors present in batched_tensors. + /// + /// Batched tensors are concatenated along the first dimension, and all tensors in + /// in_tensors must have the first dimension of the same size. + /// + /// in_tensors: The tensors to be batched. + /// num_batch_threads: Number of scheduling threads for processing batches of work. + /// Determines the number of batches processed in parallel. + /// max_batch_size: Batch sizes will never be bigger than this. + /// batch_timeout_micros: Maximum number of microseconds to wait before outputting + /// an incomplete batch. + /// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does + /// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad + /// batches up to one of those sizes. The entries must increase monotonically, and + /// the final entry must equal max_batch_size. + /// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. + /// batched_tensors: Either empty tensors or a batch of concatenated Tensors. + /// batch_index: If out_tensors is non-empty, has information to invert it. + /// container: Controls the scope of sharing of this batch. + /// id: always contains a scalar with a unique ID for this invocation of Batch. + /// shared_name: Concurrently running instances of batch in the same device with the + /// same container and shared_name will batch their elements together. If left + /// empty, the op name will be used as the shared name. + /// T: the types of tensors to be batched. + /// + public static (Tensor[] batched_tensors, Tensor batch_index, Tensor id) batch(Tensor[] in_tensors, int num_batch_threads, int max_batch_size, int batch_timeout_micros, int grad_timeout_micros, int? max_enqueued_batches = null, int[] allowed_batch_sizes = null, string container = null, string shared_name = null, string batching_queue = null, string name = "Batch") + { + var dict = new Dictionary(); + dict["in_tensors"] = in_tensors; + dict["num_batch_threads"] = num_batch_threads; + dict["max_batch_size"] = max_batch_size; + dict["batch_timeout_micros"] = batch_timeout_micros; + dict["grad_timeout_micros"] = grad_timeout_micros; + if (max_enqueued_batches.HasValue) + dict["max_enqueued_batches"] = max_enqueued_batches.Value; + if (allowed_batch_sizes != null) + dict["allowed_batch_sizes"] = allowed_batch_sizes; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (batching_queue != null) + dict["batching_queue"] = batching_queue; + var op = tf.OpDefLib._apply_op_helper("Batch", name: name, keywords: dict); + int _idx = 0; + var batched_tensors = Enumerable.Range(0, op.OutputListLength("batched_tensors")).Select(_ => op.outputs[_idx++]).ToArray(); + var batch_index = op.outputs[_idx++]; + var id = op.outputs[_idx++]; + return (batched_tensors, batch_index, id); + } + + /// + /// Creates a dataset that batches batch_size elements from input_dataset. + /// + /// + /// + /// + /// A scalar representing the number of elements to accumulate in a + /// batch. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor batch_dataset(Tensor input_dataset, Tensor batch_size, TF_DataType[] output_types, Shape[] output_shapes, string name = "BatchDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["batch_size"] = batch_size; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("BatchDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that batches batch_size elements from input_dataset. + /// + /// + /// + /// + /// A scalar representing the number of elements to accumulate in a batch. + /// + /// + /// A scalar representing whether the last batch should be dropped in case its size + /// is smaller than desired. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchDatasetV2'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor batch_dataset_v2(Tensor input_dataset, Tensor batch_size, Tensor drop_remainder, TF_DataType[] output_types, Shape[] output_shapes, string name = "BatchDatasetV2") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["batch_size"] = batch_size; + dict["drop_remainder"] = drop_remainder; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("BatchDatasetV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Multiplies slices of two tensors in batches. + /// + /// + /// 2-D or higher with shape [..., r_x, c_x]. + /// + /// + /// 2-D or higher with shape [..., r_y, c_y]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchMatMul'. + /// + /// + /// If True, adjoint the slices of x. Defaults to False. + /// + /// + /// If True, adjoint the slices of y. Defaults to False. + /// + /// + /// 3-D or higher with shape [..., r_o, c_o] + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Multiplies all slices of Tensor x and y (each slice can be + /// viewed as an element of a batch), and arranges the individual results + /// in a single output tensor of the same batch size. Each of the + /// individual slices can optionally be adjointed (to adjoint a matrix + /// means to transpose and conjugate it) before multiplication by setting + /// the adj_x or adj_y flag to True, which are by default False. + /// + /// The input tensors x and y are 2-D or higher with shape [..., r_x, c_x] + /// and [..., r_y, c_y]. + /// + /// The output tensor is 2-D or higher with shape [..., r_o, c_o], where: + /// + /// r_o = c_x if adj_x else r_x + /// c_o = r_y if adj_y else c_y + /// + /// It is computed as: + /// + /// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + /// + public static Tensor batch_mat_mul(Tensor x, Tensor y, bool? adj_x = null, bool? adj_y = null, string name = "BatchMatMul") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + if (adj_x.HasValue) + dict["adj_x"] = adj_x.Value; + if (adj_y.HasValue) + dict["adj_y"] = adj_y.Value; + var op = tf.OpDefLib._apply_op_helper("BatchMatMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Batch normalization. + /// + /// + /// A 4D input Tensor. + /// + /// + /// A 1D mean Tensor with size matching the last dimension of t. + /// This is the first output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// A 1D variance Tensor with size matching the last dimension of t. + /// This is the second output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// A 1D beta Tensor with size matching the last dimension of t. + /// An offset to be added to the normalized tensor. + /// + /// + /// A 1D gamma Tensor with size matching the last dimension of t. + /// If "scale_after_normalization" is true, this tensor will be multiplied + /// with the normalized tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchNormWithGlobalNormalization'. + /// + /// + /// Optional argument + /// A small float number to avoid dividing by 0. + /// + /// + /// Optional argument + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op is deprecated. Prefer tf.nn.batch_normalization. + /// + public static Tensor batch_norm_with_global_normalization(Tensor t, Tensor m, Tensor v, Tensor beta, Tensor gamma, float variance_epsilon, bool scale_after_normalization, string name = "BatchNormWithGlobalNormalization") + { + var dict = new Dictionary(); + dict["t"] = t; + dict["m"] = m; + dict["v"] = v; + dict["beta"] = beta; + dict["gamma"] = gamma; + dict["variance_epsilon"] = variance_epsilon; + dict["scale_after_normalization"] = scale_after_normalization; + var op = tf.OpDefLib._apply_op_helper("BatchNormWithGlobalNormalization", name: name, keywords: dict); + return op.output; + } + + /// + /// Gradients for batch normalization. + /// + /// + /// A 4D input Tensor. + /// + /// + /// A 1D mean Tensor with size matching the last dimension of t. + /// This is the first output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// A 1D variance Tensor with size matching the last dimension of t. + /// This is the second output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// A 1D gamma Tensor with size matching the last dimension of t. + /// If "scale_after_normalization" is true, this Tensor will be multiplied + /// with the normalized Tensor. + /// + /// + /// 4D backprop Tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchNormWithGlobalNormalizationGrad'. + /// + /// + /// Optional argument + /// A small float number to avoid dividing by 0. + /// + /// + /// Optional argument + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// dx : 4D backprop tensor for input. + /// dm : 1D backprop tensor for mean. + /// dv : 1D backprop tensor for variance. + /// db : 1D backprop tensor for beta. + /// dg : 1D backprop tensor for gamma. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This op is deprecated. See tf.nn.batch_normalization. + /// + public static (Tensor dx, Tensor dm, Tensor dv, Tensor db, Tensor dg) batch_norm_with_global_normalization_grad(Tensor t, Tensor m, Tensor v, Tensor gamma, Tensor backprop, float variance_epsilon, bool scale_after_normalization, string name = "BatchNormWithGlobalNormalizationGrad") + { + var dict = new Dictionary(); + dict["t"] = t; + dict["m"] = m; + dict["v"] = v; + dict["gamma"] = gamma; + dict["backprop"] = backprop; + dict["variance_epsilon"] = variance_epsilon; + dict["scale_after_normalization"] = scale_after_normalization; + var op = tf.OpDefLib._apply_op_helper("BatchNormWithGlobalNormalizationGrad", name: name, keywords: dict); + int _idx = 0; + var dx = op.outputs[_idx++]; + var dm = op.outputs[_idx++]; + var dv = op.outputs[_idx++]; + var db = op.outputs[_idx++]; + var dg = op.outputs[_idx++]; + return (dx, dm, dv, db, dg); + } + + /// + /// BatchToSpace for 4-D tensors of type T. + /// + /// + /// 4-D tensor with shape + /// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, + /// depth]. Note that the batch size of the input tensor must be divisible by + /// block_size * block_size. + /// + /// + /// 2-D tensor of non-negative integers with shape [2, 2]. It specifies + /// how many elements to crop from the intermediate result across the spatial + /// dimensions as follows: + /// + /// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchToSpace'. + /// + /// + /// Optional argument + /// + /// + /// 4-D with shape [batch, height, width, depth], where: + /// + /// height = height_pad - crop_top - crop_bottom + /// width = width_pad - crop_left - crop_right + /// + /// The attr block_size must be greater than one. It indicates the block size. + /// + /// Some examples: + /// + /// (1) For the following input of shape [4, 1, 1, 1] and block_size of 2: + /// + /// + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// + /// + /// The output tensor has shape [1, 2, 2, 1] and value: + /// + /// + /// x = [[[[1], [2]], [[3], [4]]]] + /// + /// + /// (2) For the following input of shape [4, 1, 1, 3] and block_size of 2: + /// + /// + /// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] + /// + /// + /// The output tensor has shape [1, 2, 2, 3] and value: + /// + /// + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// (3) For the following input of shape [4, 2, 2, 1] and block_size of 2: + /// + /// + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// + /// + /// The output tensor has shape [1, 4, 4, 1] and value: + /// + /// + /// x = [[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]] + /// + /// + /// (4) For the following input of shape [8, 1, 2, 1] and block_size of 2: + /// + /// + /// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + /// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + /// + /// + /// The output tensor has shape [2, 2, 4, 1] and value: + /// + /// + /// x = [[[[1], [3]], [[5], [7]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is a legacy version of the more general BatchToSpaceND. + /// + /// Rearranges (permutes) data from batch into blocks of spatial data, followed by + /// cropping. This is the reverse transformation of SpaceToBatch. More specifically, + /// this op outputs a copy of the input tensor where values from the batch + /// dimension are moved in spatial blocks to the height and width dimensions, + /// followed by cropping along the height and width dimensions. + /// + public static Tensor batch_to_space(Tensor input, Tensor crops, int block_size, string name = "BatchToSpace") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["crops"] = crops; + dict["block_size"] = block_size; + var op = tf.OpDefLib._apply_op_helper("BatchToSpace", name: name, keywords: dict); + return op.output; + } + + /// + /// BatchToSpace for N-D tensors of type T. + /// + /// + /// N-D with shape input_shape = [batch] + spatial_shape + remaining_shape, + /// where spatial_shape has M dimensions. + /// + /// + /// 1-D with shape [M], all values must be &gt;= 1. + /// + /// + /// 2-D with shape [M, 2], all values must be &gt;= 0. + /// crops[i] = [crop_start, crop_end] specifies the amount to crop from input + /// dimension i + 1, which corresponds to spatial dimension i. It is + /// required that + /// crop_start[i] + crop_end[i] &lt;= block_shape[i] * input_shape[i + 1]. + /// + /// This operation is equivalent to the following steps: + /// + /// 1. Reshape input to reshaped of shape: + /// [block_shape[0], ..., block_shape[M-1], + /// batch / prod(block_shape), + /// input_shape[1], ..., input_shape[N-1]] + /// + /// 2. Permute dimensions of reshaped to produce permuted of shape + /// [batch / prod(block_shape), + /// + /// input_shape[1], block_shape[0], + /// ..., + /// input_shape[M], block_shape[M-1], + /// + /// input_shape[M+1], ..., input_shape[N-1]] + /// + /// 3. Reshape permuted to produce reshaped_permuted of shape + /// [batch / prod(block_shape), + /// + /// input_shape[1] * block_shape[0], + /// ..., + /// input_shape[M] * block_shape[M-1], + /// + /// input_shape[M+1], + /// ..., + /// input_shape[N-1]] + /// + /// 4. Crop the start and end of dimensions [1, ..., M] of + /// reshaped_permuted according to crops to produce the output of shape: + /// [batch / prod(block_shape), + /// + /// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + /// ..., + /// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + /// + /// input_shape[M+1], ..., input_shape[N-1]] + /// + /// Some examples: + /// + /// (1) For the following input of shape [4, 1, 1, 1], block_shape = [2, 2], and + /// crops = [[0, 0], [0, 0]]: + /// + /// + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// + /// + /// The output tensor has shape [1, 2, 2, 1] and value: + /// + /// + /// x = [[[[1], [2]], [[3], [4]]]] + /// + /// + /// (2) For the following input of shape [4, 1, 1, 3], block_shape = [2, 2], and + /// crops = [[0, 0], [0, 0]]: + /// + /// + /// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] + /// + /// + /// The output tensor has shape [1, 2, 2, 3] and value: + /// + /// + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// (3) For the following input of shape [4, 2, 2, 1], block_shape = [2, 2], and + /// crops = [[0, 0], [0, 0]]: + /// + /// + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// + /// + /// The output tensor has shape [1, 4, 4, 1] and value: + /// + /// + /// x = [[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]] + /// + /// + /// (4) For the following input of shape [8, 1, 3, 1], block_shape = [2, 2], and + /// crops = [[0, 0], [2, 0]]: + /// + /// + /// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], + /// [[[0], [2], [4]]], [[[0], [10], [12]]], + /// [[[0], [5], [7]]], [[[0], [13], [15]]], + /// [[[0], [6], [8]]], [[[0], [14], [16]]]] + /// + /// + /// The output tensor has shape [2, 2, 4, 1] and value: + /// + /// + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]]], + /// [[[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BatchToSpaceND'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation reshapes the "batch" dimension 0 into M + 1 dimensions of shape + /// block_shape + [batch], interleaves these blocks back into the grid defined by + /// the spatial dimensions [1, ..., M], to obtain a result with the same rank as + /// the input. The spatial dimensions of this intermediate result are then + /// optionally cropped according to crops to produce the output. This is the + /// reverse of SpaceToBatch. See below for a precise description. + /// + public static Tensor batch_to_space_n_d(Tensor input, Tensor block_shape, Tensor crops, string name = "BatchToSpaceND") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["block_shape"] = block_shape; + dict["crops"] = crops; + var op = tf.OpDefLib._apply_op_helper("BatchToSpaceND", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the Bessel i0e function of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BesselI0e'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Exponentially scaled modified Bessel function of order 0 defined as + /// bessel_i0e(x) = exp(-abs(x)) bessel_i0(x). + /// + /// This function is faster and numerically stabler than bessel_i0(x). + /// + public static Tensor bessel_i0e(Tensor x, string name = "BesselI0e") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("BesselI0e", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the Bessel i1e function of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BesselI1e'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Exponentially scaled modified Bessel function of order 0 defined as + /// bessel_i1e(x) = exp(-abs(x)) bessel_i1(x). + /// + /// This function is faster and numerically stabler than bessel_i1(x). + /// + public static Tensor bessel_i1e(Tensor x, string name = "BesselI1e") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("BesselI1e", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Betainc'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The regularized incomplete beta integral is defined as: + /// + /// + /// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) + /// + /// where + /// + /// + /// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) + /// + /// + /// is the incomplete beta function and \\(B(a, b)\\) is the *complete* + /// beta function. + /// + public static Tensor betainc(Tensor a, Tensor b, Tensor x, string name = "Betainc") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["b"] = b; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Betainc", name: name, keywords: dict); + return op.output; + } + + /// + /// Adds bias to value. + /// + /// + /// Any number of dimensions. + /// + /// + /// 1-D with size the last dimension of value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BiasAdd'. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the bias tensor will be added to the last dimension + /// of the value tensor. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// The tensor will be added to "in_channels", the third-to-the-last + /// dimension. + /// + /// + /// Broadcasted sum of value and bias. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is a special case of tf.add where bias is restricted to be 1-D. + /// Broadcasting is supported, so value may have any number of dimensions. + /// + public static Tensor bias_add(Tensor value, Tensor bias, string data_format = null, string name = "BiasAdd") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["bias"] = bias; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("BiasAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// The backward operation for "BiasAdd" on the "bias" tensor. + /// + /// + /// Any number of dimensions. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BiasAddGrad'. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the bias tensor will be added to the last dimension + /// of the value tensor. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// The tensor will be added to "in_channels", the third-to-the-last + /// dimension. + /// + /// + /// 1-D with size the feature dimension of out_backprop. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It accumulates all the values from out_backprop into the feature dimension. + /// For NHWC data format, the feature dimension is the last. For NCHW data format, + /// the feature dimension is the third-to-last. + /// + public static Tensor bias_add_grad(Tensor out_backprop, string data_format = null, string name = "BiasAddGrad") + { + var dict = new Dictionary(); + dict["out_backprop"] = out_backprop; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("BiasAddGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Adds bias to value. + /// + /// + /// Any number of dimensions. + /// + /// + /// 1-D with size the last dimension of value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BiasAddV1'. + /// + /// + /// Broadcasted sum of value and bias. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is a deprecated version of BiasAdd and will be soon removed. + /// + /// This is a special case of tf.add where bias is restricted to be 1-D. + /// Broadcasting is supported, so value may have any number of dimensions. + /// + public static Tensor bias_add_v1(Tensor value, Tensor bias, string name = "BiasAddV1") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["bias"] = bias; + var op = tf.OpDefLib._apply_op_helper("BiasAddV1", name: name, keywords: dict); + return op.output; + } + + /// + /// Counts the number of occurrences of each value in an integer array. + /// + /// + /// int32 Tensor. + /// + /// + /// non-negative int32 scalar Tensor. + /// + /// + /// is an int32, int64, float32, or float64 Tensor with the same + /// shape as arr, or a length-0 Tensor, in which case it acts as all weights + /// equal to 1. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Bincount'. + /// + /// + /// 1D Tensor with length equal to size. The counts or summed weights for + /// each value in the range [0, size). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs a vector with length size and the same dtype as weights. If + /// weights are empty, then index i stores the number of times the value i is + /// counted in arr. If weights are non-empty, then index i stores the sum of + /// the value in weights at each index where the corresponding value in arr is + /// i. + /// + /// Values in arr outside of the range [0, size) are ignored. + /// + public static Tensor bincount(Tensor arr, Tensor size, Tensor weights, string name = "Bincount") + { + var dict = new Dictionary(); + dict["arr"] = arr; + dict["size"] = size; + dict["weights"] = weights; + var op = tf.OpDefLib._apply_op_helper("Bincount", name: name, keywords: dict); + return op.output; + } + + /// + /// Bitcasts a tensor from one type to another without copying data. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Bitcast'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input, this operation returns a tensor that has the same buffer + /// data as input with datatype type. + /// + /// If the input datatype T is larger than the output datatype type then the + /// shape changes from [...] to [..., sizeof(T)/sizeof(type)]. + /// + /// If T is smaller than type, the operator requires that the rightmost + /// dimension be equal to sizeof(type)/sizeof(T). The shape then goes from + /// [..., sizeof(type)/sizeof(T)] to [...]. + /// + /// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different + /// endian orderings will give different results. + /// + public static Tensor bitcast(Tensor input, TF_DataType type, string name = "Bitcast") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["type"] = type; + var op = tf.OpDefLib._apply_op_helper("Bitcast", name: name, keywords: dict); + return op.output; + } + + /// + /// Elementwise computes the bitwise AND of x and y. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BitwiseAnd'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The result will have those bits set, that are set in both x and y. The + /// computation is performed on the underlying representations of x and y. + /// + public static Tensor bitwise_and(Tensor x, Tensor y, string name = "BitwiseAnd") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("BitwiseAnd", name: name, keywords: dict); + return op.output; + } + + /// + /// Elementwise computes the bitwise OR of x and y. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BitwiseOr'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The result will have those bits set, that are set in x, y or both. The + /// computation is performed on the underlying representations of x and y. + /// + public static Tensor bitwise_or(Tensor x, Tensor y, string name = "BitwiseOr") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("BitwiseOr", name: name, keywords: dict); + return op.output; + } + + /// + /// Elementwise computes the bitwise XOR of x and y. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BitwiseXor'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The result will have those bits set, that are different in x and y. The + /// computation is performed on the underlying representations of x and y. + /// + public static Tensor bitwise_xor(Tensor x, Tensor y, string name = "BitwiseXor") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("BitwiseXor", name: name, keywords: dict); + return op.output; + } + + /// + /// Calculates gains for each feature and returns the best possible split information for the feature. + /// + /// + /// A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within stats_summary_list. The nodes are iterated between the two nodes specified by the tensor, as like for node_id in range(node_id_range[0], node_id_range[1]) (Note that the last index node_id_range[1] is exclusive). + /// + /// + /// A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. + /// + /// + /// l1 regularization factor on leaf weights, per instance based. + /// + /// + /// l2 regularization factor on leaf weights, per instance based. + /// + /// + /// adjustment to the gain, per leaf based. + /// + /// + /// mininum avg of hessians in a node before required for the node to be considered for splitting. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesCalculateBestGainsPerFeature'. + /// + /// + /// Optional argument + /// the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// node_ids_list : An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes. + /// gains_list : An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes. + /// thresholds_list : An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes. + /// left_node_contribs_list : A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes. + /// right_node_contribs_list : A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. + /// + /// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return node_ids_list for each feature, containing the list of nodes that this feature can be used to split. + /// + /// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). + /// + /// The length of output lists are all of the same length, num_features. + /// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. + /// + public static (Tensor[] node_ids_list, Tensor[] gains_list, Tensor[] thresholds_list, Tensor[] left_node_contribs_list, Tensor[] right_node_contribs_list) boosted_trees_calculate_best_gains_per_feature(Tensor node_id_range, Tensor[] stats_summary_list, Tensor l1, Tensor l2, Tensor tree_complexity, Tensor min_node_weight, int max_splits, string name = "BoostedTreesCalculateBestGainsPerFeature") + { + var dict = new Dictionary(); + dict["node_id_range"] = node_id_range; + dict["stats_summary_list"] = stats_summary_list; + dict["l1"] = l1; + dict["l2"] = l2; + dict["tree_complexity"] = tree_complexity; + dict["min_node_weight"] = min_node_weight; + dict["max_splits"] = max_splits; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesCalculateBestGainsPerFeature", name: name, keywords: dict); + int _idx = 0; + var node_ids_list = Enumerable.Range(0, op.OutputListLength("node_ids_list")).Select(_ => op.outputs[_idx++]).ToArray(); + var gains_list = Enumerable.Range(0, op.OutputListLength("gains_list")).Select(_ => op.outputs[_idx++]).ToArray(); + var thresholds_list = Enumerable.Range(0, op.OutputListLength("thresholds_list")).Select(_ => op.outputs[_idx++]).ToArray(); + var left_node_contribs_list = Enumerable.Range(0, op.OutputListLength("left_node_contribs_list")).Select(_ => op.outputs[_idx++]).ToArray(); + var right_node_contribs_list = Enumerable.Range(0, op.OutputListLength("right_node_contribs_list")).Select(_ => op.outputs[_idx++]).ToArray(); + return (node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list); + } + + /// + /// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering. + /// + /// + /// Handle to the tree ensemble. + /// + /// + /// A tensor with shape=[logits_dimension] with mean of gradients for a first node. + /// + /// + /// A tensor with shape=[logits_dimension] mean of hessians for a first node. + /// + /// + /// l1 regularization factor on leaf weights, per instance based. + /// + /// + /// l2 regularization factor on leaf weights, per instance based. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesCenterBias'. + /// + /// + /// Bool, whether to continue bias centering. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor boosted_trees_center_bias(Tensor tree_ensemble_handle, Tensor mean_gradients, Tensor mean_hessians, Tensor l1, Tensor l2, string name = "BoostedTreesCenterBias") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["mean_gradients"] = mean_gradients; + dict["mean_hessians"] = mean_hessians; + dict["l1"] = l1; + dict["l2"] = l2; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesCenterBias", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a tree ensemble model and returns a handle to it. + /// + /// + /// Handle to the tree ensemble resource to be created. + /// + /// + /// Token to use as the initial value of the resource stamp. + /// + /// + /// Serialized proto of the tree ensemble. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesCreateEnsemble'. + /// + /// + /// Returns the description of the operation + /// + public static Operation boosted_trees_create_ensemble(Tensor tree_ensemble_handle, Tensor stamp_token, Tensor tree_ensemble_serialized, string name = "BoostedTreesCreateEnsemble") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["stamp_token"] = stamp_token; + dict["tree_ensemble_serialized"] = tree_ensemble_serialized; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesCreateEnsemble", name: name, keywords: dict); + return op; + } + + /// + /// Deserializes a serialized tree ensemble config and replaces current tree + /// + /// + /// Handle to the tree ensemble. + /// + /// + /// Token to use as the new value of the resource stamp. + /// + /// + /// Serialized proto of the ensemble. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesDeserializeEnsemble'. + /// + /// + /// Returns the description of the operation + /// + /// + /// ensemble. + /// + public static Operation boosted_trees_deserialize_ensemble(Tensor tree_ensemble_handle, Tensor stamp_token, Tensor tree_ensemble_serialized, string name = "BoostedTreesDeserializeEnsemble") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["stamp_token"] = stamp_token; + dict["tree_ensemble_serialized"] = tree_ensemble_serialized; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesDeserializeEnsemble", name: name, keywords: dict); + return op; + } + + /// + /// Creates a handle to a BoostedTreesEnsembleResource + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesEnsembleResourceHandleOp'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor boosted_trees_ensemble_resource_handle_op(string container = null, string shared_name = null, string name = "BoostedTreesEnsembleResourceHandleOp") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesEnsembleResourceHandleOp", name: name, keywords: dict); + return op.output; + } + + /// + /// Debugging/model interpretability outputs for each example. + /// + /// + /// + /// + /// A list of rank 1 Tensors containing bucket id for each + /// feature. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesExampleDebugOutputs'. + /// + /// + /// Optional argument + /// scalar, dimension of the logits, to be used for constructing the protos in + /// examples_debug_outputs_serialized. + /// + /// + /// Output rank 1 Tensor containing a proto serialized as a string for each example. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It traverses all the trees and computes debug metrics for individual examples, + /// such as getting split feature ids and logits after each split along the decision + /// path used to compute directional feature contributions. + /// + public static Tensor boosted_trees_example_debug_outputs(Tensor tree_ensemble_handle, Tensor[] bucketized_features, int logits_dimension, string name = "BoostedTreesExampleDebugOutputs") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["bucketized_features"] = bucketized_features; + dict["logits_dimension"] = logits_dimension; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesExampleDebugOutputs", name: name, keywords: dict); + return op.output; + } + + /// + /// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. + /// + /// + /// Handle to the tree ensemble. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesGetEnsembleStates'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// stamp_token : Stamp token of the tree ensemble resource. + /// num_trees : The number of trees in the tree ensemble resource. + /// num_finalized_trees : The number of trees that were finished successfully. + /// num_attempted_layers : The number of layers we attempted to build (but not necessarily succeeded). + /// last_layer_nodes_range : Rank size 2 tensor that contains start and end ids of the nodes in the latest + /// layer. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor stamp_token, Tensor num_trees, Tensor num_finalized_trees, Tensor num_attempted_layers, Tensor last_layer_nodes_range) boosted_trees_get_ensemble_states(Tensor tree_ensemble_handle, string name = "BoostedTreesGetEnsembleStates") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesGetEnsembleStates", name: name, keywords: dict); + int _idx = 0; + var stamp_token = op.outputs[_idx++]; + var num_trees = op.outputs[_idx++]; + var num_finalized_trees = op.outputs[_idx++]; + var num_attempted_layers = op.outputs[_idx++]; + var last_layer_nodes_range = op.outputs[_idx++]; + return (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range); + } + + /// + /// Makes the summary of accumulated stats for the batch. + /// + /// + /// int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer. + /// + /// + /// float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients. + /// + /// + /// float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians. + /// + /// + /// int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesMakeStatsSummary'. + /// + /// + /// Optional argument + /// int; the maximum number of splits possible in the whole tree. + /// + /// + /// Optional argument + /// int; equals to the maximum possible value of bucketized feature. + /// + /// + /// output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example. + /// + public static Tensor boosted_trees_make_stats_summary(Tensor node_ids, Tensor gradients, Tensor hessians, Tensor[] bucketized_features_list, int max_splits, int num_buckets, string name = "BoostedTreesMakeStatsSummary") + { + var dict = new Dictionary(); + dict["node_ids"] = node_ids; + dict["gradients"] = gradients; + dict["hessians"] = hessians; + dict["bucketized_features_list"] = bucketized_features_list; + dict["max_splits"] = max_splits; + dict["num_buckets"] = num_buckets; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesMakeStatsSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Runs multiple additive regression ensemble predictors on input instances and + /// + /// + /// + /// + /// A list of rank 1 Tensors containing bucket id for each + /// feature. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesPredict'. + /// + /// + /// Optional argument + /// scalar, dimension of the logits, to be used for partial logits + /// shape. + /// + /// + /// Output rank 2 Tensor containing logits for each example. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// computes the logits. It is designed to be used during prediction. + /// It traverses all the trees and calculates the final score for each instance. + /// + public static Tensor boosted_trees_predict(Tensor tree_ensemble_handle, Tensor[] bucketized_features, int logits_dimension, string name = "BoostedTreesPredict") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["bucketized_features"] = bucketized_features; + dict["logits_dimension"] = logits_dimension; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesPredict", name: name, keywords: dict); + return op.output; + } + + /// + /// Serializes the tree ensemble to a proto. + /// + /// + /// Handle to the tree ensemble. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesSerializeEnsemble'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// stamp_token : Stamp token of the tree ensemble resource. + /// tree_ensemble_serialized : Serialized proto of the ensemble. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor stamp_token, Tensor tree_ensemble_serialized) boosted_trees_serialize_ensemble(Tensor tree_ensemble_handle, string name = "BoostedTreesSerializeEnsemble") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesSerializeEnsemble", name: name, keywords: dict); + int _idx = 0; + var stamp_token = op.outputs[_idx++]; + var tree_ensemble_serialized = op.outputs[_idx++]; + return (stamp_token, tree_ensemble_serialized); + } + + /// + /// Runs multiple additive regression ensemble predictors on input instances and + /// + /// + /// + /// + /// Rank 1 Tensor containing cached tree ids which is the starting + /// tree of prediction. + /// + /// + /// Rank 1 Tensor containing cached node id which is the starting + /// node of prediction. + /// + /// + /// A list of rank 1 Tensors containing bucket id for each + /// feature. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesTrainingPredict'. + /// + /// + /// Optional argument + /// scalar, dimension of the logits, to be used for partial logits + /// shape. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// partial_logits : Rank 2 Tensor containing logits update (with respect to cached + /// values stored) for each example. + /// tree_ids : Rank 1 Tensor containing new tree ids for each example. + /// node_ids : Rank 1 Tensor containing new node ids in the new tree_ids. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// computes the update to cached logits. It is designed to be used during training. + /// It traverses the trees starting from cached tree id and cached node id and + /// calculates the updates to be pushed to the cache. + /// + public static (Tensor partial_logits, Tensor tree_ids, Tensor node_ids) boosted_trees_training_predict(Tensor tree_ensemble_handle, Tensor cached_tree_ids, Tensor cached_node_ids, Tensor[] bucketized_features, int logits_dimension, string name = "BoostedTreesTrainingPredict") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["cached_tree_ids"] = cached_tree_ids; + dict["cached_node_ids"] = cached_node_ids; + dict["bucketized_features"] = bucketized_features; + dict["logits_dimension"] = logits_dimension; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesTrainingPredict", name: name, keywords: dict); + int _idx = 0; + var partial_logits = op.outputs[_idx++]; + var tree_ids = op.outputs[_idx++]; + var node_ids = op.outputs[_idx++]; + return (partial_logits, tree_ids, node_ids); + } + + /// + /// Updates the tree ensemble by either adding a layer to the last tree being grown + /// + /// + /// Handle to the ensemble variable. + /// + /// + /// Rank 1 tensor with ids for each feature. This is the real id of + /// the feature that will be used in the split. + /// + /// + /// List of rank 1 tensors representing the nodes for which this feature + /// has a split. + /// + /// + /// List of rank 1 tensors representing the gains for each of the feature's + /// split. + /// + /// + /// List of rank 1 tensors representing the thesholds for each of the + /// feature's split. + /// + /// + /// List of rank 2 tensors with left leaf contribs for each of + /// the feature's splits. Will be added to the previous node values to constitute + /// the values of the left nodes. + /// + /// + /// List of rank 2 tensors with right leaf contribs for each + /// of the feature's splits. Will be added to the previous node values to constitute + /// the values of the right nodes. + /// + /// + /// Max depth of the tree to build. + /// + /// + /// shrinkage const for each new tree. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BoostedTreesUpdateEnsemble'. + /// + /// + /// Optional argument + /// 0-No pruning, 1-Pre-pruning, 2-Post-pruning. + /// + /// + /// Returns the description of the operation + /// + /// + /// or by starting a new tree. + /// + public static Operation boosted_trees_update_ensemble(Tensor tree_ensemble_handle, Tensor feature_ids, Tensor[] node_ids, Tensor[] gains, Tensor[] thresholds, Tensor[] left_node_contribs, Tensor[] right_node_contribs, Tensor max_depth, Tensor learning_rate, int pruning_mode, string name = "BoostedTreesUpdateEnsemble") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + dict["feature_ids"] = feature_ids; + dict["node_ids"] = node_ids; + dict["gains"] = gains; + dict["thresholds"] = thresholds; + dict["left_node_contribs"] = left_node_contribs; + dict["right_node_contribs"] = right_node_contribs; + dict["max_depth"] = max_depth; + dict["learning_rate"] = learning_rate; + dict["pruning_mode"] = pruning_mode; + var op = tf.OpDefLib._apply_op_helper("BoostedTreesUpdateEnsemble", name: name, keywords: dict); + return op; + } + + /// + /// Return the shape of s0 op s1 with broadcast. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BroadcastArgs'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given s0 and s1, tensors that represent shapes, compute r0, the + /// broadcasted shape. s0, s1 and r0 are all integer vectors. + /// + public static Tensor broadcast_args(Tensor s0, Tensor s1, string name = "BroadcastArgs") + { + var dict = new Dictionary(); + dict["s0"] = s0; + dict["s1"] = s1; + var op = tf.OpDefLib._apply_op_helper("BroadcastArgs", name: name, keywords: dict); + return op.output; + } + + /// + /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BroadcastGradientArgs'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// r0 : + /// r1 : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This is typically used by gradient computations for a broadcasting operation. + /// + public static (Tensor r0, Tensor r1) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "BroadcastGradientArgs") + { + var dict = new Dictionary(); + dict["s0"] = s0; + dict["s1"] = s1; + var op = tf.OpDefLib._apply_op_helper("BroadcastGradientArgs", name: name, keywords: dict); + int _idx = 0; + var r0 = op.outputs[_idx++]; + var r1 = op.outputs[_idx++]; + return (r0, r1); + } + + /// + /// Broadcast an array for a compatible shape. + /// + /// + /// A Tensor to broadcast. + /// + /// + /// An 1-D int Tensor. The shape of the desired output. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BroadcastTo'. + /// + /// + /// A Tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Broadcasting is the process of making arrays to have compatible shapes + /// for arithmetic operations. Two shapes are compatible if for each + /// dimension pair they are either equal or one of them is one. When trying + /// to broadcast a Tensor to a shape, it starts with the trailing dimensions, + /// and works its way forward. + /// + /// For example, + /// + /// &gt;&gt;&gt; x = tf.constant([1, 2, 3]) + /// &gt;&gt;&gt; y = tf.broadcast_to(x, [3, 3]) + /// &gt;&gt;&gt; sess.run(y) + /// array([[1, 2, 3], + /// [1, 2, 3], + /// [1, 2, 3]], dtype=int32) + /// + /// In the above example, the input Tensor with the shape of [1, 3] + /// is broadcasted to output Tensor with shape of [3, 3]. + /// + public static Tensor broadcast_to(Tensor input, Tensor shape, string name = "BroadcastTo") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("BroadcastTo", name: name, keywords: dict); + return op.output; + } + + /// + /// Bucketizes 'input' based on 'boundaries'. + /// + /// + /// Any shape of Tensor contains with int or float type. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Bucketize'. + /// + /// + /// Optional argument + /// A sorted list of floats gives the boundary of the buckets. + /// + /// + /// Same shape with 'input', each value of input replaced with bucket index. + /// + /// @compatibility(numpy) + /// Equivalent to np.digitize. + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For example, if the inputs are + /// boundaries = [0, 10, 100] + /// input = [[-5, 10000] + /// [150, 10] + /// [5, 100]] + /// + /// then the output will be + /// output = [[0, 3] + /// [3, 2] + /// [1, 3]] + /// + public static Tensor bucketize(Tensor input, float[] boundaries, string name = "Bucketize") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["boundaries"] = boundaries; + var op = tf.OpDefLib._apply_op_helper("Bucketize", name: name, keywords: dict); + return op.output; + } + + /// + /// Records the bytes size of each element of input_dataset in a StatsAggregator. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'BytesProducedStatsDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor bytes_produced_stats_dataset(Tensor input_dataset, Tensor tag, TF_DataType[] output_types, Shape[] output_shapes, string name = "BytesProducedStatsDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["tag"] = tag; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("BytesProducedStatsDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs beam search decoding on the logits given in input. + /// + /// + /// 3-D, shape: (max_time x batch_size x num_classes), the logits. + /// + /// + /// A vector containing sequence lengths, size (batch). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCBeamSearchDecoder'. + /// + /// + /// Optional argument + /// A scalar &gt;= 0 (beam search beam width). + /// + /// + /// Optional argument + /// A scalar &gt;= 0, &lt;= beam_width (controls output size). + /// + /// + /// If true, merge repeated classes in output. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// decoded_indices : A list (length: top_paths) of indices matrices. Matrix j, + /// size (total_decoded_outputs[j] x 2), has indices of a + /// SparseTensor&lt;int64, 2&gt;. The rows store: [batch, time]. + /// decoded_values : A list (length: top_paths) of values vectors. Vector j, + /// size (length total_decoded_outputs[j]), has the values of a + /// SparseTensor&lt;int64, 2&gt;. The vector stores the decoded classes for beam j. + /// decoded_shape : A list (length: top_paths) of shape vector. Vector j, + /// size (2), stores the shape of the decoded SparseTensor[j]. + /// Its values are: [batch_size, max_decoded_length[j]]. + /// log_probability : A matrix, shaped: (batch_size x top_paths). The + /// sequence log-probabilities. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// A note about the attribute merge_repeated: For the beam search decoder, + /// this means that if consecutive entries in a beam are the same, only + /// the first of these is emitted. That is, when the top path is "A B B B B", + /// "A B" is returned if merge_repeated = True but "A B B B B" is + /// returned if merge_repeated = False. + /// + public static (Tensor[] decoded_indices, Tensor[] decoded_values, Tensor[] decoded_shape, Tensor log_probability) c_t_c_beam_search_decoder(Tensor inputs, Tensor sequence_length, int beam_width, int top_paths, bool? merge_repeated = null, string name = "CTCBeamSearchDecoder") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["sequence_length"] = sequence_length; + dict["beam_width"] = beam_width; + dict["top_paths"] = top_paths; + if (merge_repeated.HasValue) + dict["merge_repeated"] = merge_repeated.Value; + var op = tf.OpDefLib._apply_op_helper("CTCBeamSearchDecoder", name: name, keywords: dict); + int _idx = 0; + var decoded_indices = Enumerable.Range(0, op.OutputListLength("decoded_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var decoded_values = Enumerable.Range(0, op.OutputListLength("decoded_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var decoded_shape = Enumerable.Range(0, op.OutputListLength("decoded_shape")).Select(_ => op.outputs[_idx++]).ToArray(); + var log_probability = op.outputs[_idx++]; + return (decoded_indices, decoded_values, decoded_shape, log_probability); + } + + /// + /// Performs greedy decoding on the logits given in inputs. + /// + /// + /// 3-D, shape: (max_time x batch_size x num_classes), the logits. + /// + /// + /// A vector containing sequence lengths, size (batch_size). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCGreedyDecoder'. + /// + /// + /// If True, merge repeated classes in output. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// decoded_indices : Indices matrix, size (total_decoded_outputs x 2), + /// of a SparseTensor&lt;int64, 2&gt;. The rows store: [batch, time]. + /// decoded_values : Values vector, size: (total_decoded_outputs), + /// of a SparseTensor&lt;int64, 2&gt;. The vector stores the decoded classes. + /// decoded_shape : Shape vector, size (2), of the decoded SparseTensor. + /// Values are: [batch_size, max_decoded_length]. + /// log_probability : Matrix, size (batch_size x 1), containing sequence + /// log-probabilities. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// A note about the attribute merge_repeated: if enabled, when + /// consecutive logits' maximum indices are the same, only the first of + /// these is emitted. Labeling the blank '*', the sequence "A B B * B B" + /// becomes "A B B" if merge_repeated = True and "A B B B B" if + /// merge_repeated = False. + /// + /// Regardless of the value of merge_repeated, if the maximum index of a given + /// time and batch corresponds to the blank, index (num_classes - 1), no new + /// element is emitted. + /// + public static (Tensor decoded_indices, Tensor decoded_values, Tensor decoded_shape, Tensor log_probability) c_t_c_greedy_decoder(Tensor inputs, Tensor sequence_length, bool? merge_repeated = null, string name = "CTCGreedyDecoder") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["sequence_length"] = sequence_length; + if (merge_repeated.HasValue) + dict["merge_repeated"] = merge_repeated.Value; + var op = tf.OpDefLib._apply_op_helper("CTCGreedyDecoder", name: name, keywords: dict); + int _idx = 0; + var decoded_indices = op.outputs[_idx++]; + var decoded_values = op.outputs[_idx++]; + var decoded_shape = op.outputs[_idx++]; + var log_probability = op.outputs[_idx++]; + return (decoded_indices, decoded_values, decoded_shape, log_probability); + } + + /// + /// Calculates the CTC Loss (log probability) for each batch entry. Also calculates + /// + /// + /// 3-D, shape: (max_time x batch_size x num_classes), the logits. + /// + /// + /// The indices of a SparseTensor&lt;int32, 2&gt;. + /// labels_indices(i, :) == [b, t] means labels_values(i) stores the id for + /// (batch b, time t). + /// + /// + /// The values (labels) associated with the given batch and time. + /// + /// + /// A vector containing sequence lengths (batch). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CTCLoss'. + /// + /// + /// Scalar, if true then repeated labels are + /// collapsed prior to the CTC calculation. + /// + /// + /// Scalar. If set to false, *during* CTC calculation + /// repeated non-blank labels will not be merged and are interpreted as + /// individual labels. This is a simplified version of CTC. + /// + /// + /// Scalar. If set to true, during CTC + /// calculation, items that have longer output sequences than input sequences + /// are skipped: they don't contribute to the loss term and have zero-gradient. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// loss : A vector (batch) containing log-probabilities. + /// gradient : The gradient of loss. 3-D, shape: + /// (max_time x batch_size x num_classes). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// the gradient. This class performs the softmax operation for you, so inputs + /// should be e.g. linear projections of outputs by an LSTM. + /// + public static (Tensor loss, Tensor gradient) c_t_c_loss(Tensor inputs, Tensor labels_indices, Tensor labels_values, Tensor sequence_length, bool? preprocess_collapse_repeated = null, bool? ctc_merge_repeated = null, bool? ignore_longer_outputs_than_inputs = null, string name = "CTCLoss") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["labels_indices"] = labels_indices; + dict["labels_values"] = labels_values; + dict["sequence_length"] = sequence_length; + if (preprocess_collapse_repeated.HasValue) + dict["preprocess_collapse_repeated"] = preprocess_collapse_repeated.Value; + if (ctc_merge_repeated.HasValue) + dict["ctc_merge_repeated"] = ctc_merge_repeated.Value; + if (ignore_longer_outputs_than_inputs.HasValue) + dict["ignore_longer_outputs_than_inputs"] = ignore_longer_outputs_than_inputs.Value; + var op = tf.OpDefLib._apply_op_helper("CTCLoss", name: name, keywords: dict); + int _idx = 0; + var loss = op.outputs[_idx++]; + var gradient = op.outputs[_idx++]; + return (loss, gradient); + } + + /// + /// Creates a dataset that caches elements from input_dataset. + /// + /// + /// + /// + /// A path on the filesystem where we should cache the dataset. Note: this + /// will be a directory. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CacheDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A CacheDataset will iterate over the input_dataset, and store tensors. If the + /// cache already exists, the cache will be used. If the cache is inappropriate + /// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error + /// will the returned when used. + /// + public static Tensor cache_dataset(Tensor input_dataset, Tensor filename, TF_DataType[] output_types, Shape[] output_shapes, string name = "CacheDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["filename"] = filename; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("CacheDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Cast x of type SrcT to y of DstT. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cast'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor cast(Tensor x, TF_DataType DstT, bool? Truncate = null, string name = "Cast") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["DstT"] = DstT; + if (Truncate.HasValue) + dict["Truncate"] = Truncate.Value; + var op = tf.OpDefLib._apply_op_helper("Cast", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns element-wise smallest integer not less than x. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Ceil'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ceil(Tensor x, string name = "Ceil") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Ceil", name: name, keywords: dict); + return op.output; + } + + /// + /// Checks a tensor for NaN and Inf values. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CheckNumerics'. + /// + /// + /// Optional argument + /// Prefix of the error message. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// When run, reports an InvalidArgument error if tensor has any values + /// that are not a number (NaN) or infinity (Inf). Otherwise, passes tensor as-is. + /// + public static Tensor check_numerics(Tensor tensor, string message, string name = "CheckNumerics") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["message"] = message; + var op = tf.OpDefLib._apply_op_helper("CheckNumerics", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the Cholesky decomposition of one or more square matrices. + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cholesky'. + /// + /// + /// Shape is [..., M, M]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices. + /// + /// The input has to be symmetric and positive definite. Only the lower-triangular + /// part of the input will be used for this operation. The upper-triangular part + /// will not be read. + /// + /// The output is a tensor of the same shape as the input + /// containing the Cholesky decompositions for all input submatrices [..., :, :]. + /// + /// **Note**: The gradient computation on GPU is faster for large matrices but + /// not for large batch dimensions when the submatrices are small. In this + /// case it might be faster to use the CPU. + /// + public static Tensor cholesky(Tensor input, string name = "Cholesky") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("Cholesky", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. + /// + /// + /// Output of batch Cholesky algorithm l = cholesky(A). Shape is [..., M, M]. + /// Algorithm depends only on lower triangular part of the innermost matrices of + /// this tensor. + /// + /// + /// df/dl where f is some scalar function. Shape is [..., M, M]. + /// Algorithm depends only on lower triangular part of the innermost matrices of + /// this tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CholeskyGrad'. + /// + /// + /// Symmetrized version of df/dA . Shape is [..., M, M] + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For an explanation see "Differentiation of the Cholesky algorithm" by + /// Iain Murray http://arxiv.org/abs/1602.07527. + /// + public static Tensor cholesky_grad(Tensor l, Tensor grad, string name = "CholeskyGrad") + { + var dict = new Dictionary(); + dict["l"] = l; + dict["grad"] = grad; + var op = tf.OpDefLib._apply_op_helper("CholeskyGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// A Tensor. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The minimum value to clip by. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The maximum value to clip by. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'. + /// + /// + /// A clipped Tensor with the same shape as input 't'. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor t, this operation returns a tensor of the same type and + /// shape as t with its values clipped to clip_value_min and clip_value_max. + /// Any values less than clip_value_min are set to clip_value_min. Any values + /// greater than clip_value_max are set to clip_value_max. + /// + public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") + { + var dict = new Dictionary(); + dict["t"] = t; + dict["clip_value_min"] = clip_value_min; + dict["clip_value_max"] = clip_value_max; + var op = tf.OpDefLib._apply_op_helper("ClipByValue", name: name, keywords: dict); + return op.output; + } + + /// + /// Receives a tensor value broadcast from another device. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CollectiveBcastRecv'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor collective_bcast_recv(TF_DataType T, int group_size, int group_key, int instance_key, Shape shape, string name = "CollectiveBcastRecv") + { + var dict = new Dictionary(); + dict["T"] = T; + dict["group_size"] = group_size; + dict["group_key"] = group_key; + dict["instance_key"] = instance_key; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("CollectiveBcastRecv", name: name, keywords: dict); + return op.output; + } + + /// + /// Broadcasts a tensor value to one or more other devices. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CollectiveBcastSend'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor collective_bcast_send(Tensor input, int group_size, int group_key, int instance_key, Shape shape, string name = "CollectiveBcastSend") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["group_size"] = group_size; + dict["group_key"] = group_key; + dict["instance_key"] = instance_key; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("CollectiveBcastSend", name: name, keywords: dict); + return op.output; + } + + /// + /// Mutually reduces multiple tensors of identical type and shape. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CollectiveReduce'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor collective_reduce(Tensor input, int group_size, int group_key, int instance_key, string merge_op, string final_op, int[] subdiv_offsets, string name = "CollectiveReduce") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["group_size"] = group_size; + dict["group_key"] = group_key; + dict["instance_key"] = instance_key; + dict["merge_op"] = merge_op; + dict["final_op"] = final_op; + dict["subdiv_offsets"] = subdiv_offsets; + var op = tf.OpDefLib._apply_op_helper("CollectiveReduce", name: name, keywords: dict); + return op.output; + } + + /// + /// Compare values of input to threshold and pack resulting bits into a uint8. + /// + /// + /// Values to compare against threshold and bitpack. + /// + /// + /// Threshold to compare against. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CompareAndBitpack'. + /// + /// + /// The bitpacked comparisons. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Each comparison returns a boolean true (if input_value &gt; threshold) + /// or and false otherwise. + /// + /// This operation is useful for Locality-Sensitive-Hashing (LSH) and other + /// algorithms that use hashing approximations of cosine and L2 distances; + /// codes can be generated from an input via: + /// + /// + /// codebook_size = 50 + /// codebook_bits = codebook_size * 32 + /// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], + /// dtype=x.dtype, + /// initializer=tf.orthogonal_initializer()) + /// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) + /// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 + /// # now codes has shape x.shape[:-1] + [codebook_size] + /// + /// + /// **NOTE**: Currently, the innermost dimension of the tensor must be divisible + /// by 8. + /// + /// Given an input shaped [s0, s1, ..., s_n], the output is + /// a uint8 tensor shaped [s0, s1, ..., s_n / 8]. + /// + public static Tensor compare_and_bitpack(Tensor input, Tensor threshold, string name = "CompareAndBitpack") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["threshold"] = threshold; + var op = tf.OpDefLib._apply_op_helper("CompareAndBitpack", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts two real numbers to a complex number. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Complex'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor real representing the real part of a complex number, and a + /// tensor imag representing the imaginary part of a complex number, this + /// operation returns complex numbers elementwise of the form \\(a + bj\\), where + /// *a* represents the real part and *b* represents the imag part. + /// + /// The input tensors real and imag must have the same shape. + /// + /// For example: + /// + /// + /// # tensor 'real' is [2.25, 3.25] + /// # tensor imag is [4.75, 5.75] + /// tf.complex(real, imag) ==&gt; [[2.25 + 4.75j], [3.25 + 5.75j]] + /// + /// + public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex") + { + TF_DataType Tin = real.GetDataType(); + if (a_Tout is null) + { + a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64; + } + return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout })); + } + + /// + /// Computes the complex absolute value of a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ComplexAbs'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor x of complex numbers, this operation returns a tensor of type + /// float or double that is the absolute value of each element in x. All + /// elements in x must be complex numbers of the form \\(a + bj\\). The absolute + /// value is computed as \\( \sqrt{a^2 + b^2}\\). + /// + public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs") + { + return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout })); + } + + /// + /// Computes the ids of the positions in sampled_candidates that match true_labels. + /// + /// + /// The true_classes output of UnpackSparseLabels. + /// + /// + /// The sampled_candidates output of CandidateSampler. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ComputeAccidentalHits'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// indices : A vector of indices corresponding to rows of true_candidates. + /// ids : A vector of IDs of positions in sampled_candidates that match a true_label + /// for the row with the corresponding index in indices. + /// weights : A vector of the same length as indices and ids, in which each element + /// is -FLOAT_MAX. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// When doing log-odds NCE, the result of this op should be passed through a + /// SparseToDense op, then added to the logits of the sampled candidates. This has + /// the effect of 'removing' the sampled labels that match the true labels by + /// making the classifier sure that they are sampled labels. + /// + public static (Tensor indices, Tensor ids, Tensor weights) compute_accidental_hits(Tensor true_classes, Tensor sampled_candidates, int num_true, int? seed = null, int? seed2 = null, string name = "ComputeAccidentalHits") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["sampled_candidates"] = sampled_candidates; + dict["num_true"] = num_true; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("ComputeAccidentalHits", name: name, keywords: dict); + int _idx = 0; + var indices = op.outputs[_idx++]; + var ids = op.outputs[_idx++]; + var weights = op.outputs[_idx++]; + return (indices, ids, weights); + } + + /// + /// Concatenates tensors along one dimension. + /// + /// + /// 0-D. The dimension along which to concatenate. Must be in the + /// range [0, rank(values)). + /// + /// + /// The N Tensors to concatenate. Their ranks and types must match, + /// and their sizes must match in all dimensions except concat_dim. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Concat'. + /// + /// + /// A Tensor with the concatenation of values stacked along the + /// concat_dim dimension. This tensor's shape matches that of values except + /// in concat_dim where it has the sum of the sizes. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor concat(Tensor concat_dim, Tensor[] values, string name = "Concat") + { + var dict = new Dictionary(); + dict["concat_dim"] = concat_dim; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("Concat", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes offsets of concat inputs within its output. + /// + /// + /// The dimension along which to concatenate. + /// + /// + /// The N int32 vectors representing shape of tensors being concatenated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConcatOffset'. + /// + /// + /// The N int32 vectors representing the starting offset + /// of input tensors within the concatenated output. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For example: + /// + /// + /// # 'x' is [2, 2, 7] + /// # 'y' is [2, 3, 7] + /// # 'z' is [2, 5, 7] + /// concat_offset(2, [x, y, z]) =&gt; [0, 0, 0], [0, 2, 0], [0, 5, 0] + /// + /// + /// This is typically used by gradient computations for a concat operation. + /// + public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = "ConcatOffset") + { + var dict = new Dictionary(); + dict["concat_dim"] = concat_dim; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("ConcatOffset", name: name, keywords: dict); + int _idx = 0; + var offset = Enumerable.Range(0, op.OutputListLength("offset")).Select(_ => op.outputs[_idx++]).ToArray(); + return (offset); + } + + /// + /// Concatenates tensors along one dimension. + /// + /// + /// List of N Tensors to concatenate. Their ranks and types must match, + /// and their sizes must match in all dimensions except concat_dim. + /// + /// + /// 0-D. The dimension along which to concatenate. Must be in the + /// range [-rank(values), rank(values)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConcatV2'. + /// + /// + /// A Tensor with the concatenation of values stacked along the + /// concat_dim dimension. This tensor's shape matches that of values except + /// in concat_dim where it has the sum of the sizes. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor concat_v2(Tensor[] values, Tensor axis, string name = "ConcatV2") + { + var dict = new Dictionary(); + dict["values"] = values; + dict["axis"] = axis; + var op = tf.OpDefLib._apply_op_helper("ConcatV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that concatenates input_dataset with another_dataset. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConcatenateDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor concatenate_dataset(Tensor input_dataset, Tensor another_dataset, TF_DataType[] output_types, Shape[] output_shapes, string name = "ConcatenateDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["another_dataset"] = another_dataset; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("ConcatenateDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// A conditional accumulator for aggregating gradients. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConditionalAccumulator'. + /// + /// + /// Optional argument + /// The type of the value being accumulated. + /// + /// + /// Optional argument + /// The shape of the values, can be [], in which case shape is unknown. + /// + /// + /// If non-empty, this accumulator is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this accumulator will be shared under the + /// given name across multiple sessions. + /// + /// + /// The handle to the accumulator. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The accumulator accepts gradients marked with local_step greater or + /// equal to the most recent global_step known to the accumulator. The + /// average can be extracted from the accumulator, provided sufficient + /// gradients have been accumulated. Extracting the average automatically + /// resets the aggregate to 0, and increments the global_step recorded by + /// the accumulator. + /// + public static Tensor conditional_accumulator(TF_DataType dtype, Shape shape, string container = null, string shared_name = null, string name = "ConditionalAccumulator") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("ConditionalAccumulator", name: name, keywords: dict); + return op.output; + } + + /// + /// An op that sets up the centralized structures for a distributed TPU + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConfigureDistributedTPU'. + /// + /// + /// Reserved. Do not use. + /// + /// + /// Serialized tensorflow.tpu.TPUEmbeddingConfiguration that + /// describes the embedding lookups of the program. + /// + /// + /// Reserved. Do not use. + /// + /// + /// A serialized tensorflow.tpu.TopologyProto that describes the TPU + /// topology. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// system. + /// + public static Tensor configure_distributed_t_p_u(string embedding_config = null, string tpu_embedding_config = null, bool? is_global_init = null, string name = "ConfigureDistributedTPU") + { + var dict = new Dictionary(); + if (embedding_config != null) + dict["embedding_config"] = embedding_config; + if (tpu_embedding_config != null) + dict["tpu_embedding_config"] = tpu_embedding_config; + if (is_global_init.HasValue) + dict["is_global_init"] = is_global_init.Value; + var op = tf.OpDefLib._apply_op_helper("ConfigureDistributedTPU", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the complex conjugate of a complex number. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conj'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input of complex numbers, this operation returns a tensor of + /// complex numbers that are the complex conjugate of each element in input. The + /// complex numbers in input must be of the form \\(a + bj\\), where *a* is the + /// real part and *b* is the imaginary part. + /// + /// The complex conjugate returned by this operation is of the form \\(a - bj\\). + /// + /// For example: + /// + /// + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.conj(input) ==&gt; [-2.25 - 4.75j, 3.25 - 5.75j] + /// + /// + public static Tensor conj(Tensor input, string name = "Conj") + { + return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input })); + } + + /// + /// Shuffle dimensions of x according to a permutation and conjugate the result. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConjugateTranspose'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The output y has the same rank as x. The shapes of x and y satisfy: + /// y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1] + /// y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]]) + /// + public static Tensor conjugate_transpose(Tensor x, Tensor perm, string name = "ConjugateTranspose") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["perm"] = perm; + var op = tf.OpDefLib._apply_op_helper("ConjugateTranspose", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a constant tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Const'. + /// + /// + /// Optional argument + /// Attr value is the tensor to return. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor constant(Tensor value, TF_DataType dtype, string name = "Const") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("Const", name: name, keywords: dict); + return op.output; + } + + /// + /// This op consumes a lock created by MutexLock. + /// + /// + /// A tensor returned by MutexLock. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ConsumeMutexLock'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This op exists to consume a tensor created by MutexLock (other than + /// direct control dependencies). It should be the only that consumes the tensor, + /// and will raise an error if it is not. Its only purpose is to keep the + /// mutex lock tensor alive until it is consumed by this op. + /// + /// **NOTE**: This operation must run on the same device as its input. This may + /// be enforced via the colocate_with mechanism. + /// + public static Operation consume_mutex_lock(Tensor mutex_lock, string name = "ConsumeMutexLock") + { + var dict = new Dictionary(); + dict["mutex_lock"] = mutex_lock; + var op = tf.OpDefLib._apply_op_helper("ConsumeMutexLock", name: name, keywords: dict); + return op; + } + + /// + /// Does nothing. Serves as a control trigger for scheduling. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ControlTrigger'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Only useful as a placeholder for control edges. + /// + public static Operation control_trigger(string name = "ControlTrigger") + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("ControlTrigger", name: name, keywords: dict); + return op; + } + + /// + /// Computes a 2-D convolution given 4-D input and filter tensors. + /// + /// + /// A 4-D tensor. The dimension order is interpreted according to the value + /// of data_format, see below for details. + /// + /// + /// A 4-D tensor of shape + /// [filter_height, filter_width, in_channels, out_channels] + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv2D'. + /// + /// + /// Optional argument + /// 1-D tensor of length 4. The stride of the sliding window for each + /// dimension of input. The dimension order is determined by the value of + /// data_format, see below for details. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of data_format, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// A 4-D tensor. The dimension order is determined by the value of + /// data_format, see below for details. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given an input tensor of shape [batch, in_height, in_width, in_channels] + /// and a filter / kernel tensor of shape + /// [filter_height, filter_width, in_channels, out_channels], this op + /// performs the following: + /// + /// 1. Flattens the filter to a 2-D matrix with shape + /// [filter_height * filter_width * in_channels, output_channels]. + /// 2. Extracts image patches from the input tensor to form a *virtual* + /// tensor of shape [batch, out_height, out_width, + /// filter_height * filter_width * in_channels]. + /// 3. For each patch, right-multiplies the filter matrix and the image patch + /// vector. + /// + /// In detail, with the default NHWC format, + /// + /// output[b, i, j, k] = + /// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * + /// filter[di, dj, q, k] + /// + /// Must have strides[0] = strides[3] = 1. For the most common case of the same + /// horizontal and vertices strides, strides = [1, stride, stride, 1]. + /// + public static Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool? use_cudnn_on_gpu = null, string data_format = null, int[] dilations = null, string name = "Conv2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["strides"] = strides; + dict["padding"] = padding; + if (use_cudnn_on_gpu.HasValue) + dict["use_cudnn_on_gpu"] = use_cudnn_on_gpu.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv2D", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of convolution with respect to the filter. + /// + /// + /// 4-D with shape [batch, in_height, in_width, in_channels]. + /// + /// + /// An integer vector representing the tensor shape of filter, + /// where filter is a 4-D + /// [filter_height, filter_width, in_channels, out_channels] tensor. + /// + /// + /// 4-D with shape [batch, out_height, out_width, out_channels]. + /// Gradients w.r.t. the output of the convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv2DBackpropFilter'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// of the convolution. Must be in the same order as the dimension specified with + /// format. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// data_format, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, out_channels]. Gradient w.r.t. + /// the filter input of the convolution. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, bool? use_cudnn_on_gpu = null, string data_format = null, int[] dilations = null, string name = "Conv2DBackpropFilter") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter_sizes"] = filter_sizes; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (use_cudnn_on_gpu.HasValue) + dict["use_cudnn_on_gpu"] = use_cudnn_on_gpu.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of convolution with respect to the input. + /// + /// + /// An integer vector representing the shape of input, + /// where input is a 4-D [batch, height, width, channels] tensor. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, out_channels]. + /// + /// + /// 4-D with shape [batch, out_height, out_width, out_channels]. + /// Gradients w.r.t. the output of the convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv2DBackpropInput'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// of the convolution. Must be in the same order as the dimension specified with + /// format. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// data_format, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// 4-D with shape [batch, in_height, in_width, in_channels]. Gradient + /// w.r.t. the input of the convolution. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, bool? use_cudnn_on_gpu = null, string data_format = null, int[] dilations = null, string name = "Conv2DBackpropInput") + { + var dict = new Dictionary(); + dict["input_sizes"] = input_sizes; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (use_cudnn_on_gpu.HasValue) + dict["use_cudnn_on_gpu"] = use_cudnn_on_gpu.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes a 3-D convolution given 5-D input and filter tensors. + /// + /// + /// Shape [batch, in_depth, in_height, in_width, in_channels]. + /// + /// + /// Shape [filter_depth, filter_height, filter_width, in_channels, + /// out_channels]. in_channels must match between input and filter. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv3D'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of data_format, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// In signal processing, cross-correlation is a measure of similarity of + /// two waveforms as a function of a time-lag applied to one of them. This + /// is also known as a sliding dot product or sliding inner-product. + /// + /// Our Conv3D implements a form of cross-correlation. + /// + public static Tensor conv3d(Tensor input, Tensor filter, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "Conv3D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv3D", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of 3-D convolution with respect to the filter. + /// + /// + /// Shape [batch, depth, rows, cols, in_channels]. + /// + /// + /// Shape [depth, rows, cols, in_channels, out_channels]. + /// in_channels must match between input and filter. + /// + /// + /// Backprop signal of shape [batch, out_depth, out_rows, out_cols, + /// out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv3DBackpropFilter'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv3d_backprop_filter(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations = null, string name = "Conv3DBackpropFilter") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv3DBackpropFilter", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of 3-D convolution with respect to the filter. + /// + /// + /// Shape [batch, depth, rows, cols, in_channels]. + /// + /// + /// An integer vector representing the tensor shape of filter, + /// where filter is a 5-D + /// [filter_depth, filter_height, filter_width, in_channels, out_channels] + /// tensor. + /// + /// + /// Backprop signal of shape [batch, out_depth, out_rows, out_cols, + /// out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv3DBackpropFilterV2'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of data_format, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv3d_backprop_filter_v2(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "Conv3DBackpropFilterV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter_sizes"] = filter_sizes; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv3DBackpropFilterV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of 3-D convolution with respect to the input. + /// + /// + /// Shape [batch, depth, rows, cols, in_channels]. + /// + /// + /// Shape [depth, rows, cols, in_channels, out_channels]. + /// in_channels must match between input and filter. + /// + /// + /// Backprop signal of shape [batch, out_depth, out_rows, out_cols, + /// out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv3DBackpropInput'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv3d_backprop_input(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, string padding, int[] dilations = null, string name = "Conv3DBackpropInput") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv3DBackpropInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of 3-D convolution with respect to the input. + /// + /// + /// An integer vector representing the tensor shape of input, + /// where input is a 5-D + /// [batch, depth, rows, cols, in_channels] tensor. + /// + /// + /// Shape [depth, rows, cols, in_channels, out_channels]. + /// in_channels must match between input and filter. + /// + /// + /// Backprop signal of shape [batch, out_depth, out_rows, out_cols, + /// out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Conv3DBackpropInputV2'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// 1-D tensor of length 5. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of data_format, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor conv3d_backprop_input_v2(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "Conv3DBackpropInputV2") + { + var dict = new Dictionary(); + dict["input_sizes"] = input_sizes; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("Conv3DBackpropInputV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Copy Op. + /// + /// + /// Input tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Copy'. + /// + /// + /// The name of the input tensor. + /// + /// + /// A list of debug op spec (op, url, gated_grpc) for attached debug + /// ops. Each element of the list has the format + /// &lt;debug_op&gt;;&lt;grpc_url&gt;;&lt;gated_grpc&gt;, wherein gated_grpc is boolean represented + /// as 0/1. E.g., "DebugIdentity;grpc://foo:3333;1", + /// "DebugIdentity;file:///tmp/tfdbg_1;0". + /// + /// + /// Output tensor, deep-copied from input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Performs CPU-to-CPU or GPU-to-GPU deep-copying of tensor, depending on the + /// device on which the tensor is allocated. + /// N.B.: If the all downstream attached debug ops are disabled given the current + /// gRPC gating status, the output will simply forward the input tensor without + /// deep-copying. See the documentation of Debug* ops for more details. + /// + /// Unlike the CopyHost Op, this op does not have HostMemory constraint on its + /// input or output. + /// + public static Tensor copy(Tensor input, string tensor_name = null, string[] debug_ops_spec = null, string name = "Copy") + { + var dict = new Dictionary(); + dict["input"] = input; + if (tensor_name != null) + dict["tensor_name"] = tensor_name; + if (debug_ops_spec != null) + dict["debug_ops_spec"] = debug_ops_spec; + var op = tf.OpDefLib._apply_op_helper("Copy", name: name, keywords: dict); + return op.output; + } + + /// + /// Copy Host Op. + /// + /// + /// Input tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CopyHost'. + /// + /// + /// The name of the input tensor. + /// + /// + /// A list of debug op spec (op, url, gated_grpc) for attached debug + /// ops. Each element of the list has the format + /// &lt;debug_op&gt;;&lt;grpc_url&gt;;&lt;gated_grpc&gt;, wherein gated_grpc is boolean represented + /// as 0/1. E.g., "DebugIdentity;grpc://foo:3333;1", + /// "DebugIdentity;file:///tmp/tfdbg_1;0". + /// + /// + /// Output tensor, deep-copied from input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Performs CPU-to-CPU deep-copying of tensor. + /// N.B.: If the all downstream attached debug ops are disabled given the current + /// gRPC gating status, the output will simply forward the input tensor without + /// deep-copying. See the documentation of Debug* ops for more details. + /// + /// Unlike the Copy Op, this op has HostMemory constraint on its input or output. + /// + public static Tensor copy_host(Tensor input, string tensor_name = null, string[] debug_ops_spec = null, string name = "CopyHost") + { + var dict = new Dictionary(); + dict["input"] = input; + if (tensor_name != null) + dict["tensor_name"] = tensor_name; + if (debug_ops_spec != null) + dict["debug_ops_spec"] = debug_ops_spec; + var op = tf.OpDefLib._apply_op_helper("CopyHost", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes cos of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cos'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor cos(Tensor x, string name = "Cos") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Cos", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes hyperbolic cosine of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cosh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor cosh(Tensor x, string name = "Cosh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Cosh", name: name, keywords: dict); + return op.output; + } + + /// + /// Increments 'ref' until it reaches 'limit'. + /// + /// + /// Should be from a scalar Variable node. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CountUpTo'. + /// + /// + /// Optional argument + /// If incrementing ref would bring it above limit, instead generates an + /// 'OutOfRange' error. + /// + /// + /// A copy of the input before increment. If nothing else modifies the + /// input, the values produced will all be distinct. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor count_up_to(Tensor referecne, int limit, string name = "CountUpTo") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["limit"] = limit; + var op = tf.OpDefLib._apply_op_helper("CountUpTo", name: name, keywords: dict); + return op.output; + } + + /// + /// Extracts crops from the input image tensor and resizes them. + /// + /// + /// A 4-D tensor of shape [batch, image_height, image_width, depth]. + /// Both image_height and image_width need to be positive. + /// + /// + /// A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor + /// specifies the coordinates of a box in the box_ind[i] image and is specified + /// in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of + /// y is mapped to the image coordinate at y * (image_height - 1), so as the + /// [0, 1] interval of normalized image height is mapped to + /// [0, image_height - 1] in image height coordinates. We do allow y1 &gt; y2, in + /// which case the sampled crop is an up-down flipped version of the original + /// image. The width dimension is treated similarly. Normalized coordinates + /// outside the [0, 1] range are allowed, in which case we use + /// extrapolation_value to extrapolate the input image values. + /// + /// + /// A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). + /// The value of box_ind[i] specifies the image that the i-th box refers to. + /// + /// + /// A 1-D tensor of 2 elements, size = [crop_height, crop_width]. All + /// cropped image patches are resized to this size. The aspect ratio of the image + /// content is not preserved. Both crop_height and crop_width need to be + /// positive. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CropAndResize'. + /// + /// + /// A string specifying the sampling method for resizing. It can be either + /// "bilinear" or "nearest" and default to "bilinear". Currently two sampling + /// methods are supported: Bilinear and Nearest Neighbor. + /// + /// + /// Value used for extrapolation, when applicable. + /// + /// + /// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Extracts crops from the input image tensor and resizes them using bilinear + /// sampling or nearest neighbor sampling (possibly with aspect ratio change) to a + /// common output size specified by crop_size. This is more general than the + /// crop_to_bounding_box op which extracts a fixed size slice from the input image + /// and does not allow resizing or aspect ratio change. + /// + /// Returns a tensor with crops from the input image at positions defined at the + /// bounding box locations in boxes. The cropped boxes are all resized (with + /// bilinear or nearest neighbor interpolation) to a fixed + /// size = [crop_height, crop_width]. The result is a 4-D tensor + /// [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned. + /// In particular, if boxes = [[0, 0, 1, 1]], the method will give identical + /// results to using tf.image.resize_bilinear() or + /// tf.image.resize_nearest_neighbor()(depends on the method argument) with + /// align_corners=True. + /// + public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = null, float? extrapolation_value = null, string name = "CropAndResize") + { + var dict = new Dictionary(); + dict["image"] = image; + dict["boxes"] = boxes; + dict["box_ind"] = box_ind; + dict["crop_size"] = crop_size; + if (method != null) + dict["method"] = method; + if (extrapolation_value.HasValue) + dict["extrapolation_value"] = extrapolation_value.Value; + var op = tf.OpDefLib._apply_op_helper("CropAndResize", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. + /// + /// + /// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]. + /// + /// + /// A 4-D tensor of shape [batch, image_height, image_width, depth]. + /// Both image_height and image_width need to be positive. + /// + /// + /// A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor + /// specifies the coordinates of a box in the box_ind[i] image and is specified + /// in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of + /// y is mapped to the image coordinate at y * (image_height - 1), so as the + /// [0, 1] interval of normalized image height is mapped to + /// [0, image_height - 1] in image height coordinates. We do allow y1 &gt; y2, in + /// which case the sampled crop is an up-down flipped version of the original + /// image. The width dimension is treated similarly. Normalized coordinates + /// outside the [0, 1] range are allowed, in which case we use + /// extrapolation_value to extrapolate the input image values. + /// + /// + /// A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). + /// The value of box_ind[i] specifies the image that the i-th box refers to. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CropAndResizeGradBoxes'. + /// + /// + /// A string specifying the interpolation method. Only 'bilinear' is + /// supported for now. + /// + /// + /// A 2-D tensor of shape [num_boxes, 4]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor crop_and_resize_grad_boxes(Tensor grads, Tensor image, Tensor boxes, Tensor box_ind, string method = null, string name = "CropAndResizeGradBoxes") + { + var dict = new Dictionary(); + dict["grads"] = grads; + dict["image"] = image; + dict["boxes"] = boxes; + dict["box_ind"] = box_ind; + if (method != null) + dict["method"] = method; + var op = tf.OpDefLib._apply_op_helper("CropAndResizeGradBoxes", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of the crop_and_resize op wrt the input image tensor. + /// + /// + /// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]. + /// + /// + /// A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor + /// specifies the coordinates of a box in the box_ind[i] image and is specified + /// in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of + /// y is mapped to the image coordinate at y * (image_height - 1), so as the + /// [0, 1] interval of normalized image height is mapped to + /// [0, image_height - 1] in image height coordinates. We do allow y1 &gt; y2, in + /// which case the sampled crop is an up-down flipped version of the original + /// image. The width dimension is treated similarly. Normalized coordinates + /// outside the [0, 1] range are allowed, in which case we use + /// extrapolation_value to extrapolate the input image values. + /// + /// + /// A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). + /// The value of box_ind[i] specifies the image that the i-th box refers to. + /// + /// + /// A 1-D tensor with value [batch, image_height, image_width, depth] + /// containing the original image size. Both image_height and image_width need + /// to be positive. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CropAndResizeGradImage'. + /// + /// + /// Optional argument + /// + /// + /// A string specifying the interpolation method. Only 'bilinear' is + /// supported for now. + /// + /// + /// A 4-D tensor of shape [batch, image_height, image_width, depth]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor crop_and_resize_grad_image(Tensor grads, Tensor boxes, Tensor box_ind, Tensor image_size, TF_DataType T, string method = null, string name = "CropAndResizeGradImage") + { + var dict = new Dictionary(); + dict["grads"] = grads; + dict["boxes"] = boxes; + dict["box_ind"] = box_ind; + dict["image_size"] = image_size; + dict["T"] = T; + if (method != null) + dict["method"] = method; + var op = tf.OpDefLib._apply_op_helper("CropAndResizeGradImage", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the pairwise cross product. + /// + /// + /// A tensor containing 3-element vectors. + /// + /// + /// Another tensor, of same type and shape as a. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cross'. + /// + /// + /// Pairwise cross product of the vectors in a and b. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// a and b must be the same shape; they can either be simple 3-element vectors, + /// or any shape where the innermost dimension is 3. In the latter case, each pair + /// of corresponding 3-element vectors is cross-multiplied independently. + /// + public static Tensor cross(Tensor a, Tensor b, string name = "Cross") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["b"] = b; + var op = tf.OpDefLib._apply_op_helper("Cross", name: name, keywords: dict); + return op.output; + } + + /// + /// An Op to sum inputs across replicated TPU instances. Each + /// + /// + /// The local input to the sum. + /// + /// + /// An int32 tensor with shape + /// [num_groups, num_replicas_per_group]. group_assignment[i] represents the + /// replica ids in the ith subgroup. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CrossReplicaSum'. + /// + /// + /// The sum of all the distributed inputs. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// instance supplies its own input. If group_assignment is empty, the output of + /// each is the sum of all the inputs, otherwise the output of each is the sum of + /// the inputs belonging to the same group. + /// + /// For example, suppose there are 8 TPU instances: [A, B, C, D, E, F, G, H]. + /// Passing group_assignment=[[0,2,4,6],[1,3,5,7]] sets A, C, E, G as group 0, + /// and B, D, F, H as group 1. Thus we get the outputs: + /// [A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]. + /// + public static Tensor cross_replica_sum(Tensor input, Tensor group_assignment, string name = "CrossReplicaSum") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["group_assignment"] = group_assignment; + var op = tf.OpDefLib._apply_op_helper("CrossReplicaSum", name: name, keywords: dict); + return op.output; + } + + /// + /// A RNN backed by cuDNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNN'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// output_h : + /// output_c : + /// reserve_space : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Computes the RNN from the input and initial states, with respect to the params + /// buffer. + /// + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicate whether there is a linear projection between the input and + /// the actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. Should be + /// "unidirectional" or "bidirectional". + /// dropout: Dropout probability. When set to 0., dropout is disabled. + /// seed: The 1st part of a seed to initialize dropout. + /// seed2: The 2nd part of a seed to initialize dropout. + /// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. + /// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, + /// num_units]. + /// input_c: For LSTM, a 3-D tensor with the shape of + /// [num_layer * dir, batch, num_units]. For other models, it is ignored. + /// params: A 1-D tensor that contains the weights and biases in an opaque layout. + /// The size must be created through CudnnRNNParamsSize, and initialized + /// separately. Note that they might not be compatible across different + /// generations. So it is a good idea to save and restore + /// output: A 3-D tensor with the shape of [seq_length, batch_size, + /// dir * num_units]. + /// output_h: The same shape has input_h. + /// output_c: The same shape as input_c for LSTM. An empty tensor for other models. + /// is_training: Indicates whether this operation is used for inferenece or + /// training. + /// reserve_space: An opaque tensor that can be used in backprop calculation. It + /// is only produced if is_training is false. + /// + public static (Tensor output, Tensor output_h, Tensor output_c, Tensor reserve_space) cudnn_r_n_n(Tensor input, Tensor input_h, Tensor input_c, Tensor parameters, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, bool? is_training = null, string name = "CudnnRNN") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_h"] = input_h; + dict["input_c"] = input_c; + dict["params"] = parameters; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNN", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_h = op.outputs[_idx++]; + var output_c = op.outputs[_idx++]; + var reserve_space = op.outputs[_idx++]; + return (output, output_h, output_c, reserve_space); + } + + /// + /// Backprop step of CudnnRNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNBackprop'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// input_backprop : + /// input_h_backprop : + /// input_c_backprop : + /// params_backprop : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Compute the backprop of both data and weights in a RNN. + /// + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicate whether there is a linear projection between the input and + /// the actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. Should be + /// "unidirectional" or "bidirectional". + /// dropout: Dropout probability. When set to 0., dropout is disabled. + /// seed: The 1st part of a seed to initialize dropout. + /// seed2: The 2nd part of a seed to initialize dropout. + /// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. + /// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, + /// num_units]. + /// input_c: For LSTM, a 3-D tensor with the shape of + /// [num_layer * dir, batch, num_units]. For other models, it is ignored. + /// params: A 1-D tensor that contains the weights and biases in an opaque layout. + /// The size must be created through CudnnRNNParamsSize, and initialized + /// separately. Note that they might not be compatible across different + /// generations. So it is a good idea to save and restore + /// output: A 3-D tensor with the shape of [seq_length, batch_size, + /// dir * num_units]. + /// output_h: The same shape has input_h. + /// output_c: The same shape as input_c for LSTM. An empty tensor for other models. + /// output_backprop: A 3-D tensor with the same shape as output in the forward pass. + /// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward + /// pass. + /// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward + /// pass. + /// reserve_space: The same reserve_space produced in for forward operation. + /// input_backprop: The backprop to input in the forward pass. Has the same shape + /// as input. + /// input_h_backprop: The backprop to input_h in the forward pass. Has the same + /// shape as input_h. + /// input_c_backprop: The backprop to input_c in the forward pass. Has the same + /// shape as input_c. + /// params_backprop: The backprop to the params buffer in the forward pass. Has the + /// same shape as params. + /// + public static (Tensor input_backprop, Tensor input_h_backprop, Tensor input_c_backprop, Tensor params_backprop) cudnn_r_n_n_backprop(Tensor input, Tensor input_h, Tensor input_c, Tensor parameters, Tensor output, Tensor output_h, Tensor output_c, Tensor output_backprop, Tensor output_h_backprop, Tensor output_c_backprop, Tensor reserve_space, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, string name = "CudnnRNNBackprop") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_h"] = input_h; + dict["input_c"] = input_c; + dict["params"] = parameters; + dict["output"] = output; + dict["output_h"] = output_h; + dict["output_c"] = output_c; + dict["output_backprop"] = output_backprop; + dict["output_h_backprop"] = output_h_backprop; + dict["output_c_backprop"] = output_c_backprop; + dict["reserve_space"] = reserve_space; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNBackprop", name: name, keywords: dict); + int _idx = 0; + var input_backprop = op.outputs[_idx++]; + var input_h_backprop = op.outputs[_idx++]; + var input_c_backprop = op.outputs[_idx++]; + var params_backprop = op.outputs[_idx++]; + return (input_backprop, input_h_backprop, input_c_backprop, params_backprop); + } + + /// + /// Backprop step of CudnnRNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNBackpropV2'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// input_backprop : + /// input_h_backprop : + /// input_c_backprop : + /// params_backprop : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Compute the backprop of both data and weights in a RNN. Takes an extra + /// "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN + /// cudnnRNNAlgo_t and cudnnMathType_t. + /// + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicates whether there is a linear projection between the input and + /// the actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. Should be + /// "unidirectional" or "bidirectional". + /// dropout: Dropout probability. When set to 0., dropout is disabled. + /// seed: The 1st part of a seed to initialize dropout. + /// seed2: The 2nd part of a seed to initialize dropout. + /// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. + /// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, + /// num_units]. + /// input_c: For LSTM, a 3-D tensor with the shape of + /// [num_layer * dir, batch, num_units]. For other models, it is ignored. + /// params: A 1-D tensor that contains the weights and biases in an opaque layout. + /// The size must be created through CudnnRNNParamsSize, and initialized + /// separately. Note that they might not be compatible across different + /// generations. So it is a good idea to save and restore + /// output: A 3-D tensor with the shape of [seq_length, batch_size, + /// dir * num_units]. + /// output_h: The same shape has input_h. + /// output_c: The same shape as input_c for LSTM. An empty tensor for other models. + /// output_backprop: A 3-D tensor with the same shape as output in the forward pass. + /// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward + /// pass. + /// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward + /// pass. + /// reserve_space: The same reserve_space produced in the forward operation. + /// host_reserved: The same host_reserved produced in the forward operation. + /// input_backprop: The backprop to input in the forward pass. Has the same shape + /// as input. + /// input_h_backprop: The backprop to input_h in the forward pass. Has the same + /// shape as input_h. + /// input_c_backprop: The backprop to input_c in the forward pass. Has the same + /// shape as input_c. + /// params_backprop: The backprop to the params buffer in the forward pass. Has the + /// same shape as params. + /// + public static (Tensor input_backprop, Tensor input_h_backprop, Tensor input_c_backprop, Tensor params_backprop) cudnn_r_n_n_backprop_v2(Tensor input, Tensor input_h, Tensor input_c, Tensor parameters, Tensor output, Tensor output_h, Tensor output_c, Tensor output_backprop, Tensor output_h_backprop, Tensor output_c_backprop, Tensor reserve_space, Tensor host_reserved, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, string name = "CudnnRNNBackpropV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_h"] = input_h; + dict["input_c"] = input_c; + dict["params"] = parameters; + dict["output"] = output; + dict["output_h"] = output_h; + dict["output_c"] = output_c; + dict["output_backprop"] = output_backprop; + dict["output_h_backprop"] = output_h_backprop; + dict["output_c_backprop"] = output_c_backprop; + dict["reserve_space"] = reserve_space; + dict["host_reserved"] = host_reserved; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNBackpropV2", name: name, keywords: dict); + int _idx = 0; + var input_backprop = op.outputs[_idx++]; + var input_h_backprop = op.outputs[_idx++]; + var input_c_backprop = op.outputs[_idx++]; + var params_backprop = op.outputs[_idx++]; + return (input_backprop, input_h_backprop, input_c_backprop, params_backprop); + } + + /// + /// Converts CudnnRNN params from canonical form to usable form. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNCanonicalToParams'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Writes a set of weights into the opaque params buffer so they can be used in + /// upcoming training or inferences. + /// + /// Note that the params buffer may not be compatible across different GPUs. So any + /// save and restoration should be converted to and from the canonical weights and + /// biases. + /// + /// num_layers: Specifies the number of layers in the RNN model. + /// num_units: Specifies the size of the hidden state. + /// input_size: Specifies the size of the input state. + /// weights: the canonical form of weights that can be used for saving + /// and restoration. They are more likely to be compatible across different + /// generations. + /// biases: the canonical form of biases that can be used for saving + /// and restoration. They are more likely to be compatible across different + /// generations. + /// num_params: number of parameter sets for all layers. + /// Each layer may contain multiple parameter sets, with each set consisting of + /// a weight matrix and a bias vector. + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicate whether there is a linear projection between the input and + /// The actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. + /// dir = (direction == bidirectional) ? 2 : 1 + /// dropout: dropout probability. When set to 0., dropout is disabled. + /// seed: the 1st part of a seed to initialize dropout. + /// seed2: the 2nd part of a seed to initialize dropout. + /// + public static Tensor cudnn_r_n_n_canonical_to_params(Tensor num_layers, Tensor num_units, Tensor input_size, Tensor[] weights, Tensor[] biases, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, string name = "CudnnRNNCanonicalToParams") + { + var dict = new Dictionary(); + dict["num_layers"] = num_layers; + dict["num_units"] = num_units; + dict["input_size"] = input_size; + dict["weights"] = weights; + dict["biases"] = biases; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNCanonicalToParams", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes size of weights that can be used by a Cudnn RNN model. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNParamsSize'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Return the params size that can be used by the Cudnn RNN model. Subsequent + /// weight allocation and initialization should use this size. + /// + /// num_layers: Specifies the number of layers in the RNN model. + /// num_units: Specifies the size of the hidden state. + /// input_size: Specifies the size of the input state. + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicate whether there is a linear projection between the input and + /// The actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. + /// dir = (direction == bidirectional) ? 2 : 1 + /// dropout: dropout probability. When set to 0., dropout is disabled. + /// seed: the 1st part of a seed to initialize dropout. + /// seed2: the 2nd part of a seed to initialize dropout. + /// params_size: The size of the params buffer that should be allocated and + /// initialized for this RNN model. Note that this params buffer may not be + /// compatible across GPUs. Please use CudnnRNNParamsWeights and + /// CudnnRNNParamsBiases to save and restore them in a way that is compatible + /// across different runs. + /// + public static Tensor cudnn_r_n_n_params_size(Tensor num_layers, Tensor num_units, Tensor input_size, TF_DataType T, TF_DataType S, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, string name = "CudnnRNNParamsSize") + { + var dict = new Dictionary(); + dict["num_layers"] = num_layers; + dict["num_units"] = num_units; + dict["input_size"] = input_size; + dict["T"] = T; + dict["S"] = S; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNParamsSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Retrieves CudnnRNN params in canonical form. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNParamsToCanonical'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// weights : + /// biases : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Retrieves a set of weights from the opaque params buffer that can be saved and + /// restored in a way compatible with future runs. + /// + /// Note that the params buffer may not be compatible across different GPUs. So any + /// save and restoration should be converted to and from the canonical weights and + /// biases. + /// + /// num_layers: Specifies the number of layers in the RNN model. + /// num_units: Specifies the size of the hidden state. + /// input_size: Specifies the size of the input state. + /// num_params: number of parameter sets for all layers. + /// Each layer may contain multiple parameter sets, with each set consisting of + /// a weight matrix and a bias vector. + /// weights: the canonical form of weights that can be used for saving + /// and restoration. They are more likely to be compatible across different + /// generations. + /// biases: the canonical form of biases that can be used for saving + /// and restoration. They are more likely to be compatible across different + /// generations. + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicate whether there is a linear projection between the input and + /// The actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. + /// dir = (direction == bidirectional) ? 2 : 1 + /// dropout: dropout probability. When set to 0., dropout is disabled. + /// seed: the 1st part of a seed to initialize dropout. + /// seed2: the 2nd part of a seed to initialize dropout. + /// + public static (Tensor[] weights, Tensor[] biases) cudnn_r_n_n_params_to_canonical(Tensor num_layers, Tensor num_units, Tensor input_size, Tensor parameters, int num_params, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, string name = "CudnnRNNParamsToCanonical") + { + var dict = new Dictionary(); + dict["num_layers"] = num_layers; + dict["num_units"] = num_units; + dict["input_size"] = input_size; + dict["params"] = parameters; + dict["num_params"] = num_params; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNParamsToCanonical", name: name, keywords: dict); + int _idx = 0; + var weights = Enumerable.Range(0, op.OutputListLength("weights")).Select(_ => op.outputs[_idx++]).ToArray(); + var biases = Enumerable.Range(0, op.OutputListLength("biases")).Select(_ => op.outputs[_idx++]).ToArray(); + return (weights, biases); + } + + /// + /// A RNN backed by cuDNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'CudnnRNNV2'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// output_h : + /// output_c : + /// reserve_space : + /// host_reserved : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Computes the RNN from the input and initial states, with respect to the params + /// buffer. Produces one extra output "host_reserved" than CudnnRNN. + /// + /// rnn_mode: Indicates the type of the RNN model. + /// input_mode: Indicates whether there is a linear projection between the input and + /// the actual computation before the first layer. 'skip_input' is only allowed + /// when input_size == num_units; 'auto_select' implies 'skip_input' when + /// input_size == num_units; otherwise, it implies 'linear_input'. + /// direction: Indicates whether a bidirectional model will be used. Should be + /// "unidirectional" or "bidirectional". + /// dropout: Dropout probability. When set to 0., dropout is disabled. + /// seed: The 1st part of a seed to initialize dropout. + /// seed2: The 2nd part of a seed to initialize dropout. + /// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. + /// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, + /// num_units]. + /// input_c: For LSTM, a 3-D tensor with the shape of + /// [num_layer * dir, batch, num_units]. For other models, it is ignored. + /// params: A 1-D tensor that contains the weights and biases in an opaque layout. + /// The size must be created through CudnnRNNParamsSize, and initialized + /// separately. Note that they might not be compatible across different + /// generations. So it is a good idea to save and restore + /// output: A 3-D tensor with the shape of [seq_length, batch_size, + /// dir * num_units]. + /// output_h: The same shape has input_h. + /// output_c: The same shape as input_c for LSTM. An empty tensor for other models. + /// is_training: Indicates whether this operation is used for inferenece or + /// training. + /// reserve_space: An opaque tensor that can be used in backprop calculation. It + /// is only produced if is_training is true. + /// host_reserved: An opaque tensor that can be used in backprop calculation. It is + /// only produced if is_training is true. It is output on host memory rather than + /// device memory. + /// + public static (Tensor output, Tensor output_h, Tensor output_c, Tensor reserve_space, Tensor host_reserved) cudnn_r_n_n_v2(Tensor input, Tensor input_h, Tensor input_c, Tensor parameters, string rnn_mode = null, string input_mode = null, string direction = null, float? dropout = null, int? seed = null, int? seed2 = null, bool? is_training = null, string name = "CudnnRNNV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_h"] = input_h; + dict["input_c"] = input_c; + dict["params"] = parameters; + if (rnn_mode != null) + dict["rnn_mode"] = rnn_mode; + if (input_mode != null) + dict["input_mode"] = input_mode; + if (direction != null) + dict["direction"] = direction; + if (dropout.HasValue) + dict["dropout"] = dropout.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("CudnnRNNV2", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_h = op.outputs[_idx++]; + var output_c = op.outputs[_idx++]; + var reserve_space = op.outputs[_idx++]; + var host_reserved = op.outputs[_idx++]; + return (output, output_h, output_c, reserve_space, host_reserved); + } + + /// + /// Compute the cumulative product of the tensor x along axis. + /// + /// + /// A Tensor. Must be one of the following types: float32, float64, + /// int64, int32, uint8, uint16, int16, int8, complex64, + /// complex128, qint8, quint8, qint32, half. + /// + /// + /// A Tensor of type int32 (default: 0). Must be in the range + /// [-rank(x), rank(x)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cumprod'. + /// + /// + /// If True, perform exclusive cumprod. + /// + /// + /// A bool (default: False). + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// By default, this op performs an inclusive cumprod, which means that the first + /// element of the input is identical to the first element of the output: + /// + /// + /// tf.cumprod([a, b, c]) # =&gt; [a, a * b, a * b * c] + /// + /// + /// By setting the exclusive kwarg to True, an exclusive cumprod is + /// performed instead: + /// + /// + /// tf.cumprod([a, b, c], exclusive=True) # =&gt; [1, a, a * b] + /// + /// + /// By setting the reverse kwarg to True, the cumprod is performed in the + /// opposite direction: + /// + /// + /// tf.cumprod([a, b, c], reverse=True) # =&gt; [a * b * c, b * c, c] + /// + /// + /// This is more efficient than using separate tf.reverse ops. + /// + /// The reverse and exclusive kwargs can also be combined: + /// + /// + /// tf.cumprod([a, b, c], exclusive=True, reverse=True) # =&gt; [b * c, c, 1] + /// + /// + public static Tensor cumprod(Tensor x, Tensor axis, bool? exclusive = null, bool? reverse = null, string name = "Cumprod") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["axis"] = axis; + if (exclusive.HasValue) + dict["exclusive"] = exclusive.Value; + if (reverse.HasValue) + dict["reverse"] = reverse.Value; + var op = tf.OpDefLib._apply_op_helper("Cumprod", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the cumulative sum of the tensor x along axis. + /// + /// + /// A Tensor. Must be one of the following types: float32, float64, + /// int64, int32, uint8, uint16, int16, int8, complex64, + /// complex128, qint8, quint8, qint32, half. + /// + /// + /// A Tensor of type int32 (default: 0). Must be in the range + /// [-rank(x), rank(x)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cumsum'. + /// + /// + /// If True, perform exclusive cumsum. + /// + /// + /// A bool (default: False). + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// By default, this op performs an inclusive cumsum, which means that the first + /// element of the input is identical to the first element of the output: + /// + /// + /// tf.cumsum([a, b, c]) # =&gt; [a, a + b, a + b + c] + /// + /// + /// By setting the exclusive kwarg to True, an exclusive cumsum is + /// performed instead: + /// + /// + /// tf.cumsum([a, b, c], exclusive=True) # =&gt; [0, a, a + b] + /// + /// + /// By setting the reverse kwarg to True, the cumsum is performed in the + /// opposite direction: + /// + /// + /// tf.cumsum([a, b, c], reverse=True) # =&gt; [a + b + c, b + c, c] + /// + /// + /// This is more efficient than using separate tf.reverse ops. + /// + /// The reverse and exclusive kwargs can also be combined: + /// + /// + /// tf.cumsum([a, b, c], exclusive=True, reverse=True) # =&gt; [b + c, c, 0] + /// + /// + public static Tensor cumsum(Tensor x, Tensor axis, bool? exclusive = null, bool? reverse = null, string name = "Cumsum") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["axis"] = axis; + if (exclusive.HasValue) + dict["exclusive"] = exclusive.Value; + if (reverse.HasValue) + dict["reverse"] = reverse.Value; + var op = tf.OpDefLib._apply_op_helper("Cumsum", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the dimension index in the destination data format given the one in + /// + /// + /// A Tensor with each element as a dimension index in source data format. + /// Must be in the range [-4, 4). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DataFormatDimMap'. + /// + /// + /// source data format. + /// + /// + /// destination data format. + /// + /// + /// A Tensor with each element as a dimension index in destination data format. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// the source data format. + /// + public static Tensor data_format_dim_map(Tensor x, string src_format = null, string dst_format = null, string name = "DataFormatDimMap") + { + var dict = new Dictionary(); + dict["x"] = x; + if (src_format != null) + dict["src_format"] = src_format; + if (dst_format != null) + dict["dst_format"] = dst_format; + var op = tf.OpDefLib._apply_op_helper("DataFormatDimMap", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the permuted vector/tensor in the destination data format given the + /// + /// + /// Vector of size 4 or Tensor of shape (4, 2) in source data format. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DataFormatVecPermute'. + /// + /// + /// source data format. + /// + /// + /// destination data format. + /// + /// + /// Vector of size 4 or Tensor of shape (4, 2) in destination data format. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// one in the source data format. + /// + public static Tensor data_format_vec_permute(Tensor x, string src_format = null, string dst_format = null, string name = "DataFormatVecPermute") + { + var dict = new Dictionary(); + dict["x"] = x; + if (src_format != null) + dict["src_format"] = src_format; + if (dst_format != null) + dict["dst_format"] = dst_format; + var op = tf.OpDefLib._apply_op_helper("DataFormatVecPermute", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a serialized GraphDef representing input_dataset. + /// + /// + /// A variant tensor representing the dataset to return the graph representation for. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DatasetToGraph'. + /// + /// + /// The graph representation of the dataset (as serialized GraphDef). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Returns a graph representation for input_dataset. + /// + public static Tensor dataset_to_graph(Tensor input_dataset, string name = "DatasetToGraph") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + var op = tf.OpDefLib._apply_op_helper("DatasetToGraph", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs the single element from the given dataset. + /// + /// + /// A handle to a dataset that contains a single element. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DatasetToSingleElement'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The components of the single element of input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] dataset_to_single_element(Tensor dataset, TF_DataType[] output_types, Shape[] output_shapes, string name = "DatasetToSingleElement") + { + var dict = new Dictionary(); + dict["dataset"] = dataset; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("DatasetToSingleElement", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Writes the given dataset to the given file using the TFRecord format. + /// + /// + /// A variant tensor representing the dataset to write. + /// + /// + /// A scalar string tensor representing the filename to use. + /// + /// + /// A scalar string tensor containing either (i) the empty string (no + /// compression), (ii) "ZLIB", or (iii) "GZIP". + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DatasetToTFRecord'. + /// + /// + /// Returns the description of the operation + /// + public static Operation dataset_to_t_f_record(Tensor input_dataset, Tensor filename, Tensor compression_type, string name = "DatasetToTFRecord") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["filename"] = filename; + dict["compression_type"] = compression_type; + var op = tf.OpDefLib._apply_op_helper("DatasetToTFRecord", name: name, keywords: dict); + return op; + } + + /// + /// Identity op for gradient debugging. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DebugGradientIdentity'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op is hidden from public in Python. It is used by TensorFlow Debugger to + /// register gradient tensors for gradient debugging. + /// This op operates on non-reference-type tensors. + /// + public static Tensor debug_gradient_identity(Tensor input, string name = "DebugGradientIdentity") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("DebugGradientIdentity", name: name, keywords: dict); + return op.output; + } + + /// + /// Identity op for gradient debugging. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DebugGradientRefIdentity'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op is hidden from public in Python. It is used by TensorFlow Debugger to + /// register gradient tensors for gradient debugging. + /// This op operates on reference-type tensors. + /// + public static Tensor debug_gradient_ref_identity(Tensor input, string name = "DebugGradientRefIdentity") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("DebugGradientRefIdentity", name: name, keywords: dict); + return op.output; + } + + /// + /// Debug Identity Op. + /// + /// + /// Input tensor, non-Reference type. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DebugIdentity'. + /// + /// + /// + /// + /// Name of the input tensor. + /// + /// + /// List of URLs to debug targets, e.g., + /// file:///foo/tfdbg_dump, grpc:://localhost:11011 + /// + /// + /// Whether this op will be gated. If any of the debug_urls of this + /// debug node is of the grpc:// scheme, when the value of this attribute is set + /// to True, the data will not actually be sent via the grpc stream unless this + /// debug op has been enabled at the debug_url. If all of the debug_urls of this + /// debug node are of the grpc:// scheme and the debug op is enabled at none of + /// them, the output will be an empty Tensor. + /// + /// + /// Output tensor that equals the input tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Provides an identity mapping of the non-Ref type input tensor for debugging. + /// + public static Tensor debug_identity(Tensor input, string device_name = null, string tensor_name = null, string[] debug_urls = null, bool? gated_grpc = null, string name = "DebugIdentity") + { + var dict = new Dictionary(); + dict["input"] = input; + if (device_name != null) + dict["device_name"] = device_name; + if (tensor_name != null) + dict["tensor_name"] = tensor_name; + if (debug_urls != null) + dict["debug_urls"] = debug_urls; + if (gated_grpc.HasValue) + dict["gated_grpc"] = gated_grpc.Value; + var op = tf.OpDefLib._apply_op_helper("DebugIdentity", name: name, keywords: dict); + return op.output; + } + + /// + /// Debug NaN Value Counter Op + /// + /// + /// Input tensor, non-Reference type. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DebugNanCount'. + /// + /// + /// + /// + /// Name of the input tensor. + /// + /// + /// List of URLs to debug targets, e.g., + /// file:///foo/tfdbg_dump, grpc:://localhost:11011. + /// + /// + /// Whether this op will be gated. If any of the debug_urls of this + /// debug node is of the grpc:// scheme, when the value of this attribute is set + /// to True, the data will not actually be sent via the grpc stream unless this + /// debug op has been enabled at the debug_url. If all of the debug_urls of this + /// debug node are of the grpc:// scheme and the debug op is enabled at none of + /// them, the output will be an empty Tensor. + /// + /// + /// An integer output tensor that is the number of NaNs in the input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Counts number of NaNs in the input tensor, for debugging. + /// + public static Tensor debug_nan_count(Tensor input, string device_name = null, string tensor_name = null, string[] debug_urls = null, bool? gated_grpc = null, string name = "DebugNanCount") + { + var dict = new Dictionary(); + dict["input"] = input; + if (device_name != null) + dict["device_name"] = device_name; + if (tensor_name != null) + dict["tensor_name"] = tensor_name; + if (debug_urls != null) + dict["debug_urls"] = debug_urls; + if (gated_grpc.HasValue) + dict["gated_grpc"] = gated_grpc.Value; + var op = tf.OpDefLib._apply_op_helper("DebugNanCount", name: name, keywords: dict); + return op.output; + } + + /// + /// Debug Numeric Summary Op. + /// + /// + /// Input tensor, non-Reference type, float or double. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DebugNumericSummary'. + /// + /// + /// + /// + /// Name of the input tensor. + /// + /// + /// List of URLs to debug targets, e.g., + /// file:///foo/tfdbg_dump, grpc:://localhost:11011 + /// + /// + /// (float) The lower bound &lt;= which values will be included in the + /// generalized -inf count. Default: -inf. + /// + /// + /// (float) The upper bound &gt;= which values will be included in the + /// generalized +inf count. Default: +inf. + /// + /// + /// (bool) Do not send data to the debug URLs unless at least one + /// of elements [2], [3] and [7] (i.e., the nan count and the generalized -inf and + /// inf counts) is non-zero. + /// + /// + /// Whether this op will be gated. If any of the debug_urls of this + /// debug node is of the grpc:// scheme, when the value of this attribute is set + /// to True, the data will not actually be sent via the grpc stream unless this + /// debug op has been enabled at the debug_url. If all of the debug_urls of this + /// debug node are of the grpc:// scheme and the debug op is enabled at none of + /// them, the output will be an empty Tensor. + /// + /// + /// A double tensor of shape [14 + nDimensions], where nDimensions is the + /// the number of dimensions of the tensor's shape. The elements of output are: + /// [0]: is initialized (1.0) or not (0.0). + /// [1]: total number of elements + /// [2]: NaN element count + /// [3]: generalized -inf count: elements &lt;= lower_bound. lower_bound is -inf by + /// default. + /// [4]: negative element count (excluding -inf), if lower_bound is the default + /// -inf. Otherwise, this is the count of elements &gt; lower_bound and &lt; 0. + /// [5]: zero element count + /// [6]: positive element count (excluding +inf), if upper_bound is the default + /// -inf. Otherwise, this is the count of elements &lt; upper_bound and &gt; 0. + /// [7]: generalized +inf count, elements &gt;= upper_bound. upper_bound is +inf by + /// default. + /// Output elements [1:8] are all zero, if the tensor is uninitialized. + /// [8]: minimum of all non-inf and non-NaN elements. + /// If uninitialized or no such element exists: +inf. + /// [9]: maximum of all non-inf and non-NaN elements. + /// If uninitialized or no such element exists: -inf. + /// [10]: mean of all non-inf and non-NaN elements. + /// If uninitialized or no such element exists: NaN. + /// [11]: variance of all non-inf and non-NaN elements. + /// If uninitialized or no such element exists: NaN. + /// [12]: Data type of the tensor encoded as an enum integer. See the DataType + /// proto for more details. + /// [13]: Number of dimensions of the tensor (ndims). + /// [14+]: Sizes of the dimensions. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Provide a basic summary of numeric value types, range and distribution. + /// + public static Tensor debug_numeric_summary(Tensor input, string device_name = null, string tensor_name = null, string[] debug_urls = null, float? lower_bound = null, float? upper_bound = null, bool? mute_if_healthy = null, bool? gated_grpc = null, string name = "DebugNumericSummary") + { + var dict = new Dictionary(); + dict["input"] = input; + if (device_name != null) + dict["device_name"] = device_name; + if (tensor_name != null) + dict["tensor_name"] = tensor_name; + if (debug_urls != null) + dict["debug_urls"] = debug_urls; + if (lower_bound.HasValue) + dict["lower_bound"] = lower_bound.Value; + if (upper_bound.HasValue) + dict["upper_bound"] = upper_bound.Value; + if (mute_if_healthy.HasValue) + dict["mute_if_healthy"] = mute_if_healthy.Value; + if (gated_grpc.HasValue) + dict["gated_grpc"] = gated_grpc.Value; + var op = tf.OpDefLib._apply_op_helper("DebugNumericSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode and Crop a JPEG-encoded image to a uint8 tensor. + /// + /// + /// 0-D. The JPEG-encoded image. + /// + /// + /// 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeAndCropJpeg'. + /// + /// + /// Number of color channels for the decoded image. + /// + /// + /// Downscaling ratio. + /// + /// + /// If true use a slower but nicer upscaling of the + /// chroma planes (yuv420/422 only). + /// + /// + /// If true try to recover an image from truncated input. + /// + /// + /// The minimum required fraction of lines before a truncated + /// input is accepted. + /// + /// + /// string specifying a hint about the algorithm used for + /// decompression. Defaults to "" which maps to a system-specific + /// default. Currently valid values are ["INTEGER_FAST", + /// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal + /// jpeg library changes to a version that does not have that specific + /// option.) + /// + /// + /// 3-D with shape [height, width, channels].. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The attr channels indicates the desired number of color channels for the + /// decoded image. + /// + /// Accepted values are: + /// + /// * 0: Use the number of channels in the JPEG-encoded image. + /// * 1: output a grayscale image. + /// * 3: output an RGB image. + /// + /// If needed, the JPEG-encoded image is transformed to match the requested number + /// of color channels. + /// + /// The attr ratio allows downscaling the image by an integer factor during + /// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than + /// downscaling the image later. + /// + /// + /// It is equivalent to a combination of decode and crop, but much faster by only + /// decoding partial jpeg image. + /// + public static Tensor decode_and_crop_jpeg(Tensor contents, Tensor crop_window, int? channels = null, int? ratio = null, bool? fancy_upscaling = null, bool? try_recover_truncated = null, float? acceptable_fraction = null, string dct_method = null, string name = "DecodeAndCropJpeg") + { + var dict = new Dictionary(); + dict["contents"] = contents; + dict["crop_window"] = crop_window; + if (channels.HasValue) + dict["channels"] = channels.Value; + if (ratio.HasValue) + dict["ratio"] = ratio.Value; + if (fancy_upscaling.HasValue) + dict["fancy_upscaling"] = fancy_upscaling.Value; + if (try_recover_truncated.HasValue) + dict["try_recover_truncated"] = try_recover_truncated.Value; + if (acceptable_fraction.HasValue) + dict["acceptable_fraction"] = acceptable_fraction.Value; + if (dct_method != null) + dict["dct_method"] = dct_method; + var op = tf.OpDefLib._apply_op_helper("DecodeAndCropJpeg", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode web-safe base64-encoded strings. + /// + /// + /// Base64 strings to decode. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeBase64'. + /// + /// + /// Decoded strings. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Input may or may not have padding at the end. See EncodeBase64 for padding. + /// Web-safe means that input must use - and _ instead of + and /. + /// + public static Tensor decode_base64(Tensor input, string name = "DecodeBase64") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("DecodeBase64", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode the first frame of a BMP-encoded image to a uint8 tensor. + /// + /// + /// 0-D. The BMP-encoded image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeBmp'. + /// + /// + /// + /// + /// 3-D with shape [height, width, channels]. RGB order + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The attr channels indicates the desired number of color channels for the + /// decoded image. + /// + /// Accepted values are: + /// + /// * 0: Use the number of channels in the BMP-encoded image. + /// * 3: output an RGB image. + /// * 4: output an RGBA image. + /// + public static Tensor decode_bmp(Tensor contents, int? channels = null, string name = "DecodeBmp") + { + var dict = new Dictionary(); + dict["contents"] = contents; + if (channels.HasValue) + dict["channels"] = channels.Value; + var op = tf.OpDefLib._apply_op_helper("DecodeBmp", name: name, keywords: dict); + return op.output; + } + + /// + /// Convert CSV records to tensors. Each column maps to one tensor. + /// + /// + /// Each string is a record/row in the csv and all records should have + /// the same format. + /// + /// + /// One tensor per column of the input record, with either a + /// scalar default value for that column or an empty vector if the column is + /// required. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeCSV'. + /// + /// + /// char delimiter to separate fields in a record. + /// + /// + /// If false, treats double quotation marks as regular + /// characters inside of the string fields (ignoring RFC 4180, Section 2, + /// Bullet 5). + /// + /// + /// Additional string to recognize as NA/NaN. + /// + /// + /// + /// + /// Each tensor will have the same shape as records. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// RFC 4180 format is expected for the CSV records. + /// (https://tools.ietensorflow.org/html/rfc4180) + /// Note that we allow leading and trailing spaces with int or float field. + /// + public static Tensor[] decode_c_s_v(Tensor records, Tensor[] record_defaults, string field_delim = null, bool? use_quote_delim = null, string na_value = null, int[] select_cols = null, string name = "DecodeCSV") + { + var dict = new Dictionary(); + dict["records"] = records; + dict["record_defaults"] = record_defaults; + if (field_delim != null) + dict["field_delim"] = field_delim; + if (use_quote_delim.HasValue) + dict["use_quote_delim"] = use_quote_delim.Value; + if (na_value != null) + dict["na_value"] = na_value; + if (select_cols != null) + dict["select_cols"] = select_cols; + var op = tf.OpDefLib._apply_op_helper("DecodeCSV", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// Decompress strings. + /// + /// + /// A Tensor of string which is compressed. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeCompressed'. + /// + /// + /// A scalar containing either (i) the empty string (no + /// compression), (ii) "ZLIB", or (iii) "GZIP". + /// + /// + /// A Tensor with the same shape as input bytes, uncompressed + /// from bytes. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op decompresses each element of the bytes input Tensor, which + /// is assumed to be compressed using the given compression_type. + /// + /// The output is a string Tensor of the same shape as bytes, + /// each element containing the decompressed data from the corresponding + /// element in bytes. + /// + public static Tensor decode_compressed(Tensor bytes, string compression_type = null, string name = "DecodeCompressed") + { + var dict = new Dictionary(); + dict["bytes"] = bytes; + if (compression_type != null) + dict["compression_type"] = compression_type; + var op = tf.OpDefLib._apply_op_helper("DecodeCompressed", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode the first frame of a GIF-encoded image to a uint8 tensor. + /// + /// + /// 0-D. The GIF-encoded image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeGif'. + /// + /// + /// 4-D with shape [num_frames, height, width, 3]. RGB order + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// GIF with frame or transparency compression are not supported + /// convert animated GIF from compressed to uncompressed by: + /// + /// convert $src.gif -coalesce $dst.gif + /// + /// This op also supports decoding JPEGs and PNGs, though it is cleaner to use + /// tf.image.decode_image. + /// + public static Tensor decode_gif(Tensor contents, string name = "DecodeGif") + { + var dict = new Dictionary(); + dict["contents"] = contents; + var op = tf.OpDefLib._apply_op_helper("DecodeGif", name: name, keywords: dict); + return op.output; + } + + /// + /// Convert JSON-encoded Example records to binary protocol buffer strings. + /// + /// + /// Each string is a JSON object serialized according to the JSON + /// mapping of the Example proto. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeJSONExample'. + /// + /// + /// Each string is a binary Example protocol buffer corresponding + /// to the respective element of json_examples. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op translates a tensor containing Example records, encoded using + /// the [standard JSON + /// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), + /// into a tensor containing the same records encoded as binary protocol + /// buffers. The resulting tensor can then be fed to any of the other + /// Example-parsing ops. + /// + public static Tensor decode_j_s_o_n_example(Tensor json_examples, string name = "DecodeJSONExample") + { + var dict = new Dictionary(); + dict["json_examples"] = json_examples; + var op = tf.OpDefLib._apply_op_helper("DecodeJSONExample", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode a JPEG-encoded image to a uint8 tensor. + /// + /// + /// 0-D. The JPEG-encoded image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeJpeg'. + /// + /// + /// Number of color channels for the decoded image. + /// + /// + /// Downscaling ratio. + /// + /// + /// If true use a slower but nicer upscaling of the + /// chroma planes (yuv420/422 only). + /// + /// + /// If true try to recover an image from truncated input. + /// + /// + /// The minimum required fraction of lines before a truncated + /// input is accepted. + /// + /// + /// string specifying a hint about the algorithm used for + /// decompression. Defaults to "" which maps to a system-specific + /// default. Currently valid values are ["INTEGER_FAST", + /// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal + /// jpeg library changes to a version that does not have that specific + /// option.) + /// + /// + /// 3-D with shape [height, width, channels].. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The attr channels indicates the desired number of color channels for the + /// decoded image. + /// + /// Accepted values are: + /// + /// * 0: Use the number of channels in the JPEG-encoded image. + /// * 1: output a grayscale image. + /// * 3: output an RGB image. + /// + /// If needed, the JPEG-encoded image is transformed to match the requested number + /// of color channels. + /// + /// The attr ratio allows downscaling the image by an integer factor during + /// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than + /// downscaling the image later. + /// + /// + /// This op also supports decoding PNGs and non-animated GIFs since the interface is + /// the same, though it is cleaner to use tf.image.decode_image. + /// + public static Tensor decode_jpeg(Tensor contents, int? channels = null, int? ratio = null, bool? fancy_upscaling = null, bool? try_recover_truncated = null, float? acceptable_fraction = null, string dct_method = null, string name = "DecodeJpeg") + { + var dict = new Dictionary(); + dict["contents"] = contents; + if (channels.HasValue) + dict["channels"] = channels.Value; + if (ratio.HasValue) + dict["ratio"] = ratio.Value; + if (fancy_upscaling.HasValue) + dict["fancy_upscaling"] = fancy_upscaling.Value; + if (try_recover_truncated.HasValue) + dict["try_recover_truncated"] = try_recover_truncated.Value; + if (acceptable_fraction.HasValue) + dict["acceptable_fraction"] = acceptable_fraction.Value; + if (dct_method != null) + dict["dct_method"] = dct_method; + var op = tf.OpDefLib._apply_op_helper("DecodeJpeg", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode a PNG-encoded image to a uint8 or uint16 tensor. + /// + /// + /// 0-D. The PNG-encoded image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodePng'. + /// + /// + /// Number of color channels for the decoded image. + /// + /// + /// + /// + /// 3-D with shape [height, width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The attr channels indicates the desired number of color channels for the + /// decoded image. + /// + /// Accepted values are: + /// + /// * 0: Use the number of channels in the PNG-encoded image. + /// * 1: output a grayscale image. + /// * 3: output an RGB image. + /// * 4: output an RGBA image. + /// + /// If needed, the PNG-encoded image is transformed to match the requested number + /// of color channels. + /// + /// This op also supports decoding JPEGs and non-animated GIFs since the interface + /// is the same, though it is cleaner to use tf.image.decode_image. + /// + public static Tensor decode_png(Tensor contents, int? channels = null, TF_DataType? dtype = null, string name = "DecodePng") + { + var dict = new Dictionary(); + dict["contents"] = contents; + if (channels.HasValue) + dict["channels"] = channels.Value; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("DecodePng", name: name, keywords: dict); + return op.output; + } + + /// + /// The op extracts fields from a serialized protocol buffers message into tensors. + /// + /// + /// Tensor of serialized protos with shape batch_shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeProtoV2'. + /// + /// + /// Optional argument + /// Name of the proto message type to decode. + /// + /// + /// Optional argument + /// List of strings containing proto field names. + /// + /// + /// Optional argument + /// List of TF types to use for the respective field in field_names. + /// + /// + /// Either the special value local:// or a path to a file containing + /// a serialized FileDescriptorSet. + /// + /// + /// Either binary or text. + /// + /// + /// Whether to sanitize the result or not. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sizes : Tensor of int32 with shape [batch_shape, len(field_names)]. + /// Each entry is the number of values found for the corresponding field. + /// Optional fields may have 0 or 1 values. + /// values : List of tensors containing values for the corresponding field. + /// values[i] has datatype output_types[i] + /// and shape [batch_shape, max(sizes[...,i])]. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The decode_proto op extracts fields from a serialized protocol buffers + /// message into tensors. The fields in field_names are decoded and converted + /// to the corresponding output_types if possible. + /// + /// A message_type name must be provided to give context for the field + /// names. The actual message descriptor can be looked up either in the + /// linked-in descriptor pool or a filename provided by the caller using + /// the descriptor_source attribute. + /// + /// Each output tensor is a dense tensor. This means that it is padded to + /// hold the largest number of repeated elements seen in the input + /// minibatch. (The shape is also padded by one to prevent zero-sized + /// dimensions). The actual repeat counts for each example in the + /// minibatch can be found in the sizes output. In many cases the output + /// of decode_proto is fed immediately into tf.squeeze if missing values + /// are not a concern. When using tf.squeeze, always pass the squeeze + /// dimension explicitly to avoid surprises. + /// + /// For the most part, the mapping between Proto field types and + /// TensorFlow dtypes is straightforward. However, there are a few + /// special cases: + /// + /// - A proto field that contains a submessage or group can only be converted + /// to DT_STRING (the serialized submessage). This is to reduce the + /// complexity of the API. The resulting string can be used as input + /// to another instance of the decode_proto op. + /// + /// - TensorFlow lacks support for unsigned integers. The ops represent uint64 + /// types as a DT_INT64 with the same twos-complement bit pattern + /// (the obvious way). Unsigned int32 values can be represented exactly by + /// specifying type DT_INT64, or using twos-complement if the caller + /// specifies DT_INT32 in the output_types attribute. + /// + /// The descriptor_source attribute selects a source of protocol + /// descriptors to consult when looking up message_type. This may be a + /// filename containing a serialized FileDescriptorSet message, + /// or the special value local://, in which case only descriptors linked + /// into the code will be searched; the filename can be on any filesystem + /// accessible to TensorFlow. + /// + /// You can build a descriptor_source file using the --descriptor_set_out + /// and --include_imports options to the protocol compiler protoc. + /// + /// The local:// database only covers descriptors linked into the + /// code via C++ libraries, not Python imports. You can link in a proto descriptor + /// by creating a cc_library target with alwayslink=1. + /// + /// Both binary and text proto serializations are supported, and can be + /// chosen using the format attribute. + /// + public static (Tensor sizes, Tensor[] values) decode_proto_v2(Tensor bytes, string message_type, string[] field_names, TF_DataType[] output_types, string descriptor_source = null, string message_format = null, bool? sanitize = null, string name = "DecodeProtoV2") + { + var dict = new Dictionary(); + dict["bytes"] = bytes; + dict["message_type"] = message_type; + dict["field_names"] = field_names; + dict["output_types"] = output_types; + if (descriptor_source != null) + dict["descriptor_source"] = descriptor_source; + if (message_format != null) + dict["message_format"] = message_format; + if (sanitize.HasValue) + dict["sanitize"] = sanitize.Value; + var op = tf.OpDefLib._apply_op_helper("DecodeProtoV2", name: name, keywords: dict); + int _idx = 0; + var sizes = op.outputs[_idx++]; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (sizes, values); + } + + /// + /// Reinterpret the bytes of a string as a vector of numbers. + /// + /// + /// All the elements must have the same length. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeRaw'. + /// + /// + /// Optional argument + /// + /// + /// Whether the input bytes are in little-endian order. + /// Ignored for out_type values that are stored in a single byte like + /// uint8. + /// + /// + /// A Tensor with one more dimension than the input bytes. The + /// added dimension will have size equal to the length of the elements + /// of bytes divided by the number of bytes to represent out_type. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor decode_raw(Tensor bytes, TF_DataType out_type, bool? little_endian = null, string name = "DecodeRaw") + { + var dict = new Dictionary(); + dict["bytes"] = bytes; + dict["out_type"] = out_type; + if (little_endian.HasValue) + dict["little_endian"] = little_endian.Value; + var op = tf.OpDefLib._apply_op_helper("DecodeRaw", name: name, keywords: dict); + return op.output; + } + + /// + /// Decode a 16-bit PCM WAV file to a float tensor. + /// + /// + /// The WAV-encoded audio, usually from a file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DecodeWav'. + /// + /// + /// Number of sample channels wanted. + /// + /// + /// Length of audio requested. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// audio : 2-D with shape [length, channels]. + /// sample_rate : Scalar holding the sample rate found in the WAV header. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. + /// + /// When desired_channels is set, if the input contains fewer channels than this + /// then the last channel will be duplicated to give the requested number, else if + /// the input has more channels than requested then the additional channels will be + /// ignored. + /// + /// If desired_samples is set, then the audio will be cropped or padded with zeroes + /// to the requested length. + /// + /// The first output contains a Tensor with the content of the audio samples. The + /// lowest dimension will be the number of channels, and the second will be the + /// number of samples. For example, a ten-sample-long stereo WAV file should give an + /// output shape of [10, 2]. + /// + public static (Tensor audio, Tensor sample_rate) decode_wav(Tensor contents, int? desired_channels = null, int? desired_samples = null, string name = "DecodeWav") + { + var dict = new Dictionary(); + dict["contents"] = contents; + if (desired_channels.HasValue) + dict["desired_channels"] = desired_channels.Value; + if (desired_samples.HasValue) + dict["desired_samples"] = desired_samples.Value; + var op = tf.OpDefLib._apply_op_helper("DecodeWav", name: name, keywords: dict); + int _idx = 0; + var audio = op.outputs[_idx++]; + var sample_rate = op.outputs[_idx++]; + return (audio, sample_rate); + } + + /// + /// Makes a copy of x. + /// + /// + /// The source tensor of type T. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DeepCopy'. + /// + /// + /// y: A Tensor of type T. A copy of x. Guaranteed that y + /// is not an alias of x. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor deep_copy(Tensor x, string name = "DeepCopy") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("DeepCopy", name: name, keywords: dict); + return op.output; + } + + /// + /// Delete the tensor specified by its handle in the session. + /// + /// + /// The handle for a tensor stored in the session state. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DeleteSessionTensor'. + /// + /// + /// Returns the description of the operation + /// + public static Operation delete_session_tensor(Tensor handle, string name = "DeleteSessionTensor") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("DeleteSessionTensor", name: name, keywords: dict); + return op; + } + + /// + /// Applies set operation along last dimension of 2 Tensor inputs. + /// + /// + /// Tensor with rank n. 1st n-1 dimensions must be the same as set2. + /// Dimension n contains values in a set, duplicates are allowed but ignored. + /// + /// + /// Tensor with rank n. 1st n-1 dimensions must be the same as set1. + /// Dimension n contains values in a set, duplicates are allowed but ignored. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DenseToDenseSetOperation'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// result_indices : 2D indices of a SparseTensor. + /// result_values : 1D values of a SparseTensor. + /// result_shape : 1D Tensor shape of a SparseTensor. result_shape[0...n-1] is + /// the same as the 1st n-1 dimensions of set1 and set2, result_shape[n] + /// is the max result set size across all 0...n-1 dimensions. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See SetOperationOp::SetOperationFromContext for values of set_operation. + /// + /// Output result is a SparseTensor represented by result_indices, + /// result_values, and result_shape. For set1 and set2 ranked n, this + /// has rank n and the same 1st n-1 dimensions as set1 and set2. The nth + /// dimension contains the result of set_operation applied to the corresponding + /// [0...n-1] dimension of set. + /// + public static (Tensor result_indices, Tensor result_values, Tensor result_shape) dense_to_dense_set_operation(Tensor set1, Tensor set2, string set_operation, bool? validate_indices = null, string name = "DenseToDenseSetOperation") + { + var dict = new Dictionary(); + dict["set1"] = set1; + dict["set2"] = set2; + dict["set_operation"] = set_operation; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("DenseToDenseSetOperation", name: name, keywords: dict); + int _idx = 0; + var result_indices = op.outputs[_idx++]; + var result_values = op.outputs[_idx++]; + var result_shape = op.outputs[_idx++]; + return (result_indices, result_values, result_shape); + } + + /// + /// Creates a dataset that batches input elements into a SparseTensor. + /// + /// + /// A handle to an input dataset. Must have a single component. + /// + /// + /// A scalar representing the number of elements to accumulate in a + /// batch. + /// + /// + /// A vector representing the dense shape of each row in the produced + /// SparseTensor. The shape may be partially specified, using -1 to indicate + /// that a particular dimension should use the maximum size of all batch elements. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DenseToSparseBatchDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor dense_to_sparse_batch_dataset(Tensor input_dataset, Tensor batch_size, Tensor row_shape, TF_DataType[] output_types, Shape[] output_shapes, string name = "DenseToSparseBatchDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["batch_size"] = batch_size; + dict["row_shape"] = row_shape; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("DenseToSparseBatchDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies set operation along last dimension of Tensor and SparseTensor. + /// + /// + /// Tensor with rank n. 1st n-1 dimensions must be the same as set2. + /// Dimension n contains values in a set, duplicates are allowed but ignored. + /// + /// + /// 2D Tensor, indices of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, values of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, shape of a SparseTensor. set2_shape[0...n-1] must + /// be the same as the 1st n-1 dimensions of set1, result_shape[n] is the + /// max set size across n-1 dimensions. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DenseToSparseSetOperation'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// result_indices : 2D indices of a SparseTensor. + /// result_values : 1D values of a SparseTensor. + /// result_shape : 1D Tensor shape of a SparseTensor. result_shape[0...n-1] is + /// the same as the 1st n-1 dimensions of set1 and set2, result_shape[n] + /// is the max result set size across all 0...n-1 dimensions. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See SetOperationOp::SetOperationFromContext for values of set_operation. + /// + /// Input set2 is a SparseTensor represented by set2_indices, set2_values, + /// and set2_shape. For set2 ranked n, 1st n-1 dimensions must be the same + /// as set1. Dimension n contains values in a set, duplicates are allowed but + /// ignored. + /// + /// If validate_indices is True, this op validates the order and range of set2 + /// indices. + /// + /// Output result is a SparseTensor represented by result_indices, + /// result_values, and result_shape. For set1 and set2 ranked n, this + /// has rank n and the same 1st n-1 dimensions as set1 and set2. The nth + /// dimension contains the result of set_operation applied to the corresponding + /// [0...n-1] dimension of set. + /// + public static (Tensor result_indices, Tensor result_values, Tensor result_shape) dense_to_sparse_set_operation(Tensor set1, Tensor set2_indices, Tensor set2_values, Tensor set2_shape, string set_operation, bool? validate_indices = null, string name = "DenseToSparseSetOperation") + { + var dict = new Dictionary(); + dict["set1"] = set1; + dict["set2_indices"] = set2_indices; + dict["set2_values"] = set2_values; + dict["set2_shape"] = set2_shape; + dict["set_operation"] = set_operation; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("DenseToSparseSetOperation", name: name, keywords: dict); + int _idx = 0; + var result_indices = op.outputs[_idx++]; + var result_values = op.outputs[_idx++]; + var result_shape = op.outputs[_idx++]; + return (result_indices, result_values, result_shape); + } + + /// + /// DepthToSpace for tensors of type T. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DepthToSpace'. + /// + /// + /// Optional argument + /// The size of the spatial block, same as in Space2Depth. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Rearranges data from depth into blocks of spatial data. + /// This is the reverse transformation of SpaceToDepth. More specifically, + /// this op outputs a copy of the input tensor where values from the depth + /// dimension are moved in spatial blocks to the height and width dimensions. + /// The attr block_size indicates the input block size and how the data is moved. + /// + /// * Chunks of data of size block_size * block_size from depth are rearranged + /// into non-overlapping blocks of size block_size x block_size + /// * The width the output tensor is input_depth * block_size, whereas the + /// height is input_height * block_size. + /// * The Y, X coordinates within each block of the output image are determined + /// by the high order component of the input channel index. + /// * The depth of the input tensor must be divisible by + /// block_size * block_size. + /// + /// The data_format attr specifies the layout of the input and output tensors + /// with the following options: + /// "NHWC": [ batch, height, width, channels ] + /// "NCHW": [ batch, channels, height, width ] + /// "NCHW_VECT_C": + /// qint8 [ batch, channels / 4, height, width, 4 ] + /// + /// It is useful to consider the operation as transforming a 6-D Tensor. + /// e.g. for data_format = NHWC, + /// Each element in the input tensor can be specified via 6 coordinates, + /// ordered by decreasing memory layout significance as: + /// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates + /// within the input image, bX, bY means coordinates + /// within the output block, oC means output channels). + /// The output would be the input transposed to the following layout: + /// n,iY,bY,iX,bX,oC + /// + /// This operation is useful for resizing the activations between convolutions + /// (but keeping all data), e.g. instead of pooling. It is also useful for training + /// purely convolutional models. + /// + /// For example, given an input of shape [1, 1, 1, 4], data_format = "NHWC" and + /// block_size = 2: + /// + /// + /// x = [[[[1, 2, 3, 4]]]] + /// + /// + /// + /// This operation will output a tensor of shape [1, 2, 2, 1]: + /// + /// + /// [[[[1], [2]], + /// [[3], [4]]]] + /// + /// + /// Here, the input has a batch of 1 and each batch element has shape [1, 1, 4], + /// the corresponding output will have 2x2 elements and will have a depth of + /// 1 channel (1 = 4 / (block_size * block_size)). + /// The output element shape is [2, 2, 1]. + /// + /// For an input tensor with larger depth, here of shape [1, 1, 1, 12], e.g. + /// + /// + /// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] + /// + /// + /// This operation, for block size of 2, will return the following tensor of shape + /// [1, 2, 2, 3] + /// + /// + /// [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// + /// Similarly, for the following input of shape [1 2 2 4], and a block size of 2: + /// + /// + /// x = [[[[1, 2, 3, 4], + /// [5, 6, 7, 8]], + /// [[9, 10, 11, 12], + /// [13, 14, 15, 16]]]] + /// + /// + /// the operator will return the following tensor of shape [1 4 4 1]: + /// + /// + /// x = [[[ [1], [2], [5], [6]], + /// [ [3], [4], [7], [8]], + /// [ [9], [10], [13], [14]], + /// [ [11], [12], [15], [16]]]] + /// + /// + /// + public static Tensor depth_to_space(Tensor input, int block_size, string data_format = null, string name = "DepthToSpace") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["block_size"] = block_size; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("DepthToSpace", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes a 2-D depthwise convolution given 4-D input and filter tensors. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DepthwiseConv2dNative'. + /// + /// + /// Optional argument + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of input. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// data_format, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given an input tensor of shape [batch, in_height, in_width, in_channels] + /// and a filter / kernel tensor of shape + /// [filter_height, filter_width, in_channels, channel_multiplier], containing + /// in_channels convolutional filters of depth 1, depthwise_conv2d applies + /// a different filter to each input channel (expanding from 1 channel to + /// channel_multiplier channels for each), then concatenates the results + /// together. Thus, the output has in_channels * channel_multiplier channels. + /// + /// + /// for k in 0..in_channels-1 + /// for q in 0..channel_multiplier-1 + /// output[b, i, j, k * channel_multiplier + q] = + /// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * + /// filter[di, dj, k, q] + /// + /// + /// Must have strides[0] = strides[3] = 1. For the most common case of the same + /// horizontal and vertices strides, strides = [1, stride, stride, 1]. + /// + public static Tensor depthwise_conv2d_native(Tensor input, Tensor filter, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "DepthwiseConv2dNative") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNative", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of depthwise convolution with respect to the filter. + /// + /// + /// 4-D with shape based on data_format. For example, if + /// data_format is 'NHWC' then input is a 4-D [batch, in_height, + /// in_width, in_channels] tensor. + /// + /// + /// An integer vector representing the tensor shape of filter, + /// where filter is a 4-D + /// [filter_height, filter_width, in_channels, depthwise_multiplier] tensor. + /// + /// + /// 4-D with shape based on data_format. + /// For example, if data_format is 'NHWC' then + /// out_backprop shape is [batch, out_height, out_width, out_channels]. + /// Gradients w.r.t. the output of the convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DepthwiseConv2dNativeBackpropFilter'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// of the convolution. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// data_format, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, out_channels]. Gradient w.r.t. + /// the filter input of the convolution. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor depthwise_conv2d_native_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "DepthwiseConv2dNativeBackpropFilter") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter_sizes"] = filter_sizes; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNativeBackpropFilter", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradients of depthwise convolution with respect to the input. + /// + /// + /// An integer vector representing the shape of input, based + /// on data_format. For example, if data_format is 'NHWC' then + /// input is a 4-D [batch, height, width, channels] tensor. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, depthwise_multiplier]. + /// + /// + /// 4-D with shape based on data_format. + /// For example, if data_format is 'NHWC' then + /// out_backprop shape is [batch, out_height, out_width, out_channels]. + /// Gradients w.r.t. the output of the convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DepthwiseConv2dNativeBackpropInput'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// of the convolution. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, height, width, channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, channels, height, width]. + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each filter + /// element on that dimension. The dimension order is determined by the value of + /// data_format, see above for details. Dilations in the batch and depth + /// dimensions must be 1. + /// + /// + /// 4-D with shape according to data_format. For example, if + /// data_format is 'NHWC', output shape is [batch, in_height, + /// in_width, in_channels]. Gradient w.r.t. the input of the + /// convolution. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor depthwise_conv2d_native_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, int[] strides, string padding, string data_format = null, int[] dilations = null, string name = "DepthwiseConv2dNativeBackpropInput") + { + var dict = new Dictionary(); + dict["input_sizes"] = input_sizes; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("DepthwiseConv2dNativeBackpropInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Dequantize the 'input' tensor into a float Tensor. + /// + /// + /// + /// + /// The minimum scalar value possibly produced for the input. + /// + /// + /// The maximum scalar value possibly produced for the input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Dequantize'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// [min_range, max_range] are scalar floats that specify the range for + /// the 'input' data. The 'mode' attribute controls exactly which calculations are + /// used to convert the float values to their quantized equivalents. + /// + /// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: + /// + /// + /// if T == qint8, in[i] += (range(T) + 1)/ 2.0 + /// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) + /// + /// here range(T) = numeric_limits&lt;T&gt;::max() - numeric_limits&lt;T&gt;::min() + /// + /// *MIN_COMBINED Mode Example* + /// + /// If the input comes from a QuantizedRelu6, the output type is + /// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is + /// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. + /// Dequantize on quint8 will take each value, cast to float, and multiply + /// by 6 / 255. + /// Note that if quantizedtype is qint8, the operation will additionally add + /// each value by 128 prior to casting. + /// + /// If the mode is 'MIN_FIRST', then this approach is used: + /// + /// + /// num_discrete_values = 1 &lt;&lt; (# of bits in T) + /// range_adjust = num_discrete_values / (num_discrete_values - 1) + /// range = (range_max - range_min) * range_adjust + /// range_scale = range / num_discrete_values + /// const double offset_input = static_cast&lt;double&gt;(input) - lowest_quantized; + /// result = range_min + ((input - numeric_limits&lt;T&gt;::min()) * range_scale) + /// + /// + /// *SCALED mode Example* + /// + /// SCALED mode matches the quantization approach used in + /// QuantizeAndDequantize{V2|V3}. + /// + /// If the mode is SCALED, we do not use the full range of the output type, + /// choosing to elide the lowest possible value for symmetry (e.g., output range is + /// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to + /// 0. + /// + /// We first find the range of values in our tensor. The + /// range we use is always centered on 0, so we find m such that + /// + /// m = max(abs(input_min), abs(input_max)) + /// + /// + /// Our input tensor range is then [-m, m]. + /// + /// Next, we choose our fixed-point quantization buckets, [min_fixed, max_fixed]. + /// If T is signed, this is + /// + /// num_bits = sizeof(T) * 8 + /// [min_fixed, max_fixed] = + /// [-(1 &lt;&lt; (num_bits - 1) - 1), (1 &lt;&lt; (num_bits - 1)) - 1] + /// + /// + /// Otherwise, if T is unsigned, the fixed-point range is + /// + /// [min_fixed, max_fixed] = [0, (1 &lt;&lt; num_bits) - 1] + /// + /// + /// From this we compute our scaling factor, s: + /// + /// s = (2 * m) / (max_fixed - min_fixed) + /// + /// + /// Now we can dequantize the elements of our tensor: + /// + /// result = input * s + /// + /// + public static Tensor dequantize(Tensor input, Tensor min_range, Tensor max_range, string mode = null, string name = "Dequantize") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["min_range"] = min_range; + dict["max_range"] = max_range; + if (mode != null) + dict["mode"] = mode; + var op = tf.OpDefLib._apply_op_helper("Dequantize", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts the given variant tensor to an iterator and stores it in the given resource. + /// + /// + /// A handle to an iterator resource. + /// + /// + /// A variant tensor storing the state of the iterator contained in the + /// resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DeserializeIterator'. + /// + /// + /// Returns the description of the operation + /// + public static Operation deserialize_iterator(Tensor resource_handle, Tensor serialized, string name = "DeserializeIterator") + { + var dict = new Dictionary(); + dict["resource_handle"] = resource_handle; + dict["serialized"] = serialized; + var op = tf.OpDefLib._apply_op_helper("DeserializeIterator", name: name, keywords: dict); + return op; + } + + /// + /// Deserialize and concatenate SparseTensors from a serialized minibatch. + /// + /// + /// 2-D, The N serialized SparseTensor objects. + /// Must have 3 columns. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DeserializeManySparse'. + /// + /// + /// Optional argument + /// The dtype of the serialized SparseTensor objects. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sparse_indices : + /// sparse_values : + /// sparse_shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The input serialized_sparse must be a string matrix of shape [N x 3] where + /// N is the minibatch size and the rows correspond to packed outputs of + /// SerializeSparse. The ranks of the original SparseTensor objects + /// must all match. When the final SparseTensor is created, it has rank one + /// higher than the ranks of the incoming SparseTensor objects + /// (they have been concatenated along a new row dimension). + /// + /// The output SparseTensor object's shape values for all dimensions but the + /// first are the max across the input SparseTensor objects' shape values + /// for the corresponding dimensions. Its first shape value is N, the minibatch + /// size. + /// + /// The input SparseTensor objects' indices are assumed ordered in + /// standard lexicographic order. If this is not the case, after this + /// step run SparseReorder to restore index ordering. + /// + /// For example, if the serialized input is a [2 x 3] matrix representing two + /// original SparseTensor objects: + /// + /// index = [ 0] + /// [10] + /// [20] + /// values = [1, 2, 3] + /// shape = [50] + /// + /// and + /// + /// index = [ 2] + /// [10] + /// values = [4, 5] + /// shape = [30] + /// + /// then the final deserialized SparseTensor will be: + /// + /// index = [0 0] + /// [0 10] + /// [0 20] + /// [1 2] + /// [1 10] + /// values = [1, 2, 3, 4, 5] + /// shape = [2 50] + /// + public static (Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape) deserialize_many_sparse(Tensor serialized_sparse, TF_DataType dtype, string name = "DeserializeManySparse") + { + var dict = new Dictionary(); + dict["serialized_sparse"] = serialized_sparse; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("DeserializeManySparse", name: name, keywords: dict); + int _idx = 0; + var sparse_indices = op.outputs[_idx++]; + var sparse_values = op.outputs[_idx++]; + var sparse_shape = op.outputs[_idx++]; + return (sparse_indices, sparse_values, sparse_shape); + } + + /// + /// Deserialize SparseTensor objects. + /// + /// + /// The serialized SparseTensor objects. The last dimension + /// must have 3 columns. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DeserializeSparse'. + /// + /// + /// Optional argument + /// The dtype of the serialized SparseTensor objects. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sparse_indices : + /// sparse_values : + /// sparse_shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The input serialized_sparse must have the shape [?, ?, ..., ?, 3] where + /// the last dimension stores serialized SparseTensor objects and the other N + /// dimensions (N &gt;= 0) correspond to a batch. The ranks of the original + /// SparseTensor objects must all match. When the final SparseTensor is + /// created, its rank is the rank of the incoming SparseTensor objects plus N; + /// the sparse tensors have been concatenated along new dimensions, one for each + /// batch. + /// + /// The output SparseTensor object's shape values for the original dimensions + /// are the max across the input SparseTensor objects' shape values for the + /// corresponding dimensions. The new dimensions match the size of the batch. + /// + /// The input SparseTensor objects' indices are assumed ordered in + /// standard lexicographic order. If this is not the case, after this + /// step run SparseReorder to restore index ordering. + /// + /// For example, if the serialized input is a [2 x 3] matrix representing two + /// original SparseTensor objects: + /// + /// index = [ 0] + /// [10] + /// [20] + /// values = [1, 2, 3] + /// shape = [50] + /// + /// and + /// + /// index = [ 2] + /// [10] + /// values = [4, 5] + /// shape = [30] + /// + /// then the final deserialized SparseTensor will be: + /// + /// index = [0 0] + /// [0 10] + /// [0 20] + /// [1 2] + /// [1 10] + /// values = [1, 2, 3, 4, 5] + /// shape = [2 50] + /// + public static (Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape) deserialize_sparse(Tensor serialized_sparse, TF_DataType dtype, string name = "DeserializeSparse") + { + var dict = new Dictionary(); + dict["serialized_sparse"] = serialized_sparse; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("DeserializeSparse", name: name, keywords: dict); + int _idx = 0; + var sparse_indices = op.outputs[_idx++]; + var sparse_values = op.outputs[_idx++]; + var sparse_shape = op.outputs[_idx++]; + return (sparse_indices, sparse_values, sparse_shape); + } + + /// + /// Deletes the resource specified by the handle. + /// + /// + /// handle to the resource to delete. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DestroyResourceOp'. + /// + /// + /// whether to ignore the error when the resource + /// doesn't exist. + /// + /// + /// Returns the description of the operation + /// + /// + /// All subsequent operations using the resource will result in a NotFound + /// error status. + /// + public static Operation destroy_resource_op(Tensor resource, bool? ignore_lookup_error = null, string name = "DestroyResourceOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + if (ignore_lookup_error.HasValue) + dict["ignore_lookup_error"] = ignore_lookup_error.Value; + var op = tf.OpDefLib._apply_op_helper("DestroyResourceOp", name: name, keywords: dict); + return op; + } + + /// + /// Destroys the temporary variable and returns its final value. + /// + /// + /// A reference to the temporary variable tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DestroyTemporaryVariable'. + /// + /// + /// Optional argument + /// Name of the temporary variable, usually the name of the matching + /// 'TemporaryVariable' op. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Sets output to the value of the Tensor pointed to by 'ref', then destroys + /// the temporary variable called 'var_name'. + /// All other uses of 'ref' *must* have executed before this op. + /// This is typically achieved by chaining the ref through each assign op, or by + /// using control dependencies. + /// + /// Outputs the final value of the tensor pointed to by 'ref'. + /// + public static Tensor destroy_temporary_variable(Tensor referecne, string var_name, string name = "DestroyTemporaryVariable") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["var_name"] = var_name; + var op = tf.OpDefLib._apply_op_helper("DestroyTemporaryVariable", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a diagonal tensor with a given diagonal values. + /// + /// + /// Rank k tensor where k is at most 1. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Diag'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a diagonal, this operation returns a tensor with the diagonal and + /// everything else padded with zeros. The diagonal is computed as follows: + /// + /// Assume diagonal has dimensions [D1,..., Dk], then the output is a tensor of + /// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: + /// + /// output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik] and 0 everywhere else. + /// + /// For example: + /// + /// + /// # 'diagonal' is [1, 2, 3, 4] + /// tf.diag(diagonal) ==&gt; [[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]] + /// + /// + public static Tensor diag(Tensor diagonal, string name = "Diag") + { + var dict = new Dictionary(); + dict["diagonal"] = diagonal; + var op = tf.OpDefLib._apply_op_helper("Diag", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the diagonal part of the tensor. + /// + /// + /// Rank k tensor where k is even and not zero. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DiagPart'. + /// + /// + /// The extracted diagonal. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns a tensor with the diagonal part + /// of the input. The diagonal part is computed as follows: + /// + /// Assume input has dimensions [D1,..., Dk, D1,..., Dk], then the output is a + /// tensor of rank k with dimensions [D1,..., Dk] where: + /// + /// diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]. + /// + /// For example: + /// + /// + /// # 'input' is [[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]] + /// + /// tf.diag_part(input) ==&gt; [1, 2, 3, 4] + /// + /// + public static Tensor diag_part(Tensor input, string name = "DiagPart") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("DiagPart", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes Psi, the derivative of Lgamma (the log of the absolute value of + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Digamma'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Gamma(x)), element-wise. + /// + public static Tensor digamma(Tensor x, string name = "Digamma") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Digamma", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the grayscale dilation of 4-D input and 3-D filter tensors. + /// + /// + /// 4-D with shape [batch, in_height, in_width, depth]. + /// + /// + /// 3-D with shape [filter_height, filter_width, depth]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Dilation2D'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// tensor. Must be: [1, stride_height, stride_width, 1]. + /// + /// + /// Optional argument + /// The input stride for atrous morphological dilation. Must be: + /// [1, rate_height, rate_width, 1]. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// 4-D with shape [batch, out_height, out_width, depth]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input tensor has shape [batch, in_height, in_width, depth] and the + /// filter tensor has shape [filter_height, filter_width, depth], i.e., each + /// input channel is processed independently of the others with its own structuring + /// function. The output tensor has shape + /// [batch, out_height, out_width, depth]. The spatial dimensions of the output + /// tensor depend on the padding algorithm. We currently only support the default + /// "NHWC" data_format. + /// + /// In detail, the grayscale morphological 2-D dilation is the max-sum correlation + /// (for consistency with conv2d, we use unmirrored filters): + /// + /// output[b, y, x, c] = + /// max_{dy, dx} input[b, + /// strides[1] * y + rates[1] * dy, + /// strides[2] * x + rates[2] * dx, + /// c] + + /// filter[dy, dx, c] + /// + /// Max-pooling is a special case when the filter has size equal to the pooling + /// kernel size and contains all zeros. + /// + /// Note on duality: The dilation of input by the filter is equal to the + /// negation of the erosion of -input by the reflected filter. + /// + public static Tensor dilation2d(Tensor input, Tensor filter, int[] strides, int[] rates, string padding, string name = "Dilation2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["strides"] = strides; + dict["rates"] = rates; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("Dilation2D", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of morphological 2-D dilation with respect to the filter. + /// + /// + /// 4-D with shape [batch, in_height, in_width, depth]. + /// + /// + /// 3-D with shape [filter_height, filter_width, depth]. + /// + /// + /// 4-D with shape [batch, out_height, out_width, depth]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Dilation2DBackpropFilter'. + /// + /// + /// Optional argument + /// 1-D of length 4. The stride of the sliding window for each dimension of + /// the input tensor. Must be: [1, stride_height, stride_width, 1]. + /// + /// + /// Optional argument + /// 1-D of length 4. The input stride for atrous morphological dilation. + /// Must be: [1, rate_height, rate_width, 1]. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// 3-D with shape [filter_height, filter_width, depth]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor dilation2d_backprop_filter(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string name = "Dilation2DBackpropFilter") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["rates"] = rates; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("Dilation2DBackpropFilter", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of morphological 2-D dilation with respect to the input. + /// + /// + /// 4-D with shape [batch, in_height, in_width, depth]. + /// + /// + /// 3-D with shape [filter_height, filter_width, depth]. + /// + /// + /// 4-D with shape [batch, out_height, out_width, depth]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Dilation2DBackpropInput'. + /// + /// + /// Optional argument + /// 1-D of length 4. The stride of the sliding window for each dimension of + /// the input tensor. Must be: [1, stride_height, stride_width, 1]. + /// + /// + /// Optional argument + /// 1-D of length 4. The input stride for atrous morphological dilation. + /// Must be: [1, rate_height, rate_width, 1]. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// 4-D with shape [batch, in_height, in_width, depth]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor dilation2d_backprop_input(Tensor input, Tensor filter, Tensor out_backprop, int[] strides, int[] rates, string padding, string name = "Dilation2DBackpropInput") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["out_backprop"] = out_backprop; + dict["strides"] = strides; + dict["rates"] = rates; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("Dilation2DBackpropInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x / y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Div'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Div supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor div(Tensor x, Tensor y, string name = "Div") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Div", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns 0 if the denominator is zero. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DivNoNan'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// + /// *NOTE*: DivNoNan supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor div_no_nan(Tensor x, Tensor y, string name = "DivNoNan") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("DivNoNan", name: name, keywords: dict); + return op.output; + } + + /// + /// Draw bounding boxes on a batch of images. + /// + /// + /// 4-D with shape [batch, height, width, depth]. A batch of images. + /// + /// + /// 3-D with shape [batch, num_bounding_boxes, 4] containing bounding + /// boxes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DrawBoundingBoxes'. + /// + /// + /// 4-D with the same shape as images. The batch of input images with + /// bounding boxes drawn on the images. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs a copy of images but draws on top of the pixels zero or more bounding + /// boxes specified by the locations in boxes. The coordinates of the each + /// bounding box in boxes are encoded as [y_min, x_min, y_max, x_max]. The + /// bounding box coordinates are floats in [0.0, 1.0] relative to the width and + /// height of the underlying image. + /// + /// For example, if an image is 100 x 200 pixels (height x width) and the bounding + /// box is [0.1, 0.2, 0.5, 0.9], the upper-left and bottom-right coordinates of + /// the bounding box will be (40, 10) to (180, 50) (in (x,y) coordinates). + /// + /// Parts of the bounding box may fall outside the image. + /// + public static Tensor draw_bounding_boxes(Tensor images, Tensor boxes, string name = "DrawBoundingBoxes") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["boxes"] = boxes; + var op = tf.OpDefLib._apply_op_helper("DrawBoundingBoxes", name: name, keywords: dict); + return op.output; + } + + /// + /// Partitions data into num_partitions tensors using indices from partitions. + /// + /// + /// + /// + /// Any shape. Indices in the range [0, num_partitions). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DynamicPartition'. + /// + /// + /// Optional argument + /// The number of partitions to output. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For each index tuple js of size partitions.ndim, the slice data[js, ...] + /// becomes part of outputs[partitions[js]]. The slices with partitions[js] = i + /// are placed in outputs[i] in lexicographic order of js, and the first + /// dimension of outputs[i] is the number of entries in partitions equal to i. + /// In detail, + /// + /// + /// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] + /// + /// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) + /// + /// + /// data.shape must start with partitions.shape. + /// + /// For example: + /// + /// + /// # Scalar partitions. + /// partitions = 1 + /// num_partitions = 2 + /// data = [10, 20] + /// outputs[0] = [] # Empty with shape [0, 2] + /// outputs[1] = [[10, 20]] + /// + /// # Vector partitions. + /// partitions = [0, 0, 1, 1, 0] + /// num_partitions = 2 + /// data = [10, 20, 30, 40, 50] + /// outputs[0] = [10, 20, 50] + /// outputs[1] = [30, 40] + /// + /// + /// See dynamic_stitch for an example on how to merge partitions back. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/DynamicPartition.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor[] dynamic_partition(Tensor data, Tensor partitions, int num_partitions, string name = "DynamicPartition") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["partitions"] = partitions; + dict["num_partitions"] = num_partitions; + var op = tf.OpDefLib._apply_op_helper("DynamicPartition", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// Interleave the values from the data tensors into a single tensor. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DynamicStitch'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Builds a merged tensor such that + /// + /// + /// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] + /// + /// + /// For example, if each indices[m] is scalar or vector, we have + /// + /// + /// # Scalar indices: + /// merged[indices[m], ...] = data[m][...] + /// + /// # Vector indices: + /// merged[indices[m][i], ...] = data[m][i, ...] + /// + /// + /// Each data[i].shape must start with the corresponding indices[i].shape, + /// and the rest of data[i].shape must be constant w.r.t. i. That is, we + /// must have data[i].shape = indices[i].shape + constant. In terms of this + /// constant, the output shape is + /// + /// merged.shape = [max(indices)] + constant + /// + /// Values are merged in order, so if an index appears in both indices[m][i] and + /// indices[n][j] for (m,i) &lt; (n,j) the slice data[n][j] will appear in the + /// merged result. If you do not need this guarantee, ParallelDynamicStitch might + /// perform better on some devices. + /// + /// For example: + /// + /// + /// indices[0] = 6 + /// indices[1] = [4, 1] + /// indices[2] = [[5, 2], [0, 3]] + /// data[0] = [61, 62] + /// data[1] = [[41, 42], [11, 12]] + /// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] + /// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], + /// [51, 52], [61, 62]] + /// + /// + /// This method can be used to merge partitions created by dynamic_partition + /// as illustrated on the following example: + /// + /// + /// # Apply function (increments x_i) on elements for which a certain condition + /// # apply (x_i != -1 in this example). + /// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) + /// condition_mask=tf.not_equal(x,tf.constant(-1.)) + /// partitioned_data = tf.dynamic_partition( + /// x, tf.cast(condition_mask, tf.int32) , 2) + /// partitioned_data[1] = partitioned_data[1] + 1.0 + /// condition_indices = tf.dynamic_partition( + /// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) + /// x = tf.dynamic_stitch(condition_indices, partitioned_data) + /// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain + /// # unchanged. + /// + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/DynamicStitch.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = "DynamicStitch") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("DynamicStitch", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the (possibly normalized) Levenshtein Edit Distance. + /// + /// + /// The indices of the hypothesis list SparseTensor. + /// This is an N x R int64 matrix. + /// + /// + /// The values of the hypothesis list SparseTensor. + /// This is an N-length vector. + /// + /// + /// The shape of the hypothesis list SparseTensor. + /// This is an R-length vector. + /// + /// + /// The indices of the truth list SparseTensor. + /// This is an M x R int64 matrix. + /// + /// + /// The values of the truth list SparseTensor. + /// This is an M-length vector. + /// + /// + /// truth indices, vector. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EditDistance'. + /// + /// + /// boolean (if true, edit distances are normalized by length of truth). + /// + /// The output is: + /// + /// + /// A dense float tensor with rank R - 1. + /// + /// For the example input: + /// + /// // hypothesis represents a 2x1 matrix with variable-length values: + /// // (0,0) = ["a"] + /// // (1,0) = ["b"] + /// hypothesis_indices = [[0, 0, 0], + /// [1, 0, 0]] + /// hypothesis_values = ["a", "b"] + /// hypothesis_shape = [2, 1, 1] + /// + /// // truth represents a 2x2 matrix with variable-length values: + /// // (0,0) = [] + /// // (0,1) = ["a"] + /// // (1,0) = ["b", "c"] + /// // (1,1) = ["a"] + /// truth_indices = [[0, 1, 0], + /// [1, 0, 0], + /// [1, 0, 1], + /// [1, 1, 0]] + /// truth_values = ["a", "b", "c", "a"] + /// truth_shape = [2, 2, 2] + /// normalize = true + /// + /// The output will be: + /// + /// // output is a 2x2 matrix with edit distances normalized by truth lengths. + /// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis + /// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The inputs are variable-length sequences provided by SparseTensors + /// (hypothesis_indices, hypothesis_values, hypothesis_shape) + /// and + /// (truth_indices, truth_values, truth_shape). + /// + /// The inputs are: + /// + public static Tensor edit_distance(Tensor hypothesis_indices, Tensor hypothesis_values, Tensor hypothesis_shape, Tensor truth_indices, Tensor truth_values, Tensor truth_shape, bool? normalize = null, string name = "EditDistance") + { + var dict = new Dictionary(); + dict["hypothesis_indices"] = hypothesis_indices; + dict["hypothesis_values"] = hypothesis_values; + dict["hypothesis_shape"] = hypothesis_shape; + dict["truth_indices"] = truth_indices; + dict["truth_values"] = truth_values; + dict["truth_shape"] = truth_shape; + if (normalize.HasValue) + dict["normalize"] = normalize.Value; + var op = tf.OpDefLib._apply_op_helper("EditDistance", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes exponential linear: exp(features) - 1 if &lt; 0, features otherwise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Elu'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) + /// ](http://arxiv.org/abs/1511.07289) + /// + public static Tensor elu(Tensor features, string name = "Elu") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Elu", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients for the exponential linear (Elu) operation. + /// + /// + /// The backpropagated gradients to the corresponding Elu operation. + /// + /// + /// The outputs of the corresponding Elu operation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EluGrad'. + /// + /// + /// The gradients: gradients * (outputs + 1) if outputs &lt; 0, + /// gradients otherwise. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor elu_grad(Tensor gradients, Tensor outputs, string name = "EluGrad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["outputs"] = outputs; + var op = tf.OpDefLib._apply_op_helper("EluGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a tensor with the given shape. + /// + /// This operation creates a tensor of shape and dtype. + /// + /// + /// 1-D. Represents the shape of the output tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Empty'. + /// + /// + /// Optional argument + /// + /// + /// If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. + /// + /// + /// A Tensor of type T. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor empty(Tensor shape, TF_DataType dtype, bool? init = null, string name = "Empty") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (init.HasValue) + dict["init"] = init.Value; + var op = tf.OpDefLib._apply_op_helper("Empty", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates and returns an empty tensor list. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EmptyTensorList'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// All list elements must be tensors of dtype element_dtype and shape compatible + /// with element_shape. + /// + /// handle: an empty tensor list. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + public static Tensor empty_tensor_list(Tensor element_shape, TF_DataType element_dtype, string name = "EmptyTensorList") + { + var dict = new Dictionary(); + dict["element_shape"] = element_shape; + dict["element_dtype"] = element_dtype; + var op = tf.OpDefLib._apply_op_helper("EmptyTensorList", name: name, keywords: dict); + return op.output; + } + + /// + /// Encode strings into web-safe base64 format. + /// + /// + /// Strings to be encoded. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EncodeBase64'. + /// + /// + /// Bool whether padding is applied at the ends. + /// + /// + /// Input strings encoded in base64. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Refer to the following article for more information on base64 format: + /// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the + /// end so that the encoded has length multiple of 4. See Padding section of the + /// link above. + /// + /// Web-safe means that the encoder uses - and _ instead of + and /. + /// + public static Tensor encode_base64(Tensor input, bool? pad = null, string name = "EncodeBase64") + { + var dict = new Dictionary(); + dict["input"] = input; + if (pad.HasValue) + dict["pad"] = pad.Value; + var op = tf.OpDefLib._apply_op_helper("EncodeBase64", name: name, keywords: dict); + return op.output; + } + + /// + /// JPEG-encode an image. + /// + /// + /// 3-D with shape [height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EncodeJpeg'. + /// + /// + /// Per pixel image format. + /// + /// + /// Quality of the compression from 0 to 100 (higher is better and slower). + /// + /// + /// If True, create a JPEG that loads progressively (coarse to fine). + /// + /// + /// If True, spend CPU/RAM to reduce size with no quality change. + /// + /// + /// See http://en.wikipedia.org/wiki/Chroma_subsampling. + /// + /// + /// Unit used to specify x_density and y_density: + /// pixels per inch ('in') or centimeter ('cm'). + /// + /// + /// Horizontal pixels per density unit. + /// + /// + /// Vertical pixels per density unit. + /// + /// + /// If not empty, embed this XMP metadata in the image header. + /// + /// + /// 0-D. JPEG-encoded image. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// image is a 3-D uint8 Tensor of shape [height, width, channels]. + /// + /// The attr format can be used to override the color format of the encoded + /// output. Values can be: + /// + /// * '': Use a default format based on the number of channels in the image. + /// * grayscale: Output a grayscale JPEG image. The channels dimension + /// of image must be 1. + /// * rgb: Output an RGB JPEG image. The channels dimension + /// of image must be 3. + /// + /// If format is not specified or is the empty string, a default format is picked + /// in function of the number of channels in image: + /// + /// * 1: Output a grayscale image. + /// * 3: Output an RGB image. + /// + public static Tensor encode_jpeg(Tensor image, string format = null, int? quality = null, bool? progressive = null, bool? optimize_size = null, bool? chroma_downsampling = null, string density_unit = null, int? x_density = null, int? y_density = null, string xmp_metadata = null, string name = "EncodeJpeg") + { + var dict = new Dictionary(); + dict["image"] = image; + if (format != null) + dict["format"] = format; + if (quality.HasValue) + dict["quality"] = quality.Value; + if (progressive.HasValue) + dict["progressive"] = progressive.Value; + if (optimize_size.HasValue) + dict["optimize_size"] = optimize_size.Value; + if (chroma_downsampling.HasValue) + dict["chroma_downsampling"] = chroma_downsampling.Value; + if (density_unit != null) + dict["density_unit"] = density_unit; + if (x_density.HasValue) + dict["x_density"] = x_density.Value; + if (y_density.HasValue) + dict["y_density"] = y_density.Value; + if (xmp_metadata != null) + dict["xmp_metadata"] = xmp_metadata; + var op = tf.OpDefLib._apply_op_helper("EncodeJpeg", name: name, keywords: dict); + return op.output; + } + + public static Tensor encode_jpeg_variable_quality(Tensor image, Tensor quality) + { + throw new NotImplementedException(""); + } + + /// + /// PNG-encode an image. + /// + /// + /// 3-D with shape [height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EncodePng'. + /// + /// + /// Compression level. + /// + /// + /// 0-D. PNG-encoded image. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// image is a 3-D uint8 or uint16 Tensor of shape [height, width, channels] + /// where channels is: + /// + /// * 1: for grayscale. + /// * 2: for grayscale + alpha. + /// * 3: for RGB. + /// * 4: for RGBA. + /// + /// The ZLIB compression level, compression, can be -1 for the PNG-encoder + /// default or a value from 0 to 9. 9 is the highest compression level, generating + /// the smallest output, but is slower. + /// + public static Tensor encode_png(Tensor image, int? compression = null, string name = "EncodePng") + { + var dict = new Dictionary(); + dict["image"] = image; + if (compression.HasValue) + dict["compression"] = compression.Value; + var op = tf.OpDefLib._apply_op_helper("EncodePng", name: name, keywords: dict); + return op.output; + } + + /// + /// The op serializes protobuf messages provided in the input tensors. + /// + /// + /// Tensor of int32 with shape [batch_shape, len(field_names)]. + /// + /// + /// List of tensors containing values for the corresponding field. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EncodeProto'. + /// + /// + /// Optional argument + /// List of strings containing proto field names. + /// + /// + /// Optional argument + /// Name of the proto message type to decode. + /// + /// + /// + /// + /// Tensor of serialized protos with shape batch_shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The types of the tensors in values must match the schema for the + /// fields specified in field_names. All the tensors in values must + /// have a common shape prefix, *batch_shape*. + /// + /// The sizes tensor specifies repeat counts for each field. The repeat + /// count (last dimension) of a each tensor in values must be greater + /// than or equal to corresponding repeat count in sizes. + /// + /// A message_type name must be provided to give context for the field + /// names. The actual message descriptor can be looked up either in the + /// linked-in descriptor pool or a filename provided by the caller using + /// the descriptor_source attribute. + /// + /// The descriptor_source attribute selects a source of protocol + /// descriptors to consult when looking up message_type. This may be a + /// filename containing a serialized FileDescriptorSet message, + /// or the special value local://, in which case only descriptors linked + /// into the code will be searched; the filename can be on any filesystem + /// accessible to TensorFlow. + /// + /// You can build a descriptor_source file using the --descriptor_set_out + /// and --include_imports options to the protocol compiler protoc. + /// + /// The local:// database only covers descriptors linked into the + /// code via C++ libraries, not Python imports. You can link in a proto descriptor + /// by creating a cc_library target with alwayslink=1. + /// + /// There are a few special cases in the value mapping: + /// + /// Submessage and group fields must be pre-serialized as TensorFlow strings. + /// + /// TensorFlow lacks support for unsigned int64s, so they must be + /// represented as tf.int64 with the same twos-complement bit pattern + /// (the obvious way). + /// + /// Unsigned int32 values can be represented exactly with tf.int64, or + /// with sign wrapping if the input is of type tf.int32. + /// + public static Tensor encode_proto(Tensor sizes, Tensor[] values, string[] field_names, string message_type, string descriptor_source = null, string name = "EncodeProto") + { + var dict = new Dictionary(); + dict["sizes"] = sizes; + dict["values"] = values; + dict["field_names"] = field_names; + dict["message_type"] = message_type; + if (descriptor_source != null) + dict["descriptor_source"] = descriptor_source; + var op = tf.OpDefLib._apply_op_helper("EncodeProto", name: name, keywords: dict); + return op.output; + } + + /// + /// Encode audio data using the WAV file format. + /// + /// + /// 2-D with shape [length, channels]. + /// + /// + /// Scalar containing the sample frequency. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EncodeWav'. + /// + /// + /// 0-D. WAV-encoded file contents. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation will generate a string suitable to be saved out to create a .wav + /// audio file. It will be encoded in the 16-bit PCM format. It takes in float + /// values in the range -1.0f to 1.0f, and any outside that value will be clamped to + /// that range. + /// + /// audio is a 2-D float Tensor of shape [length, channels]. + /// sample_rate is a scalar Tensor holding the rate to use (e.g. 44100). + /// + public static Tensor encode_wav(Tensor audio, Tensor sample_rate, string name = "EncodeWav") + { + var dict = new Dictionary(); + dict["audio"] = audio; + dict["sample_rate"] = sample_rate; + var op = tf.OpDefLib._apply_op_helper("EncodeWav", name: name, keywords: dict); + return op.output; + } + + /// + /// Ensures that the tensor's shape matches the expected shape. + /// + /// + /// A tensor, whose shape is to be validated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'EnsureShape'. + /// + /// + /// Optional argument + /// The expected (possibly partially specified) shape of the input tensor. + /// + /// + /// A tensor with the same shape and contents as the input tensor or value. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Raises an error if the input tensor's shape does not match the specified shape. + /// Returns the input tensor otherwise. + /// + public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "EnsureShape", name, input, shape)); + return _result[0]; + } + catch (Exception) + { + + } + try + { + return ensure_shape_eager_fallback(input, shape, name, ctx); + } + catch (Exception) + { + + } + } + + var dict = new Dictionary(); + dict["input"] = input; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); + if (_execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return op.output; + } + + public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) + { + object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; + var _result = _execute.execute("EnsureShape", 1, new Tensor[] { input }, + attrs, ctx, name); + if (_execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return _result[0]; + } + + /// + /// Creates or finds a child frame, and makes data available to the child frame. + /// + /// + /// The tensor to be made available to the child frame. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Enter'. + /// + /// + /// Optional argument + /// The name of the child frame. + /// + /// + /// If true, the output is constant within the child frame. + /// + /// + /// The number of iterations allowed to run in parallel. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op is used together with Exit to create loops in the graph. + /// The unique frame_name is used by the Executor to identify frames. If + /// is_constant is true, output is a constant in the child frame; otherwise + /// it may be changed in the child frame. At most parallel_iterations iterations + /// are run in parallel in the child frame. + /// + public static Tensor enter(Tensor data, string frame_name, bool? is_constant = null, int? parallel_iterations = null, string name = "Enter") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["frame_name"] = frame_name; + if (is_constant.HasValue) + dict["is_constant"] = is_constant.Value; + if (parallel_iterations.HasValue) + dict["parallel_iterations"] = parallel_iterations.Value; + var op = tf.OpDefLib._apply_op_helper("Enter", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x == y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Equal'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Equal supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor equal(Tensor x, Tensor y, string name = "Equal") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Equal", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the Gauss error function of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Erf'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor erf(Tensor x, string name = "Erf") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Erf", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the complementary error function of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Erfc'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor erfc(Tensor x, string name = "Erfc") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Erfc", name: name, keywords: dict); + return op.output; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// The tensor to be made available to the parent frame. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Exit'. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Exit makes its input data available to the parent frame. + /// + public static Tensor exit(Tensor data, string name = "Exit") + { + var dict = new Dictionary(); + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("Exit", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes exponential of x element-wise. \\(y = e^x\\). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Exp'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor exp(Tensor x, string name = "Exp") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Exp", name: name, keywords: dict); + return op.output; + } + + /// + /// Inserts a dimension of 1 into a tensor's shape. + /// + /// + /// + /// + /// 0-D (scalar). Specifies the dimension index at which to + /// expand the shape of input. Must be in the range + /// [-rank(input) - 1, rank(input)]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ExpandDims'. + /// + /// + /// Contains the same data as input, but its shape has an additional + /// dimension of size 1 added. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input, this operation inserts a dimension of 1 at the + /// dimension index axis of input's shape. The dimension index axis starts at + /// zero; if you specify a negative number for axis it is counted backward from + /// the end. + /// + /// This operation is useful if you want to add a batch dimension to a single + /// element. For example, if you have a single image of shape [height, width, + /// channels], you can make it a batch of 1 image with expand_dims(image, 0), + /// which will make the shape [1, height, width, channels]. + /// + /// Other examples: + /// + /// + /// # 't' is a tensor of shape [2] + /// shape(expand_dims(t, 0)) ==&gt; [1, 2] + /// shape(expand_dims(t, 1)) ==&gt; [2, 1] + /// shape(expand_dims(t, -1)) ==&gt; [2, 1] + /// + /// # 't2' is a tensor of shape [2, 3, 5] + /// shape(expand_dims(t2, 0)) ==&gt; [1, 2, 3, 5] + /// shape(expand_dims(t2, 2)) ==&gt; [2, 3, 1, 5] + /// shape(expand_dims(t2, 3)) ==&gt; [2, 3, 5, 1] + /// + /// + /// This operation requires that: + /// + /// -1-input.dims() &lt;= dim &lt;= input.dims() + /// + /// This operation is related to squeeze(), which removes dimensions of + /// size 1. + /// + public static Tensor expand_dims(Tensor input, Tensor dim, string name = "ExpandDims") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["dim"] = dim; + var op = tf.OpDefLib._apply_op_helper("ExpandDims", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes exponential of x - 1 element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Expm1'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = (\exp x) - 1\\). + /// + public static Tensor expm1(Tensor x, string name = "Expm1") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Expm1", name: name, keywords: dict); + return op.output; + } + + /// + /// Extracts a glimpse from the input tensor. + /// + /// + /// A 4-D float tensor of shape [batch_size, height, width, channels]. + /// + /// + /// A 1-D tensor of 2 elements containing the size of the glimpses + /// to extract. The glimpse height must be specified first, following + /// by the glimpse width. + /// + /// + /// A 2-D integer tensor of shape [batch_size, 2] containing + /// the y, x locations of the center of each window. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ExtractGlimpse'. + /// + /// + /// indicates if the offset coordinates are centered relative to + /// the image, in which case the (0, 0) offset is relative to the center + /// of the input images. If false, the (0,0) offset corresponds to the + /// upper left corner of the input images. + /// + /// + /// indicates if the offset coordinates are normalized. + /// + /// + /// indicates if the noise should be generated using a + /// uniform distribution or a Gaussian distribution. + /// + /// + /// A tensor representing the glimpses [batch_size, + /// glimpse_height, glimpse_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Returns a set of windows called glimpses extracted at location + /// offsets from the input tensor. If the windows only partially + /// overlaps the inputs, the non overlapping areas will be filled with + /// random noise. + /// + /// The result is a 4-D tensor of shape [batch_size, glimpse_height, + /// glimpse_width, channels]. The channels and batch dimensions are the + /// same as that of the input tensor. The height and width of the output + /// windows are specified in the size parameter. + /// + /// The argument normalized and centered controls how the windows are built: + /// + /// * If the coordinates are normalized but not centered, 0.0 and 1.0 + /// correspond to the minimum and maximum of each height and width + /// dimension. + /// * If the coordinates are both normalized and centered, they range from + /// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper + /// left corner, the lower right corner is located at (1.0, 1.0) and the + /// center is at (0, 0). + /// * If the coordinates are not normalized they are interpreted as + /// numbers of pixels. + /// + public static Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool? centered = null, bool? normalized = null, bool? uniform_noise = null, string name = "ExtractGlimpse") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["size"] = size; + dict["offsets"] = offsets; + if (centered.HasValue) + dict["centered"] = centered.Value; + if (normalized.HasValue) + dict["normalized"] = normalized.Value; + if (uniform_noise.HasValue) + dict["uniform_noise"] = uniform_noise.Value; + var op = tf.OpDefLib._apply_op_helper("ExtractGlimpse", name: name, keywords: dict); + return op.output; + } + + /// + /// Extract patches from images and put them in the "depth" output dimension. + /// + /// + /// 4-D Tensor with shape [batch, in_rows, in_cols, depth]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ExtractImagePatches'. + /// + /// + /// Optional argument + /// The size of the sliding window for each dimension of images. + /// + /// + /// Optional argument + /// 1-D of length 4. How far the centers of two consecutive patches are in + /// the images. Must be: [1, stride_rows, stride_cols, 1]. + /// + /// + /// Optional argument + /// 1-D of length 4. Must be: [1, rate_rows, rate_cols, 1]. This is the + /// input stride, specifying how far two consecutive patch samples are in the + /// input. Equivalent to extracting patches with + /// patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1), followed by + /// subsampling them spatially by a factor of rates. This is equivalent to + /// rate in dilated (a.k.a. Atrous) convolutions. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// We specify the size-related attributes as: + /// + /// + /// ksizes = [1, ksize_rows, ksize_cols, 1] + /// strides = [1, strides_rows, strides_cols, 1] + /// rates = [1, rates_rows, rates_cols, 1] + /// + /// + /// + /// 4-D Tensor with shape [batch, out_rows, out_cols, ksize_rows * + /// ksize_cols * depth] containing image patches with size + /// ksize_rows x ksize_cols x depth vectorized in the "depth" dimension. Note + /// out_rows and out_cols are the dimensions of the output patches. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor extract_image_patches(Tensor images, int[] ksizes, int[] strides, int[] rates, string padding, string name = "ExtractImagePatches") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["ksizes"] = ksizes; + dict["strides"] = strides; + dict["rates"] = rates; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("ExtractImagePatches", name: name, keywords: dict); + return op.output; + } + + /// + /// Extract the shape information of a JPEG-encoded image. + /// + /// + /// 0-D. The JPEG-encoded image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ExtractJpegShape'. + /// + /// + /// (Optional) The output type of the operation (int32 or int64). + /// Defaults to int32. + /// + /// + /// 1-D. The image shape with format [height, width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op only parses the image header, so it is much faster than DecodeJpeg. + /// + public static Tensor extract_jpeg_shape(Tensor contents, TF_DataType? output_type = null, string name = "ExtractJpegShape") + { + var dict = new Dictionary(); + dict["contents"] = contents; + if (output_type.HasValue) + dict["output_type"] = output_type.Value; + var op = tf.OpDefLib._apply_op_helper("ExtractJpegShape", name: name, keywords: dict); + return op.output; + } + + /// + /// Fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FFT'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most + /// dimension of input is replaced with its 1D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.fft + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 1-dimensional discrete Fourier transform over the inner-most + /// dimension of input. + /// + public static Tensor f_f_t(Tensor input, string name = "FFT") + { + return tf.Context.ExecuteOp("FFT", name, new ExecuteOpArgs(input)); + } + + /// + /// 2D fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FFT2D'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most 2 + /// dimensions of input are replaced with their 2D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.fft2 + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 2-dimensional discrete Fourier transform over the inner-most + /// 2 dimensions of input. + /// + public static Tensor f_f_t2d(Tensor input, string name = "FFT2D") + { + return tf.Context.ExecuteOp("FFT2D", name, new ExecuteOpArgs(input)); + } + + /// + /// 3D fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FFT3D'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most 3 + /// dimensions of input are replaced with their 3D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.fftn with 3 dimensions. + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 + /// dimensions of input. + /// + public static Tensor f_f_t3d(Tensor input, string name = "FFT3D") + { + return tf.Context.ExecuteOp("FFT3D", name, new ExecuteOpArgs(input)); + } + + /// + /// A queue that produces elements in first-in first-out order. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FIFOQueue'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor f_i_f_o_queue(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, string container = null, string shared_name = null, string name = "FIFOQueue") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("FIFOQueue", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that produces elements in first-in first-out order. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FIFOQueueV2'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor f_i_f_o_queue_v2(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, string container = null, string shared_name = null, string name = "FIFOQueueV2") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("FIFOQueueV2", name: name, keywords: dict); + return op.output; + } + + /// + /// This op is used as a placeholder in If branch functions. It doesn't provide a + /// valid output when run, so must either be removed (e.g. replaced with a + /// function input) or guaranteed not to be used (e.g. if mirroring an + /// intermediate output needed for the gradient computation of the other branch). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeParam'. + /// + /// + /// Optional argument + /// The type of the output. + /// + /// + /// Optional argument + /// The purported shape of the output. This is only used for shape inference; + /// the output will not necessarily have this shape. Can be a partial shape. + /// + /// + /// \"Fake\" output value. This should not be consumed by another op. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fake_param(TF_DataType dtype, Shape shape, string name = "FakeParam") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("FakeParam", name: name, keywords: dict); + return op.output; + } + + /// + /// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxArgs'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Attributes [min; max] define the clamping range for the inputs data. + /// inputs values are quantized into the quantization range ([0; 2^num_bits - 1] + /// when narrow_range is false and [1; 2^num_bits - 1] when it is true) and + /// then de-quantized and output as floats in [min; max] interval. + /// num_bits is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// Quantization is called fake since the output is still in floating point. + /// + public static Tensor fake_quant_with_min_max_args(Tensor inputs, float? min = null, float? max = null, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxArgs") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + if (min.HasValue) + dict["min"] = min.Value; + if (max.HasValue) + dict["max"] = max.Value; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxArgs", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute gradients for a FakeQuantWithMinMaxArgs operation. + /// + /// + /// Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. + /// + /// + /// Values passed as inputs to the FakeQuantWithMinMaxArgs operation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxArgsGradient'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: + /// gradients * (inputs &gt;= min && inputs &lt;= max). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fake_quant_with_min_max_args_gradient(Tensor gradients, Tensor inputs, float? min = null, float? max = null, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxArgsGradient") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["inputs"] = inputs; + if (min.HasValue) + dict["min"] = min.Value; + if (max.HasValue) + dict["max"] = max.Value; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxArgsGradient", name: name, keywords: dict); + return op.output; + } + + /// + /// Fake-quantize the 'inputs' tensor of type float via global float scalars min + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxVars'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// and max to 'outputs' tensor of same shape as inputs. + /// + /// [min; max] define the clamping range for the inputs data. + /// inputs values are quantized into the quantization range ([0; 2^num_bits - 1] + /// when narrow_range is false and [1; 2^num_bits - 1] when it is true) and + /// then de-quantized and output as floats in [min; max] interval. + /// num_bits is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// This operation has a gradient and thus allows for training min and max + /// values. + /// + public static Tensor fake_quant_with_min_max_vars(Tensor inputs, Tensor min, Tensor max, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxVars") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["min"] = min; + dict["max"] = max; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVars", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute gradients for a FakeQuantWithMinMaxVars operation. + /// + /// + /// Backpropagated gradients above the FakeQuantWithMinMaxVars operation. + /// + /// + /// Values passed as inputs to the FakeQuantWithMinMaxVars operation. + /// min, max: Quantization interval, scalar floats. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxVarsGradient'. + /// + /// + /// The bitwidth of the quantization; between 2 and 8, inclusive. + /// + /// + /// Whether to quantize into 2^num_bits - 1 distinct values. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// backprops_wrt_input : Backpropagated gradients w.r.t. inputs: + /// gradients * (inputs &gt;= min && inputs &lt;= max). + /// backprop_wrt_min : Backpropagated gradients w.r.t. min parameter: + /// sum(gradients * (inputs &lt; min)). + /// backprop_wrt_max : Backpropagated gradients w.r.t. max parameter: + /// sum(gradients * (inputs &gt; max)). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor backprops_wrt_input, Tensor backprop_wrt_min, Tensor backprop_wrt_max) fake_quant_with_min_max_vars_gradient(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxVarsGradient") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["inputs"] = inputs; + dict["min"] = min; + dict["max"] = max; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsGradient", name: name, keywords: dict); + int _idx = 0; + var backprops_wrt_input = op.outputs[_idx++]; + var backprop_wrt_min = op.outputs[_idx++]; + var backprop_wrt_max = op.outputs[_idx++]; + return (backprops_wrt_input, backprop_wrt_min, backprop_wrt_max); + } + + /// + /// Fake-quantize the 'inputs' tensor of type float and one of the shapes: [d], + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxVarsPerChannel'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// [b, d] [b, h, w, d] via per-channel floats min and max of shape [d] + /// to 'outputs' tensor of same shape as inputs. + /// + /// [min; max] define the clamping range for the inputs data. + /// inputs values are quantized into the quantization range ([0; 2^num_bits - 1] + /// when narrow_range is false and [1; 2^num_bits - 1] when it is true) and + /// then de-quantized and output as floats in [min; max] interval. + /// num_bits is the bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// This operation has a gradient and thus allows for training min and max + /// values. + /// + public static Tensor fake_quant_with_min_max_vars_per_channel(Tensor inputs, Tensor min, Tensor max, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxVarsPerChannel") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["min"] = min; + dict["max"] = max; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsPerChannel", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. + /// + /// + /// Backpropagated gradients above the FakeQuantWithMinMaxVars operation, + /// shape one of: [d], [b, d], [b, h, w, d]. + /// + /// + /// Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape + /// same as gradients. + /// min, max: Quantization interval, floats of shape [d]. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQuantWithMinMaxVarsPerChannelGradient'. + /// + /// + /// The bitwidth of the quantization; between 2 and 16, inclusive. + /// + /// + /// Whether to quantize into 2^num_bits - 1 distinct values. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// backprops_wrt_input : Backpropagated gradients w.r.t. inputs, shape same as + /// inputs: + /// gradients * (inputs &gt;= min && inputs &lt;= max). + /// backprop_wrt_min : Backpropagated gradients w.r.t. min parameter, shape [d]: + /// sum_per_d(gradients * (inputs &lt; min)). + /// backprop_wrt_max : Backpropagated gradients w.r.t. max parameter, shape [d]: + /// sum_per_d(gradients * (inputs &gt; max)). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor backprops_wrt_input, Tensor backprop_wrt_min, Tensor backprop_wrt_max) fake_quant_with_min_max_vars_per_channel_gradient(Tensor gradients, Tensor inputs, Tensor min, Tensor max, int? num_bits = null, bool? narrow_range = null, string name = "FakeQuantWithMinMaxVarsPerChannelGradient") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["inputs"] = inputs; + dict["min"] = min; + dict["max"] = max; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (narrow_range.HasValue) + dict["narrow_range"] = narrow_range.Value; + var op = tf.OpDefLib._apply_op_helper("FakeQuantWithMinMaxVarsPerChannelGradient", name: name, keywords: dict); + int _idx = 0; + var backprops_wrt_input = op.outputs[_idx++]; + var backprop_wrt_min = op.outputs[_idx++]; + var backprop_wrt_max = op.outputs[_idx++]; + return (backprops_wrt_input, backprop_wrt_min, backprop_wrt_max); + } + + /// + /// Deprecated. Do not use. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FakeQueue'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fake_queue(Tensor resource, string name = "FakeQueue") + { + var dict = new Dictionary(); + dict["resource"] = resource; + var op = tf.OpDefLib._apply_op_helper("FakeQueue", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a tensor filled with a scalar value. + /// + /// + /// 1-D. Represents the shape of the output tensor. + /// + /// + /// 0-D (scalar). Value to fill the returned tensor. + /// + /// @compatibility(numpy) + /// Equivalent to np.full + /// @end_compatibility + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Fill'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation creates a tensor of shape dims and fills it with value. + /// + /// For example: + /// + /// + /// # Output tensor has shape [2, 3]. + /// fill([2, 3], 9) ==&gt; [[9, 9, 9] + /// [9, 9, 9]] + /// + /// + /// tf.fill differs from tf.constant in a few ways: + /// + /// * tf.fill only supports scalar contents, whereas tf.constant supports + /// Tensor values. + /// * tf.fill creates an Op in the computation graph that constructs the actual + /// Tensor value at runtime. This is in contrast to tf.constant which embeds + /// the entire Tensor into the graph with a Const node. + /// * Because tf.fill evaluates at graph runtime, it supports dynamic shapes + /// based on other runtime Tensors, unlike tf.constant. + /// + public static Tensor fill(Tensor dims, Tensor value, string name = "Fill") + { + var dict = new Dictionary(); + dict["dims"] = dims; + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("Fill", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset containing elements of first component of input_dataset having true in the last component. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FilterByLastComponentDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor filter_by_last_component_dataset(Tensor input_dataset, TF_DataType[] output_types, Shape[] output_shapes, string name = "FilterByLastComponentDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("FilterByLastComponentDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that emits the records from one or more binary files. + /// + /// + /// A scalar or a vector containing the name(s) of the file(s) to be + /// read. + /// + /// + /// A scalar representing the number of bytes to skip at the + /// beginning of a file. + /// + /// + /// A scalar representing the number of bytes in each record. + /// + /// + /// A scalar representing the number of bytes to skip at the end + /// of a file. + /// + /// + /// A scalar representing the number of bytes to buffer. Must be &gt; 0. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FixedLengthRecordDataset'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fixed_length_record_dataset(Tensor filenames, Tensor header_bytes, Tensor record_bytes, Tensor footer_bytes, Tensor buffer_size, string name = "FixedLengthRecordDataset") + { + var dict = new Dictionary(); + dict["filenames"] = filenames; + dict["header_bytes"] = header_bytes; + dict["record_bytes"] = record_bytes; + dict["footer_bytes"] = footer_bytes; + dict["buffer_size"] = buffer_size; + var op = tf.OpDefLib._apply_op_helper("FixedLengthRecordDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs fixed-length records from a file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FixedLengthRecordReader'. + /// + /// + /// Optional argument + /// Number of bytes in the record. + /// + /// + /// Number of bytes in the header, defaults to 0. + /// + /// + /// Number of bytes in the footer, defaults to 0. + /// + /// + /// Number of bytes to hop before each read. Default of 0 means using + /// record_bytes. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fixed_length_record_reader(int record_bytes, int? header_bytes = null, int? footer_bytes = null, int? hop_bytes = null, string container = null, string shared_name = null, string name = "FixedLengthRecordReader") + { + var dict = new Dictionary(); + dict["record_bytes"] = record_bytes; + if (header_bytes.HasValue) + dict["header_bytes"] = header_bytes.Value; + if (footer_bytes.HasValue) + dict["footer_bytes"] = footer_bytes.Value; + if (hop_bytes.HasValue) + dict["hop_bytes"] = hop_bytes.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("FixedLengthRecordReader", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs fixed-length records from a file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FixedLengthRecordReaderV2'. + /// + /// + /// Optional argument + /// Number of bytes in the record. + /// + /// + /// Number of bytes in the header, defaults to 0. + /// + /// + /// Number of bytes in the footer, defaults to 0. + /// + /// + /// Number of bytes to hop before each read. Default of 0 means using + /// record_bytes. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The type of encoding for the file. Currently ZLIB and GZIP + /// are supported. Defaults to none. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fixed_length_record_reader_v2(int record_bytes, int? header_bytes = null, int? footer_bytes = null, int? hop_bytes = null, string container = null, string shared_name = null, string encoding = null, string name = "FixedLengthRecordReaderV2") + { + var dict = new Dictionary(); + dict["record_bytes"] = record_bytes; + if (header_bytes.HasValue) + dict["header_bytes"] = header_bytes.Value; + if (footer_bytes.HasValue) + dict["footer_bytes"] = footer_bytes.Value; + if (hop_bytes.HasValue) + dict["hop_bytes"] = hop_bytes.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (encoding != null) + dict["encoding"] = encoding; + var op = tf.OpDefLib._apply_op_helper("FixedLengthRecordReaderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a learned unigram distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FixedUnigramCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to randomly sample. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// Optional argument + /// The sampler will sample integers from the interval [0, range_max). + /// + /// + /// Each valid line in this file (which should have a CSV-like format) + /// corresponds to a valid word ID. IDs are in sequential order, starting from + /// num_reserved_ids. The last entry in each line is expected to be a value + /// corresponding to the count or relative probability. Exactly one of vocab_file + /// and unigrams needs to be passed to this op. + /// + /// + /// The distortion is used to skew the unigram probability distribution. + /// Each weight is first raised to the distortion's power before adding to the + /// internal unigram distribution. As a result, distortion = 1.0 gives regular + /// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives + /// a uniform distribution. + /// + /// + /// Optionally some reserved IDs can be added in the range [0, + /// ..., num_reserved_ids) by the users. One use case is that a special unknown + /// word token is used as ID 0. These IDs will have a sampling probability of 0. + /// + /// + /// A sampler can be used to sample from a subset of the original range + /// in order to speed up the whole computation through parallelism. This parameter + /// (together with 'shard') indicates the number of partitions that are being + /// used in the overall computation. + /// + /// + /// A sampler can be used to sample from a subset of the original range + /// in order to speed up the whole computation through parallelism. This parameter + /// (together with 'num_shards') indicates the particular partition number of a + /// sampler op, when partitioning is being used. + /// + /// + /// A list of unigram counts or probabilities, one per ID in sequential + /// order. Exactly one of vocab_file and unigrams should be passed to this op. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// A unigram sampler could use a fixed unigram distribution read from a + /// file or passed in as an in-memory array instead of building up the distribution + /// from data on the fly. There is also an option to skew the distribution by + /// applying a distortion power to the weights. + /// + /// The vocabulary file should be in CSV-like format, with the last field + /// being the weight associated with the word. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) fixed_unigram_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int range_max, string vocab_file = null, float? distortion = null, int? num_reserved_ids = null, int? num_shards = null, int? shard = null, float[] unigrams = null, int? seed = null, int? seed2 = null, string name = "FixedUnigramCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + dict["range_max"] = range_max; + if (vocab_file != null) + dict["vocab_file"] = vocab_file; + if (distortion.HasValue) + dict["distortion"] = distortion.Value; + if (num_reserved_ids.HasValue) + dict["num_reserved_ids"] = num_reserved_ids.Value; + if (num_shards.HasValue) + dict["num_shards"] = num_shards.Value; + if (shard.HasValue) + dict["shard"] = shard.Value; + if (unigrams != null) + dict["unigrams"] = unigrams; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("FixedUnigramCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Returns element-wise largest integer not greater than x. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Floor'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor floor(Tensor x, string name = "Floor") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Floor", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x // y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FloorDiv'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: FloorDiv supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor floor_div(Tensor x, Tensor y, string name = "FloorDiv") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("FloorDiv", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns element-wise remainder of division. When x &lt; 0 xor y &lt; 0 is + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FloorMod'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// true, this follows Python semantics in that the result here is consistent + /// with a flooring divide. E.g. floor(x / y) * y + mod(x, y) = x. + /// + /// *NOTE*: FloorMod supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor floor_mod(Tensor x, Tensor y, string name = "FloorMod") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("FloorMod", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs fractional average pooling on the input. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FractionalAvgPool'. + /// + /// + /// Optional argument + /// Pooling ratio for each dimension of value, currently only + /// supports row and col dimension and should be &gt;= 1.0. For example, a valid + /// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + /// must be 1.0 because we don't allow pooling on batch and channels + /// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + /// respectively. + /// + /// + /// When set to True, generates the pooling sequence in a + /// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + /// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for + /// difference between pseudorandom and random. + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// index 0 1 2 3 4 + /// + /// value 20 5 16 3 7 + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [41/3, 26/3] for fractional avg pooling. + /// + /// + /// When set to True, a fixed pooling region will be used when + /// iterating over a FractionalAvgPool node in the computation graph. Mainly used + /// in unit test to make FractionalAvgPool deterministic. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : output tensor after fractional avg pooling. + /// row_pooling_sequence : row pooling sequence, needed to calculate gradient. + /// col_pooling_sequence : column pooling sequence, needed to calculate gradient. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Fractional average pooling is similar to Fractional max pooling in the pooling + /// region generation step. The only difference is that after pooling regions are + /// generated, a mean operation is performed instead of a max operation in each + /// pooling region. + /// + public static (Tensor output, Tensor row_pooling_sequence, Tensor col_pooling_sequence) fractional_avg_pool(Tensor value, float[] pooling_ratio, bool? pseudo_random = null, bool? overlapping = null, bool? deterministic = null, int? seed = null, int? seed2 = null, string name = "FractionalAvgPool") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["pooling_ratio"] = pooling_ratio; + if (pseudo_random.HasValue) + dict["pseudo_random"] = pseudo_random.Value; + if (overlapping.HasValue) + dict["overlapping"] = overlapping.Value; + if (deterministic.HasValue) + dict["deterministic"] = deterministic.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("FractionalAvgPool", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var row_pooling_sequence = op.outputs[_idx++]; + var col_pooling_sequence = op.outputs[_idx++]; + return (output, row_pooling_sequence, col_pooling_sequence); + } + + /// + /// Computes gradient of the FractionalAvgPool function. + /// + /// + /// Original input tensor shape for fractional_avg_pool + /// + /// + /// 4-D with shape [batch, height, width, channels]. Gradients + /// w.r.t. the output of fractional_avg_pool. + /// + /// + /// row pooling sequence, form pooling region with + /// col_pooling_sequence. + /// + /// + /// column pooling sequence, form pooling region with + /// row_pooling sequence. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FractionalAvgPoolGrad'. + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// index 0 1 2 3 4 + /// + /// value 20 5 16 3 7 + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [41/3, 26/3] for fractional avg pooling. + /// + /// + /// 4-D. Gradients w.r.t. the input of fractional_avg_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for + /// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of + /// out_backprop to those indices that form the same pooling cell. Therefore, we + /// just need to know the shape of original input tensor, instead of the whole + /// tensor. + /// + public static Tensor fractional_avg_pool_grad(Tensor orig_input_tensor_shape, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool? overlapping = null, string name = "FractionalAvgPoolGrad") + { + var dict = new Dictionary(); + dict["orig_input_tensor_shape"] = orig_input_tensor_shape; + dict["out_backprop"] = out_backprop; + dict["row_pooling_sequence"] = row_pooling_sequence; + dict["col_pooling_sequence"] = col_pooling_sequence; + if (overlapping.HasValue) + dict["overlapping"] = overlapping.Value; + var op = tf.OpDefLib._apply_op_helper("FractionalAvgPoolGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs fractional max pooling on the input. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FractionalMaxPool'. + /// + /// + /// Optional argument + /// Pooling ratio for each dimension of value, currently only + /// supports row and col dimension and should be &gt;= 1.0. For example, a valid + /// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements + /// must be 1.0 because we don't allow pooling on batch and channels + /// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions + /// respectively. + /// + /// + /// When set to True, generates the pooling sequence in a + /// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin + /// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for + /// difference between pseudorandom and random. + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// index 0 1 2 3 4 + /// + /// value 20 5 16 3 7 + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [20, 16] for fractional max pooling. + /// + /// + /// When set to True, a fixed pooling region will be used when + /// iterating over a FractionalMaxPool node in the computation graph. Mainly used + /// in unit test to make FractionalMaxPool deterministic. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : output tensor after fractional max pooling. + /// row_pooling_sequence : row pooling sequence, needed to calculate gradient. + /// col_pooling_sequence : column pooling sequence, needed to calculate gradient. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Fractional max pooling is slightly different than regular max pooling. In + /// regular max pooling, you downsize an input set by taking the maximum value of + /// smaller N x N subsections of the set (often 2x2), and try to reduce the set by + /// a factor of N, where N is an integer. Fractional max pooling, as you might + /// expect from the word "fractional", means that the overall reduction ratio N + /// does not have to be an integer. + /// + /// The sizes of the pooling regions are generated randomly but are fairly uniform. + /// For example, let's look at the height dimension, and the constraints on the + /// list of rows that will be pool boundaries. + /// + /// First we define the following: + /// + /// 1. input_row_length : the number of rows from the input set + /// 2. output_row_length : which will be smaller than the input + /// 3. alpha = input_row_length / output_row_length : our reduction ratio + /// 4. K = floor(alpha) + /// 5. row_pooling_sequence : this is the result list of pool boundary rows + /// + /// Then, row_pooling_sequence should satisfy: + /// + /// 1. a[0] = 0 : the first value of the sequence is 0 + /// 2. a[end] = input_row_length : the last value of the sequence is the size + /// 3. K &lt;= (a[i+1] - a[i]) &lt;= K+1 : all intervals are K or K+1 size + /// 4. length(row_pooling_sequence) = output_row_length+1 + /// + /// For more details on fractional max pooling, see this paper: + /// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) + /// + public static (Tensor output, Tensor row_pooling_sequence, Tensor col_pooling_sequence) fractional_max_pool(Tensor value, float[] pooling_ratio, bool? pseudo_random = null, bool? overlapping = null, bool? deterministic = null, int? seed = null, int? seed2 = null, string name = "FractionalMaxPool") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["pooling_ratio"] = pooling_ratio; + if (pseudo_random.HasValue) + dict["pseudo_random"] = pseudo_random.Value; + if (overlapping.HasValue) + dict["overlapping"] = overlapping.Value; + if (deterministic.HasValue) + dict["deterministic"] = deterministic.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("FractionalMaxPool", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var row_pooling_sequence = op.outputs[_idx++]; + var col_pooling_sequence = op.outputs[_idx++]; + return (output, row_pooling_sequence, col_pooling_sequence); + } + + /// + /// Computes gradient of the FractionalMaxPool function. + /// + /// + /// Original input for fractional_max_pool + /// + /// + /// Original output for fractional_max_pool + /// + /// + /// 4-D with shape [batch, height, width, channels]. Gradients + /// w.r.t. the output of fractional_max_pool. + /// + /// + /// row pooling sequence, form pooling region with + /// col_pooling_sequence. + /// + /// + /// column pooling sequence, form pooling region with + /// row_pooling sequence. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FractionalMaxPoolGrad'. + /// + /// + /// When set to True, it means when pooling, the values at the boundary + /// of adjacent pooling cells are used by both cells. For example: + /// + /// index 0 1 2 3 4 + /// + /// value 20 5 16 3 7 + /// + /// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. + /// The result would be [20, 16] for fractional max pooling. + /// + /// + /// 4-D. Gradients w.r.t. the input of fractional_max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor fractional_max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor out_backprop, Tensor row_pooling_sequence, Tensor col_pooling_sequence, bool? overlapping = null, string name = "FractionalMaxPoolGrad") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["out_backprop"] = out_backprop; + dict["row_pooling_sequence"] = row_pooling_sequence; + dict["col_pooling_sequence"] = col_pooling_sequence; + if (overlapping.HasValue) + dict["overlapping"] = overlapping.Value; + var op = tf.OpDefLib._apply_op_helper("FractionalMaxPoolGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Batch normalization. + /// + /// + /// A 4D Tensor for input data. + /// + /// + /// A 1D Tensor for scaling factor, to scale the normalized x. + /// + /// + /// A 1D Tensor for offset, to shift to the normalized x. + /// + /// + /// A 1D Tensor for population mean. Used for inference only; + /// must be empty for training. + /// + /// + /// A 1D Tensor for population variance. Used for inference only; + /// must be empty for training. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedBatchNorm'. + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// The data format for x and y. Either "NHWC" (default) or "NCHW". + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : A 4D Tensor for output data. + /// batch_mean : A 1D Tensor for the computed batch mean, to be used by TensorFlow + /// to compute the running mean. + /// batch_variance : A 1D Tensor for the computed batch variance, to be used by + /// TensorFlow to compute the running variance. + /// reserve_space_1 : A 1D Tensor for the computed batch mean, to be reused + /// in the gradient computation. + /// reserve_space_2 : A 1D Tensor for the computed batch variance (inverted variance + /// in the cuDNN case), to be reused in the gradient computation. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + public static (Tensor y, Tensor batch_mean, Tensor batch_variance, Tensor reserve_space_1, Tensor reserve_space_2) fused_batch_norm(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float? epsilon = null, string data_format = null, bool? is_training = null, string name = "FusedBatchNorm") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["scale"] = scale; + dict["offset"] = offset; + dict["mean"] = mean; + dict["variance"] = variance; + if (epsilon.HasValue) + dict["epsilon"] = epsilon.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("FusedBatchNorm", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var batch_mean = op.outputs[_idx++]; + var batch_variance = op.outputs[_idx++]; + var reserve_space_1 = op.outputs[_idx++]; + var reserve_space_2 = op.outputs[_idx++]; + return (y, batch_mean, batch_variance, reserve_space_1, reserve_space_2); + } + + /// + /// Gradient for batch normalization. + /// + /// + /// A 4D Tensor for the gradient with respect to y. + /// + /// + /// A 4D Tensor for input data. + /// + /// + /// A 1D Tensor for scaling factor, to scale the normalized x. + /// + /// + /// When is_training is True, a 1D Tensor for the computed batch + /// mean to be reused in gradient computation. When is_training is + /// False, a 1D Tensor for the population mean to be reused in both + /// 1st and 2nd order gradient computation. + /// + /// + /// When is_training is True, a 1D Tensor for the computed batch + /// variance (inverted variance in the cuDNN case) to be reused in + /// gradient computation. When is_training is False, a 1D Tensor + /// for the population variance to be reused in both 1st and 2nd + /// order gradient computation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedBatchNormGrad'. + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// The data format for y_backprop, x, x_backprop. + /// Either "NHWC" (default) or "NCHW". + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// x_backprop : A 4D Tensor for the gradient with respect to x. + /// scale_backprop : A 1D Tensor for the gradient with respect to scale. + /// offset_backprop : A 1D Tensor for the gradient with respect to offset. + /// reserve_space_3 : Unused placeholder to match the mean input in FusedBatchNorm. + /// reserve_space_4 : Unused placeholder to match the variance input + /// in FusedBatchNorm. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + public static (Tensor x_backprop, Tensor scale_backprop, Tensor offset_backprop, Tensor reserve_space_3, Tensor reserve_space_4) fused_batch_norm_grad(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float? epsilon = null, string data_format = null, bool? is_training = null, string name = "FusedBatchNormGrad") + { + var dict = new Dictionary(); + dict["y_backprop"] = y_backprop; + dict["x"] = x; + dict["scale"] = scale; + dict["reserve_space_1"] = reserve_space_1; + dict["reserve_space_2"] = reserve_space_2; + if (epsilon.HasValue) + dict["epsilon"] = epsilon.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("FusedBatchNormGrad", name: name, keywords: dict); + int _idx = 0; + var x_backprop = op.outputs[_idx++]; + var scale_backprop = op.outputs[_idx++]; + var offset_backprop = op.outputs[_idx++]; + var reserve_space_3 = op.outputs[_idx++]; + var reserve_space_4 = op.outputs[_idx++]; + return (x_backprop, scale_backprop, offset_backprop, reserve_space_3, reserve_space_4); + } + + /// + /// Gradient for batch normalization. + /// + /// + /// A 4D Tensor for the gradient with respect to y. + /// + /// + /// A 4D Tensor for input data. + /// + /// + /// A 1D Tensor for scaling factor, to scale the normalized x. + /// + /// + /// When is_training is True, a 1D Tensor for the computed batch + /// mean to be reused in gradient computation. When is_training is + /// False, a 1D Tensor for the population mean to be reused in both + /// 1st and 2nd order gradient computation. + /// + /// + /// When is_training is True, a 1D Tensor for the computed batch + /// variance (inverted variance in the cuDNN case) to be reused in + /// gradient computation. When is_training is False, a 1D Tensor + /// for the population variance to be reused in both 1st and 2nd + /// order gradient computation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedBatchNormGradV2'. + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// The data format for y_backprop, x, x_backprop. + /// Either "NHWC" (default) or "NCHW". + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// x_backprop : A 4D Tensor for the gradient with respect to x. + /// scale_backprop : A 1D Tensor for the gradient with respect to scale. + /// offset_backprop : A 1D Tensor for the gradient with respect to offset. + /// reserve_space_3 : Unused placeholder to match the mean input in FusedBatchNorm. + /// reserve_space_4 : Unused placeholder to match the variance input + /// in FusedBatchNorm. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + public static (Tensor x_backprop, Tensor scale_backprop, Tensor offset_backprop, Tensor reserve_space_3, Tensor reserve_space_4) fused_batch_norm_grad_v2(Tensor y_backprop, Tensor x, Tensor scale, Tensor reserve_space_1, Tensor reserve_space_2, float? epsilon = null, string data_format = null, bool? is_training = null, string name = "FusedBatchNormGradV2") + { + var dict = new Dictionary(); + dict["y_backprop"] = y_backprop; + dict["x"] = x; + dict["scale"] = scale; + dict["reserve_space_1"] = reserve_space_1; + dict["reserve_space_2"] = reserve_space_2; + if (epsilon.HasValue) + dict["epsilon"] = epsilon.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("FusedBatchNormGradV2", name: name, keywords: dict); + int _idx = 0; + var x_backprop = op.outputs[_idx++]; + var scale_backprop = op.outputs[_idx++]; + var offset_backprop = op.outputs[_idx++]; + var reserve_space_3 = op.outputs[_idx++]; + var reserve_space_4 = op.outputs[_idx++]; + return (x_backprop, scale_backprop, offset_backprop, reserve_space_3, reserve_space_4); + } + + /// + /// Batch normalization. + /// + /// + /// A 4D Tensor for input data. + /// + /// + /// A 1D Tensor for scaling factor, to scale the normalized x. + /// + /// + /// A 1D Tensor for offset, to shift to the normalized x. + /// + /// + /// A 1D Tensor for population mean. Used for inference only; + /// must be empty for training. + /// + /// + /// A 1D Tensor for population variance. Used for inference only; + /// must be empty for training. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedBatchNormV2'. + /// + /// + /// A small float number added to the variance of x. + /// + /// + /// The data format for x and y. Either "NHWC" (default) or "NCHW". + /// + /// + /// A bool value to indicate the operation is for training (default) + /// or inference. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : A 4D Tensor for output data. + /// batch_mean : A 1D Tensor for the computed batch mean, to be used by TensorFlow + /// to compute the running mean. + /// batch_variance : A 1D Tensor for the computed batch variance, to be used by + /// TensorFlow to compute the running variance. + /// reserve_space_1 : A 1D Tensor for the computed batch mean, to be reused + /// in the gradient computation. + /// reserve_space_2 : A 1D Tensor for the computed batch variance (inverted variance + /// in the cuDNN case), to be reused in the gradient computation. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". + /// The size of 1D Tensors matches the dimension C of the 4D Tensors. + /// + public static (Tensor y, Tensor batch_mean, Tensor batch_variance, Tensor reserve_space_1, Tensor reserve_space_2) fused_batch_norm_v2(Tensor x, Tensor scale, Tensor offset, Tensor mean, Tensor variance, float? epsilon = null, string data_format = null, bool? is_training = null, string name = "FusedBatchNormV2") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["scale"] = scale; + dict["offset"] = offset; + dict["mean"] = mean; + dict["variance"] = variance; + if (epsilon.HasValue) + dict["epsilon"] = epsilon.Value; + if (data_format != null) + dict["data_format"] = data_format; + if (is_training.HasValue) + dict["is_training"] = is_training.Value; + var op = tf.OpDefLib._apply_op_helper("FusedBatchNormV2", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var batch_mean = op.outputs[_idx++]; + var batch_variance = op.outputs[_idx++]; + var reserve_space_1 = op.outputs[_idx++]; + var reserve_space_2 = op.outputs[_idx++]; + return (y, batch_mean, batch_variance, reserve_space_1, reserve_space_2); + } + + /// + /// Performs a padding as a preprocess during a convolution. + /// + /// + /// 4-D with shape [batch, in_height, in_width, in_channels]. + /// + /// + /// A two-column matrix specifying the padding sizes. The number of + /// rows must be the same as the rank of input. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedPadConv2D'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of input. Must be in the same order as the dimension specified with format. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Similar to FusedResizeAndPadConv2d, this op allows for an optimized + /// implementation where the spatial padding transformation stage is fused with the + /// im2col lookup, but in this case without the bilinear filtering required for + /// resizing. Fusing the padding prevents the need to write out the intermediate + /// results as whole tensors, reducing memory pressure, and we can get some latency + /// gains by merging the transformation calculations. + /// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' + /// order is used instead. + /// Internally this op uses a single per-graph scratch buffer, which means that it + /// will block if multiple versions are being run in parallel. This is because this + /// operator is primarily an optimization to minimize memory usage. + /// + public static Tensor fused_pad_conv2d(Tensor input, Tensor paddings, Tensor filter, string mode, int[] strides, string padding, string name = "FusedPadConv2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + dict["filter"] = filter; + dict["mode"] = mode; + dict["strides"] = strides; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("FusedPadConv2D", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs a resize and padding as a preprocess during a convolution. + /// + /// + /// 4-D with shape [batch, in_height, in_width, in_channels]. + /// + /// + /// A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// A two-column matrix specifying the padding sizes. The number of + /// rows must be the same as the rank of input. + /// + /// + /// 4-D with shape + /// [filter_height, filter_width, in_channels, out_channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'FusedResizeAndPadConv2D'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// 1-D of length 4. The stride of the sliding window for each dimension + /// of input. Must be in the same order as the dimension specified with format. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It's often possible to do spatial transformations more efficiently as part of + /// the packing stage of a convolution, so this op allows for an optimized + /// implementation where these stages are fused together. This prevents the need to + /// write out the intermediate results as whole tensors, reducing memory pressure, + /// and we can get some latency gains by merging the transformation calculations. + /// The data_format attribute for Conv2D isn't supported by this op, and defaults to + /// 'NHWC' order. + /// Internally this op uses a single per-graph scratch buffer, which means that it + /// will block if multiple versions are being run in parallel. This is because this + /// operator is primarily an optimization to minimize memory usage. + /// + public static Tensor fused_resize_and_pad_conv2d(Tensor input, Tensor size, Tensor paddings, Tensor filter, string mode, int[] strides, string padding, bool? resize_align_corners = null, string name = "FusedResizeAndPadConv2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["size"] = size; + dict["paddings"] = paddings; + dict["filter"] = filter; + dict["mode"] = mode; + dict["strides"] = strides; + dict["padding"] = padding; + if (resize_align_corners.HasValue) + dict["resize_align_corners"] = resize_align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("FusedResizeAndPadConv2D", name: name, keywords: dict); + return op.output; + } + + /// + /// Gather slices from params according to indices. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Gather'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// indices must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape indices.shape + params.shape[1:] where: + /// + /// + /// # Scalar indices + /// output[:, ..., :] = params[indices, :, ... :] + /// + /// # Vector indices + /// output[i, :, ..., :] = params[indices[i], :, ... :] + /// + /// # Higher rank indices + /// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] + /// + /// + /// If indices is a permutation and len(indices) == params.shape[0] then + /// this operation will permute params accordingly. + /// + /// validate_indices: DEPRECATED. If this operation is assigned to CPU, values in + /// indices are always validated to be within range. If assigned to GPU, + /// out-of-bound indices result in safe but unspecified behavior, which may include + /// raising an error. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor gather(Tensor parameters, Tensor indices, bool? validate_indices = null, string name = "Gather") + { + var dict = new Dictionary(); + dict["params"] = parameters; + dict["indices"] = indices; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("Gather", name: name, keywords: dict); + return op.output; + } + + /// + /// Gather slices from params into a Tensor with shape specified by indices. + /// + /// + /// The tensor from which to gather values. + /// + /// + /// Index tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GatherNd'. + /// + /// + /// Values from params gathered from indices given by indices, with + /// shape indices.shape[:-1] + params.shape[indices.shape[-1]:]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// indices is an K-dimensional integer tensor, best thought of as a + /// (K-1)-dimensional tensor of indices into params, where each element defines a + /// slice of params: + /// + /// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]] + /// + /// Whereas in tf.gather indices defines slices into the first + /// dimension of params, in tf.gather_nd, indices defines slices into the + /// first N dimensions of params, where N = indices.shape[-1]. + /// + /// The last dimension of indices can be at most the rank of + /// params: + /// + /// indices.shape[-1] &lt;= params.rank + /// + /// The last dimension of indices corresponds to elements + /// (if indices.shape[-1] == params.rank) or slices + /// (if indices.shape[-1] &lt; params.rank) along dimension indices.shape[-1] + /// of params. The output tensor has shape + /// + /// indices.shape[:-1] + params.shape[indices.shape[-1]:] + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, a 0 is stored in the + /// corresponding output value. + /// + /// Some examples below. + /// + /// Simple indexing into a matrix: + /// + /// + /// indices = [[0, 0], [1, 1]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = ['a', 'd'] + /// + /// + /// Slice indexing into a matrix: + /// + /// + /// indices = [[1], [0]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [['c', 'd'], ['a', 'b']] + /// + /// + /// Indexing into a 3-tensor: + /// + /// + /// indices = [[1]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[['a1', 'b1'], ['c1', 'd1']]] + /// + /// + /// indices = [[0, 1], [1, 0]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [['c0', 'd0'], ['a1', 'b1']] + /// + /// + /// indices = [[0, 0, 1], [1, 0, 1]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = ['b0', 'b1'] + /// + /// + /// Batched indexing into a matrix: + /// + /// + /// indices = [[[0, 0]], [[0, 1]]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [['a'], ['b']] + /// + /// + /// Batched slice indexing into a matrix: + /// + /// + /// indices = [[[1]], [[0]]] + /// params = [['a', 'b'], ['c', 'd']] + /// output = [[['c', 'd']], [['a', 'b']]] + /// + /// + /// Batched indexing into a 3-tensor: + /// + /// + /// indices = [[[1]], [[0]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[[['a1', 'b1'], ['c1', 'd1']]], + /// [[['a0', 'b0'], ['c0', 'd0']]]] + /// + /// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [[['c0', 'd0'], ['a1', 'b1']], + /// [['a0', 'b0'], ['c1', 'd1']]] + /// + /// + /// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] + /// params = [[['a0', 'b0'], ['c0', 'd0']], + /// [['a1', 'b1'], ['c1', 'd1']]] + /// output = [['b0', 'b1'], ['d0', 'c1']] + /// + /// + /// See also tf.gather and tf.batch_gather. + /// + public static Tensor gather_nd(Tensor parameters, Tensor indices, string name = "GatherNd") + { + var dict = new Dictionary(); + dict["params"] = parameters; + dict["indices"] = indices; + var op = tf.OpDefLib._apply_op_helper("GatherNd", name: name, keywords: dict); + return op.output; + } + + /// + /// Gather slices from params axis axis according to indices. + /// + /// + /// The tensor from which to gather values. Must be at least rank + /// axis + 1. + /// + /// + /// Index tensor. Must be in range [0, params.shape[axis]). + /// + /// + /// The axis in params to gather indices from. Defaults to the first + /// dimension. Supports negative indexes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GatherV2'. + /// + /// + /// Values from params gathered from indices given by indices, with + /// shape params.shape[:axis] + indices.shape + params.shape[axis + 1:]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// indices must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape params.shape[:axis] + indices.shape + + /// params.shape[axis + 1:] where: + /// + /// + /// # Scalar indices (output is rank(params) - 1). + /// output[a_0, ..., a_n, b_0, ..., b_n] = + /// params[a_0, ..., a_n, indices, b_0, ..., b_n] + /// + /// # Vector indices (output is rank(params)). + /// output[a_0, ..., a_n, i, b_0, ..., b_n] = + /// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] + /// + /// # Higher rank indices (output is rank(params) + rank(indices) - 1). + /// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = + /// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] + /// + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt&gt; + /// &lt;/div&gt; + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, a 0 is stored in the + /// corresponding output value. + /// + /// See also tf.batch_gather and tf.gather_nd. + /// + public static Tensor gather_v2(Tensor parameters, Tensor indices, Tensor axis, string name = "GatherV2") + { + var dict = new Dictionary(); + dict["params"] = parameters; + dict["indices"] = indices; + dict["axis"] = axis; + var op = tf.OpDefLib._apply_op_helper("GatherV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Given a path to new and old vocabulary files, returns a remapping Tensor of + /// + /// + /// Path to the new vocab file. + /// + /// + /// Path to the old vocab file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GenerateVocabRemapping'. + /// + /// + /// Optional argument + /// How many entries into the new vocab file to start reading. + /// + /// + /// Optional argument + /// Number of entries in the new vocab file to remap. + /// + /// + /// Number of entries in the old vocab file to consider. If -1, + /// use the entire old vocabulary. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// remapping : A Tensor of length num_new_vocab where the element at index i + /// is equal to the old ID that maps to the new ID i. This element is -1 for any + /// new ID that is not found in the old vocabulary. + /// num_present : Number of new vocab entries found in old vocab. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// length num_new_vocab, where remapping[i] contains the row number in the old + /// vocabulary that corresponds to row i in the new vocabulary (starting at line + /// new_vocab_offset and up to num_new_vocab entities), or -1 if entry i + /// in the new vocabulary is not in the old vocabulary. The old vocabulary is + /// constrained to the first old_vocab_size entries if old_vocab_size is not the + /// default value of -1. + /// + /// num_vocab_offset enables + /// use in the partitioned variable case, and should generally be set through + /// examining partitioning info. The format of the files should be a text file, + /// with each line containing a single entity within the vocabulary. + /// + /// For example, with new_vocab_file a text file containing each of the following + /// elements on a single line: [f0, f1, f2, f3], old_vocab_file = [f1, f0, f3], + /// num_new_vocab = 3, new_vocab_offset = 1, the returned remapping would be + /// [0, -1, 2]. + /// + /// The op also returns a count of how many entries in the new vocabulary + /// were present in the old vocabulary, which is used to calculate the number of + /// values to initialize in a weight matrix remapping + /// + /// This functionality can be used to remap both row vocabularies (typically, + /// features) and column vocabularies (typically, classes) from TensorFlow + /// checkpoints. Note that the partitioning logic relies on contiguous vocabularies + /// corresponding to div-partitioned variables. Moreover, the underlying remapping + /// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should + /// use the corresponding index_table_from_file() as the FeatureColumn framework + /// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). + /// + public static (Tensor remapping, Tensor num_present) generate_vocab_remapping(Tensor new_vocab_file, Tensor old_vocab_file, int new_vocab_offset, int num_new_vocab, int? old_vocab_size = null, string name = "GenerateVocabRemapping") + { + var dict = new Dictionary(); + dict["new_vocab_file"] = new_vocab_file; + dict["old_vocab_file"] = old_vocab_file; + dict["new_vocab_offset"] = new_vocab_offset; + dict["num_new_vocab"] = num_new_vocab; + if (old_vocab_size.HasValue) + dict["old_vocab_size"] = old_vocab_size.Value; + var op = tf.OpDefLib._apply_op_helper("GenerateVocabRemapping", name: name, keywords: dict); + int _idx = 0; + var remapping = op.outputs[_idx++]; + var num_present = op.outputs[_idx++]; + return (remapping, num_present); + } + + /// + /// Store the input tensor in the state of the current session. + /// + /// + /// The tensor to be stored. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GetSessionHandle'. + /// + /// + /// The handle for the tensor stored in the session state, represented + /// as a string. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor get_session_handle(Tensor value, string name = "GetSessionHandle") + { + var dict = new Dictionary(); + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("GetSessionHandle", name: name, keywords: dict); + return op.output; + } + + /// + /// Store the input tensor in the state of the current session. + /// + /// + /// The tensor to be stored. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GetSessionHandleV2'. + /// + /// + /// The handle for the tensor stored in the session state, represented + /// as a ResourceHandle object. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor get_session_handle_v2(Tensor value, string name = "GetSessionHandleV2") + { + var dict = new Dictionary(); + dict["value"] = value; + var op = tf.OpDefLib._apply_op_helper("GetSessionHandleV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Get the value of the tensor specified by its handle. + /// + /// + /// The handle for a tensor stored in the session state. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GetSessionTensor'. + /// + /// + /// Optional argument + /// The type of the output value. + /// + /// + /// The tensor for the given handle. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor get_session_tensor(Tensor handle, TF_DataType dtype, string name = "GetSessionTensor") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("GetSessionTensor", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x &gt; y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Greater'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Greater supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor greater(Tensor x, Tensor y, string name = "Greater") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Greater", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x &gt;= y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GreaterEqual'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: GreaterEqual supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor greater_equal(Tensor x, Tensor y, string name = "GreaterEqual") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("GreaterEqual", name: name, keywords: dict); + return op.output; + } + + /// + /// Gives a guarantee to the TF runtime that the input tensor is a constant. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'GuaranteeConst'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The runtime is then free to make optimizations based on this. + /// + /// Only accepts value typed tensors as inputs and rejects resource variable handles + /// as input. + /// + /// Returns the input tensor without modification. + /// + public static Tensor guarantee_const(Tensor input, string name = "GuaranteeConst") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("GuaranteeConst", name: name, keywords: dict); + return op.output; + } + + /// + /// Convert one or more images from HSV to RGB. + /// + /// + /// 1-D or higher rank. HSV data to convert. Last dimension must be size 3. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HSVToRGB'. + /// + /// + /// images converted to RGB. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs a tensor of the same shape as the images tensor, containing the RGB + /// value of the pixels. The output is only well defined if the value in images + /// are in [0,1]. + /// + /// See rgb_to_hsv for a description of the HSV encoding. + /// + public static Tensor h_s_v_to_r_g_b(Tensor images, string name = "HSVToRGB") + { + var dict = new Dictionary(); + dict["images"] = images; + var op = tf.OpDefLib._apply_op_helper("HSVToRGB", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a non-initialized hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HashTable'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// If true and shared_name is empty, the table is shared + /// using the node name. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a hash table, specifying the type of its keys and values. + /// Before using the table you will have to initialize it. After initialization the + /// table will be immutable. + /// + public static Tensor hash_table(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, string name = "HashTable") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + var op = tf.OpDefLib._apply_op_helper("HashTable", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a non-initialized hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HashTableV2'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// If true and shared_name is empty, the table is shared + /// using the node name. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a hash table, specifying the type of its keys and values. + /// Before using the table you will have to initialize it. After initialization the + /// table will be immutable. + /// + public static Tensor hash_table_v2(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, string name = "HashTableV2") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + var op = tf.OpDefLib._apply_op_helper("HashTableV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Return histogram of values. + /// + /// + /// Numeric Tensor. + /// + /// + /// Shape [2] Tensor of same dtype as values. + /// values &lt;= value_range[0] will be mapped to hist[0], + /// values &gt;= value_range[1] will be mapped to hist[-1]. + /// + /// + /// Scalar int32 Tensor. Number of histogram bins. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HistogramFixedWidth'. + /// + /// + /// + /// + /// A 1-D Tensor holding histogram of values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given the tensor values, this operation returns a rank 1 histogram counting + /// the number of entries in values that fall into every bin. The bins are + /// equal width and determined by the arguments value_range and nbins. + /// + /// + /// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + /// nbins = 5 + /// value_range = [0.0, 5.0] + /// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + /// + /// with tf.get_default_session() as sess: + /// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) + /// variables.global_variables_initializer().run() + /// sess.run(hist) =&gt; [2, 1, 1, 0, 2] + /// + /// + public static Tensor histogram_fixed_width(Tensor values, Tensor value_range, Tensor nbins, TF_DataType? dtype = null, string name = "HistogramFixedWidth") + { + var dict = new Dictionary(); + dict["values"] = values; + dict["value_range"] = value_range; + dict["nbins"] = nbins; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("HistogramFixedWidth", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with a histogram. + /// + /// + /// Scalar. Tag to use for the Summary.Value. + /// + /// + /// Any shape. Values to use to build the histogram. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HistogramSummary'. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated + /// [Summary](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + /// has one summary value containing a histogram for values. + /// + /// This op reports an InvalidArgument error if any value is not finite. + /// + public static Tensor histogram_summary(Tensor tag, Tensor values, string name = "HistogramSummary") + { + var dict = new Dictionary(); + dict["tag"] = tag; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("HistogramSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a constant tensor on the host. Only for writing C++ tests. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'HostConst'. + /// + /// + /// Optional argument + /// Attr value is the tensor to return. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor host_const(Tensor value, TF_DataType dtype, string name = "HostConst") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("HostConst", name: name, keywords: dict); + return op.output; + } + + /// + /// Inverse fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IFFT'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most + /// dimension of input is replaced with its inverse 1D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.ifft + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 1-dimensional discrete Fourier transform over the + /// inner-most dimension of input. + /// + public static Tensor i_f_f_t(Tensor input, string name = "IFFT") + { + return tf.Context.ExecuteOp("IFFT", name, new ExecuteOpArgs(input)); + } + + /// + /// Inverse 2D fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IFFT2D'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most 2 + /// dimensions of input are replaced with their inverse 2D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.ifft2 + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 2-dimensional discrete Fourier transform over the + /// inner-most 2 dimensions of input. + /// + public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D") + { + return tf.Context.ExecuteOp("IFFT2D", name, new ExecuteOpArgs(input)); + } + + /// + /// Inverse 3D fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IFFT3D'. + /// + /// + /// A complex64 tensor of the same shape as input. The inner-most 3 + /// dimensions of input are replaced with their inverse 3D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.ifftn with 3 dimensions. + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 3-dimensional discrete Fourier transform over the + /// inner-most 3 dimensions of input. + /// + public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D") + { + return tf.Context.ExecuteOp("IFFT3D", name, new ExecuteOpArgs(input)); + } + + /// + /// Inverse real-valued fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// An int32 tensor of shape [1]. The FFT length. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IRFFT'. + /// + /// + /// A float32 tensor of the same rank as input. The inner-most + /// dimension of input is replaced with the fft_length samples of its inverse + /// 1D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.irfft + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued + /// signal over the inner-most dimension of input. + /// + /// The inner-most dimension of input is assumed to be the result of RFFT: the + /// fft_length / 2 + 1 unique components of the DFT of a real-valued signal. If + /// fft_length is not provided, it is computed from the size of the inner-most + /// dimension of input (fft_length = 2 * (inner - 1)). If the FFT length used to + /// compute input is odd, it should be provided since it cannot be inferred + /// properly. + /// + /// Along the axis IRFFT is computed on, if fft_length / 2 + 1 is smaller + /// than the corresponding dimension of input, the dimension is cropped. If it is + /// larger, the dimension is padded with zeros. + /// + public static Tensor i_r_f_f_t(Tensor input, Tensor fft_length, string name = "IRFFT") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("IRFFT", name: name, keywords: dict); + return op.output; + } + + /// + /// Inverse 2D real-valued fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// An int32 tensor of shape [2]. The FFT length for each dimension. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IRFFT2D'. + /// + /// + /// A float32 tensor of the same rank as input. The inner-most 2 + /// dimensions of input are replaced with the fft_length samples of their + /// inverse 2D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.irfft2 + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued + /// signal over the inner-most 2 dimensions of input. + /// + /// The inner-most 2 dimensions of input are assumed to be the result of RFFT2D: + /// The inner-most dimension contains the fft_length / 2 + 1 unique components of + /// the DFT of a real-valued signal. If fft_length is not provided, it is computed + /// from the size of the inner-most 2 dimensions of input. If the FFT length used + /// to compute input is odd, it should be provided since it cannot be inferred + /// properly. + /// + /// Along each axis IRFFT2D is computed on, if fft_length (or + /// fft_length / 2 + 1 for the inner-most dimension) is smaller than the + /// corresponding dimension of input, the dimension is cropped. If it is larger, + /// the dimension is padded with zeros. + /// + public static Tensor i_r_f_f_t2d(Tensor input, Tensor fft_length, string name = "IRFFT2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("IRFFT2D", name: name, keywords: dict); + return op.output; + } + + /// + /// Inverse 3D real-valued fast Fourier transform. + /// + /// + /// A complex64 tensor. + /// + /// + /// An int32 tensor of shape [3]. The FFT length for each dimension. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IRFFT3D'. + /// + /// + /// A float32 tensor of the same rank as input. The inner-most 3 + /// dimensions of input are replaced with the fft_length samples of their + /// inverse 3D real Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.irfftn with 3 dimensions. + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued + /// signal over the inner-most 3 dimensions of input. + /// + /// The inner-most 3 dimensions of input are assumed to be the result of RFFT3D: + /// The inner-most dimension contains the fft_length / 2 + 1 unique components of + /// the DFT of a real-valued signal. If fft_length is not provided, it is computed + /// from the size of the inner-most 3 dimensions of input. If the FFT length used + /// to compute input is odd, it should be provided since it cannot be inferred + /// properly. + /// + /// Along each axis IRFFT3D is computed on, if fft_length (or + /// fft_length / 2 + 1 for the inner-most dimension) is smaller than the + /// corresponding dimension of input, the dimension is cropped. If it is larger, + /// the dimension is padded with zeros. + /// + public static Tensor i_r_f_f_t3d(Tensor input, Tensor fft_length, string name = "IRFFT3D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("IRFFT3D", name: name, keywords: dict); + return op.output; + } + + /// + /// Return a tensor with the same shape and contents as the input tensor or value. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Identity'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor identity(Tensor input, string name = "Identity") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("Identity", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a list of tensors with the same shapes and contents as the input + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IdentityN'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// tensors. + /// + /// This op can be used to override the gradient for complicated functions. For + /// example, suppose y = f(x) and we wish to apply a custom function g for backprop + /// such that dx = g(dy). In Python, + /// + /// + /// with tf.get_default_graph().gradient_override_map( + /// {'IdentityN': 'OverrideGradientWithG'}): + /// y, _ = identity_n([f(x), x]) + /// + /// @tf.RegisterGradient('OverrideGradientWithG') + /// def ApplyG(op, dy, _): + /// return [None, g(dy)] # Do not backprop to f(x). + /// + /// + public static Tensor[] identity_n(Tensor[] input, string name = "IdentityN") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("IdentityN", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// A Reader that outputs the queued work as both the key and value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IdentityReader'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// To use, enqueue strings in a Queue. ReaderRead will take the front + /// work string and output (work, work). + /// + public static Tensor identity_reader(string container = null, string shared_name = null, string name = "IdentityReader") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("IdentityReader", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the queued work as both the key and value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IdentityReaderV2'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// To use, enqueue strings in a Queue. ReaderRead will take the front + /// work string and output (work, work). + /// + public static Tensor identity_reader_v2(string container = null, string shared_name = null, string name = "IdentityReaderV2") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("IdentityReaderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the lower regularized incomplete Gamma function P(a, x). + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Igamma'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The lower regularized incomplete Gamma function is defined as: + /// + /// + /// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) + /// + /// where + /// + /// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) + /// + /// is the lower incomplete Gamma function. + /// + /// Note, above Q(a, x) (Igammac) is the upper regularized complete + /// Gamma function. + /// + public static Tensor igamma(Tensor a, Tensor x, string name = "Igamma") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Igamma", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of igamma(a, x) wrt a. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IgammaGradA'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor igamma_grad_a(Tensor a, Tensor x, string name = "IgammaGradA") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("IgammaGradA", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the upper regularized incomplete Gamma function Q(a, x). + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Igammac'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The upper regularized incomplete Gamma function is defined as: + /// + /// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) + /// + /// where + /// + /// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) + /// + /// is the upper incomplete Gama function. + /// + /// Note, above P(a, x) (Igamma) is the lower regularized complete + /// Gamma function. + /// + public static Tensor igammac(Tensor a, Tensor x, string name = "Igammac") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Igammac", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the imaginary part of a complex number. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Imag'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input of complex numbers, this operation returns a tensor of + /// type float that is the imaginary part of each element in input. All + /// elements in input must be complex numbers of the form \\(a + bj\\), where *a* + /// is the real part and *b* is the imaginary part returned by this operation. + /// + /// For example: + /// + /// + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.imag(input) ==&gt; [4.75, 5.75] + /// + /// + public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag") + { + TF_DataType Tin = input.GetDataType(); + return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); + + // return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input })); + } + + /// + /// Outputs a Summary protocol buffer with images. + /// + /// + /// Scalar. Used to build the tag attribute of the summary values. + /// + /// + /// 4-D of shape [batch_size, height, width, channels] where + /// channels is 1, 3, or 4. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ImageSummary'. + /// + /// + /// Max number of batch elements to generate images for. + /// + /// + /// Color to use for pixels with non-finite values. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The summary has up to max_images summary values containing images. The + /// images are built from tensor which must be 4-D with shape [batch_size, + /// height, width, channels] and where channels can be: + /// + /// * 1: tensor is interpreted as Grayscale. + /// * 3: tensor is interpreted as RGB. + /// * 4: tensor is interpreted as RGBA. + /// + /// The images have the same number of channels as the input tensor. For float + /// input, the values are normalized one image at a time to fit in the range + /// [0, 255]. uint8 values are unchanged. The op uses two different + /// normalization algorithms: + /// + /// * If the input values are all positive, they are rescaled so the largest one + /// is 255. + /// + /// * If any input value is negative, the values are shifted so input value 0.0 + /// is at 127. They are then rescaled so that either the smallest value is 0, + /// or the largest one is 255. + /// + /// The tag argument is a scalar Tensor of type string. It is used to + /// build the tag of the summary values: + /// + /// * If max_images is 1, the summary value tag is '*tag*/image'. + /// * If max_images is greater than 1, the summary value tags are + /// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. + /// + /// The bad_color argument is the color to use in the generated images for + /// non-finite input values. It is a uint8 1-D tensor of length channels. + /// Each element must be in the range [0, 255] (It represents the value of a + /// pixel in the output image). Non-finite values in the input tensor are + /// replaced by this tensor in the output image. The default value is the color + /// red. + /// + public static Tensor image_summary(Tensor tag, Tensor tensor, int? max_images = null, Tensor bad_color = null, string name = "ImageSummary") + { + var dict = new Dictionary(); + dict["tag"] = tag; + dict["tensor"] = tensor; + if (max_images.HasValue) + dict["max_images"] = max_images.Value; + if (bad_color != null) + dict["bad_color"] = bad_color; + var op = tf.OpDefLib._apply_op_helper("ImageSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns immutable tensor from memory region. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ImmutableConst'. + /// + /// + /// Optional argument + /// Type of the returned tensor. + /// + /// + /// Optional argument + /// Shape of the returned tensor. + /// + /// + /// Optional argument + /// Name of readonly memory region used by the tensor, see + /// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The current implementation memmaps the tensor from a file. + /// + public static Tensor immutable_const(TF_DataType dtype, Shape shape, string memory_region_name, string name = "ImmutableConst") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + dict["memory_region_name"] = memory_region_name; + var op = tf.OpDefLib._apply_op_helper("ImmutableConst", name: name, keywords: dict); + return op.output; + } + + /// + /// Says whether the targets are in the top K predictions. + /// + /// + /// A batch_size x classes tensor. + /// + /// + /// A batch_size vector of class ids. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InTopK'. + /// + /// + /// Optional argument + /// Number of top elements to look at for computing precision. + /// + /// + /// Computed Precision at k as a bool Tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This outputs a batch_size bool array, an entry out[i] is true if the + /// prediction for the target class is among the top k predictions among + /// all predictions for example i. Note that the behavior of InTopK differs + /// from the TopK op in its handling of ties; if multiple classes have the + /// same prediction value and straddle the top-k boundary, all of those + /// classes are considered to be in the top k. + /// + /// More formally, let + /// + /// \\(predictions_i\\) be the predictions for all classes for example i, + /// \\(targets_i\\) be the target class for example i, + /// \\(out_i\\) be the output for example i, + /// + /// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + /// + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK") + { + var dict = new Dictionary(); + dict["predictions"] = predictions; + dict["targets"] = targets; + dict["k"] = k; + var op = tf.OpDefLib._apply_op_helper("InTopK", name: name, keywords: dict); + return op.output; + } + + /// + /// Says whether the targets are in the top K predictions. + /// + /// + /// A batch_size x classes tensor. + /// + /// + /// A batch_size vector of class ids. + /// + /// + /// Number of top elements to look at for computing precision. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InTopKV2'. + /// + /// + /// Computed precision at k as a bool Tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This outputs a batch_size bool array, an entry out[i] is true if the + /// prediction for the target class is among the top k predictions among + /// all predictions for example i. Note that the behavior of InTopK differs + /// from the TopK op in its handling of ties; if multiple classes have the + /// same prediction value and straddle the top-k boundary, all of those + /// classes are considered to be in the top k. + /// + /// More formally, let + /// + /// \\(predictions_i\\) be the predictions for all classes for example i, + /// \\(targets_i\\) be the target class for example i, + /// \\(out_i\\) be the output for example i, + /// + /// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + /// + public static Tensor in_top_k_v2(Tensor predictions, Tensor targets, Tensor k, string name = "InTopKV2") + { + var dict = new Dictionary(); + dict["predictions"] = predictions; + dict["targets"] = targets; + dict["k"] = k; + var op = tf.OpDefLib._apply_op_helper("InTopKV2", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder op for a value that will be fed into the computation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InfeedDequeue'. + /// + /// + /// Optional argument + /// The type of elements in the tensor. + /// + /// + /// Optional argument + /// The shape of the tensor. + /// + /// + /// A tensor that will be provided using the infeed mechanism. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor infeed_dequeue(TF_DataType dtype, Shape shape, string name = "InfeedDequeue") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("InfeedDequeue", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder op for multiple values that will be fed into the computation + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InfeedDequeueTuple'. + /// + /// + /// Optional argument + /// The element types of each element in outputs. + /// + /// + /// Optional argument + /// The shapes of each tensor in outputs. + /// + /// + /// A list of tensors that will be provided using the infeed mechanism. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// simultaneously as an XLA tuple. + /// + public static Tensor[] infeed_dequeue_tuple(TF_DataType[] dtypes, Shape[] shapes, string name = "InfeedDequeueTuple") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + dict["shapes"] = shapes; + var op = tf.OpDefLib._apply_op_helper("InfeedDequeueTuple", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// An op which feeds a single Tensor value into the computation. + /// + /// + /// A tensor that will be provided using the infeed mechanism. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InfeedEnqueue'. + /// + /// + /// The shape of the tensor. + /// + /// + /// The TPU device to use. This should be -1 when the Op + /// is running on a TPU device, and &gt;= 0 when the Op is running on the CPU + /// device. + /// + /// + /// Returns the description of the operation + /// + public static Operation infeed_enqueue(Tensor input, Shape shape = null, int? device_ordinal = null, string name = "InfeedEnqueue") + { + var dict = new Dictionary(); + dict["input"] = input; + if (shape != null) + dict["shape"] = shape; + if (device_ordinal.HasValue) + dict["device_ordinal"] = device_ordinal.Value; + var op = tf.OpDefLib._apply_op_helper("InfeedEnqueue", name: name, keywords: dict); + return op; + } + + /// + /// An op which feeds multiple Tensor values into the computation as an XLA tuple. + /// + /// + /// A list of tensors that will be provided using the infeed mechanism. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InfeedEnqueueTuple'. + /// + /// + /// Optional argument + /// The shapes of each tensor in inputs. + /// + /// + /// The TPU device to use. This should be -1 when the Op + /// is running on a TPU device, and &gt;= 0 when the Op is running on the CPU + /// device. + /// + /// + /// Returns the description of the operation + /// + public static Operation infeed_enqueue_tuple(Tensor[] inputs, Shape[] shapes, int? device_ordinal = null, string name = "InfeedEnqueueTuple") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["shapes"] = shapes; + if (device_ordinal.HasValue) + dict["device_ordinal"] = device_ordinal.Value; + var op = tf.OpDefLib._apply_op_helper("InfeedEnqueueTuple", name: name, keywords: dict); + return op; + } + + /// + /// Table initializer that takes two tensors for keys and values respectively. + /// + /// + /// Handle to a table which will be initialized. + /// + /// + /// Keys of type Tkey. + /// + /// + /// Values of type Tval. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InitializeTable'. + /// + /// + /// Returns the description of the operation + /// + public static Operation initialize_table(Tensor table_handle, Tensor keys, Tensor values, string name = "InitializeTable") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("InitializeTable", name: name, keywords: dict); + return op; + } + + /// + /// Initializes a table from a text file. + /// + /// + /// Handle to a table which will be initialized. + /// + /// + /// Filename of a vocabulary text file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InitializeTableFromTextFile'. + /// + /// + /// Optional argument + /// Column index in a line to get the table key values from. + /// + /// + /// Optional argument + /// Column index that represents information of a line to get the table + /// value values from. + /// + /// + /// Number of elements of the file, use -1 if unknown. + /// + /// + /// Delimiter to separate fields in a line. + /// + /// + /// Returns the description of the operation + /// + /// + /// It inserts one key-value pair into the table for each line of the file. + /// The key and value is extracted from the whole line content, elements from the + /// split line based on delimiter or the line number (starting from zero). + /// Where to extract the key and value from a line is specified by key_index and + /// value_index. + /// + /// - A value of -1 means use the line number(starting from zero), expects int64. + /// - A value of -2 means use the whole line content, expects string. + /// - A value &gt;= 0 means use the index (starting at zero) of the split line based + /// on delimiter. + /// + public static Operation initialize_table_from_text_file(Tensor table_handle, Tensor filename, int key_index, int value_index, int? vocab_size = null, string delimiter = null, string name = "InitializeTableFromTextFile") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["filename"] = filename; + dict["key_index"] = key_index; + dict["value_index"] = value_index; + if (vocab_size.HasValue) + dict["vocab_size"] = vocab_size.Value; + if (delimiter != null) + dict["delimiter"] = delimiter; + var op = tf.OpDefLib._apply_op_helper("InitializeTableFromTextFile", name: name, keywords: dict); + return op; + } + + /// + /// Initializes a table from a text file. + /// + /// + /// Handle to a table which will be initialized. + /// + /// + /// Filename of a vocabulary text file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InitializeTableFromTextFileV2'. + /// + /// + /// Optional argument + /// Column index in a line to get the table key values from. + /// + /// + /// Optional argument + /// Column index that represents information of a line to get the table + /// value values from. + /// + /// + /// Number of elements of the file, use -1 if unknown. + /// + /// + /// Delimiter to separate fields in a line. + /// + /// + /// Returns the description of the operation + /// + /// + /// It inserts one key-value pair into the table for each line of the file. + /// The key and value is extracted from the whole line content, elements from the + /// split line based on delimiter or the line number (starting from zero). + /// Where to extract the key and value from a line is specified by key_index and + /// value_index. + /// + /// - A value of -1 means use the line number(starting from zero), expects int64. + /// - A value of -2 means use the whole line content, expects string. + /// - A value &gt;= 0 means use the index (starting at zero) of the split line based + /// on delimiter. + /// + public static Operation initialize_table_from_text_file_v2(Tensor table_handle, Tensor filename, int key_index, int value_index, int? vocab_size = null, string delimiter = null, string name = "InitializeTableFromTextFileV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["filename"] = filename; + dict["key_index"] = key_index; + dict["value_index"] = value_index; + if (vocab_size.HasValue) + dict["vocab_size"] = vocab_size.Value; + if (delimiter != null) + dict["delimiter"] = delimiter; + var op = tf.OpDefLib._apply_op_helper("InitializeTableFromTextFileV2", name: name, keywords: dict); + return op; + } + + /// + /// Table initializer that takes two tensors for keys and values respectively. + /// + /// + /// Handle to a table which will be initialized. + /// + /// + /// Keys of type Tkey. + /// + /// + /// Values of type Tval. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InitializeTableV2'. + /// + /// + /// Returns the description of the operation + /// + public static Operation initialize_table_v2(Tensor table_handle, Tensor keys, Tensor values, string name = "InitializeTableV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("InitializeTableV2", name: name, keywords: dict); + return op; + } + + /// + /// Adds v into specified rows of x. + /// + /// Computes y = x; y[i, :] += v; return y. + /// + /// + /// A Tensor of type T. + /// + /// + /// A vector. Indices into the left-most dimension of x. + /// + /// + /// A Tensor of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InplaceAdd'. + /// + /// + /// A Tensor of type T. An alias of x. The content of y is undefined if there are duplicates in i. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor inplace_add(Tensor x, Tensor i, Tensor v, string name = "InplaceAdd") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["i"] = i; + dict["v"] = v; + var op = tf.OpDefLib._apply_op_helper("InplaceAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Subtracts v into specified rows of x. + /// + /// Computes y = x; y[i, :] -= v; return y. + /// + /// + /// A Tensor of type T. + /// + /// + /// A vector. Indices into the left-most dimension of x. + /// + /// + /// A Tensor of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InplaceSub'. + /// + /// + /// A Tensor of type T. An alias of x. The content of y is undefined if there are duplicates in i. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor inplace_sub(Tensor x, Tensor i, Tensor v, string name = "InplaceSub") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["i"] = i; + dict["v"] = v; + var op = tf.OpDefLib._apply_op_helper("InplaceSub", name: name, keywords: dict); + return op.output; + } + + /// + /// Updates specified rows with values in v. + /// + /// Computes x[i, :] = v; return x. + /// + /// + /// A tensor of type T. + /// + /// + /// A vector. Indices into the left-most dimension of x. + /// + /// + /// A Tensor of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InplaceUpdate'. + /// + /// + /// A Tensor of type T. An alias of x. The content of y is undefined if there are duplicates in i. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor inplace_update(Tensor x, Tensor i, Tensor v, string name = "InplaceUpdate") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["i"] = i; + dict["v"] = v; + var op = tf.OpDefLib._apply_op_helper("InplaceUpdate", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the reciprocal of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Inv'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = 1 / x\\). + /// + public static Tensor inv(Tensor x, string name = "Inv") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Inv", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient for the inverse of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InvGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = -dy * y*y, where y = 1/x, and dy + /// is the corresponding input gradient. + /// + public static Tensor inv_grad(Tensor y, Tensor dy, string name = "InvGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("InvGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Flips all bits elementwise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Invert'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The result will have exactly those bits set, that are not set in x. The + /// computation is performed on the underlying representation of x. + /// + public static Tensor invert(Tensor x, string name = "Invert") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Invert", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the inverse permutation of a tensor. + /// + /// + /// 1-D. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'InvertPermutation'. + /// + /// + /// 1-D. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes the inverse of an index permutation. It takes a 1-D + /// integer tensor x, which represents the indices of a zero-based array, and + /// swaps each value with its index position. In other words, for an output tensor + /// y and an input tensor x, this operation computes the following: + /// + /// y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + /// + /// The values must include 0. There can be no duplicate values or negative values. + /// + /// For example: + /// + /// + /// # tensor x is [3, 4, 0, 2, 1] + /// invert_permutation(x) ==&gt; [2, 4, 3, 0, 1] + /// + /// + public static Tensor invert_permutation(Tensor x, string name = "InvertPermutation") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("InvertPermutation", name: name, keywords: dict); + return op.output; + } + + /// + /// Checks whether a tree ensemble has been initialized. + /// + /// + /// Handle to the tree ensemble resouce. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IsBoostedTreesEnsembleInitialized'. + /// + /// + /// output boolean on whether it is initialized or not. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor is_boosted_trees_ensemble_initialized(Tensor tree_ensemble_handle, string name = "IsBoostedTreesEnsembleInitialized") + { + var dict = new Dictionary(); + dict["tree_ensemble_handle"] = tree_ensemble_handle; + var op = tf.OpDefLib._apply_op_helper("IsBoostedTreesEnsembleInitialized", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns which elements of x are finite. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IsFinite'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isfinite + /// @end_compatibility + /// + public static Tensor is_finite(Tensor x, string name = "IsFinite") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("IsFinite", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns which elements of x are Inf. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IsInf'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isinf + /// @end_compatibility + /// + public static Tensor is_inf(Tensor x, string name = "IsInf") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("IsInf", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns which elements of x are NaN. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IsNan'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// @compatibility(numpy) + /// Equivalent to np.isnan + /// @end_compatibility + /// + public static Tensor is_nan(Tensor x, string name = "IsNan") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("IsNan", name: name, keywords: dict); + return op.output; + } + + /// + /// Checks whether a tensor has been initialized. + /// + /// + /// Should be from a Variable node. May be uninitialized. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IsVariableInitialized'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs boolean scalar indicating whether the tensor has been initialized. + /// + public static Tensor is_variable_initialized(Tensor referecne, string name = "IsVariableInitialized") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + var op = tf.OpDefLib._apply_op_helper("IsVariableInitialized", name: name, keywords: dict); + return op.output; + } + + /// + /// A container for an iterator resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Iterator'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// A handle to the iterator that can be passed to a "MakeIterator" + /// or "IteratorGetNext" op. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor iterator(string shared_name, string container, TF_DataType[] output_types, Shape[] output_shapes, string name = "Iterator") + { + var dict = new Dictionary(); + dict["shared_name"] = shared_name; + dict["container"] = container; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("Iterator", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts the given string representing a handle to an iterator to a resource. + /// + /// + /// A string representation of the given handle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IteratorFromStringHandle'. + /// + /// + /// If specified, defines the type of each tuple component in an + /// element produced by the resulting iterator. + /// + /// + /// If specified, defines the shape of each tuple component in an + /// element produced by the resulting iterator. + /// + /// + /// A handle to an iterator resource. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor iterator_from_string_handle(Tensor string_handle, TF_DataType[] output_types = null, Shape[] output_shapes = null, string name = "IteratorFromStringHandle") + { + var dict = new Dictionary(); + dict["string_handle"] = string_handle; + if (output_types != null) + dict["output_types"] = output_types; + if (output_shapes != null) + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("IteratorFromStringHandle", name: name, keywords: dict); + return op.output; + } + + /// + /// Gets the next output from the given iterator . + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IteratorGetNext'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] iterator_get_next(Tensor iterator, TF_DataType[] output_types, Shape[] output_shapes, string name = "IteratorGetNext") + { + var dict = new Dictionary(); + dict["iterator"] = iterator; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("IteratorGetNext", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Gets the next output from the given iterator as an Optional variant. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IteratorGetNextAsOptional'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor iterator_get_next_as_optional(Tensor iterator, TF_DataType[] output_types, Shape[] output_shapes, string name = "IteratorGetNextAsOptional") + { + var dict = new Dictionary(); + dict["iterator"] = iterator; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("IteratorGetNextAsOptional", name: name, keywords: dict); + return op.output; + } + + /// + /// Gets the next output from the given iterator. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IteratorGetNextSync'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation is a synchronous version IteratorGetNext. It should only be used + /// in situations where the iterator does not block the calling thread, or where + /// the calling thread is not a member of the thread pool used to execute parallel + /// operations (e.g. in eager mode). + /// + public static Tensor[] iterator_get_next_sync(Tensor iterator, TF_DataType[] output_types, Shape[] output_shapes, string name = "IteratorGetNextSync") + { + var dict = new Dictionary(); + dict["iterator"] = iterator; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("IteratorGetNextSync", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Converts the given resource_handle representing an iterator to a string. + /// + /// + /// A handle to an iterator resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'IteratorToStringHandle'. + /// + /// + /// A string representation of the given handle. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor iterator_to_string_handle(Tensor resource_handle, string name = "IteratorToStringHandle") + { + var dict = new Dictionary(); + dict["resource_handle"] = resource_handle; + var op = tf.OpDefLib._apply_op_helper("IteratorToStringHandle", name: name, keywords: dict); + return op.output; + } + + /// + /// L2 Loss. + /// + /// + /// Typically 2-D, but may have any dimensions. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'L2Loss'. + /// + /// + /// 0-D. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes half the L2 norm of a tensor without the sqrt: + /// + /// output = sum(t ** 2) / 2 + /// + public static Tensor l2loss(Tensor t, string name = "L2Loss") + { + var dict = new Dictionary(); + dict["t"] = t; + var op = tf.OpDefLib._apply_op_helper("L2Loss", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the records from a LMDB file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LMDBReader'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor l_m_d_b_reader(string container = null, string shared_name = null, string name = "LMDBReader") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("LMDBReader", name: name, keywords: dict); + return op.output; + } + + /// + /// Local Response Normalization. + /// + /// + /// 4-D. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LRN'. + /// + /// + /// 0-D. Half-width of the 1-D normalization window. + /// + /// + /// An offset (usually positive to avoid dividing by 0). + /// + /// + /// A scale factor, usually positive. + /// + /// + /// An exponent. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the last + /// dimension), and each vector is normalized independently. Within a given vector, + /// each component is divided by the weighted, squared sum of inputs within + /// depth_radius. In detail, + /// + /// sqr_sum[a, b, c, d] = + /// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) + /// output = input / (bias + alpha * sqr_sum) ** beta + /// + /// For details, see [Krizhevsky et al., ImageNet classification with deep + /// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). + /// + public static Tensor l_r_n(Tensor input, int? depth_radius = null, float? bias = null, float? alpha = null, float? beta = null, string name = "LRN") + { + var dict = new Dictionary(); + dict["input"] = input; + if (depth_radius.HasValue) + dict["depth_radius"] = depth_radius.Value; + if (bias.HasValue) + dict["bias"] = bias.Value; + if (alpha.HasValue) + dict["alpha"] = alpha.Value; + if (beta.HasValue) + dict["beta"] = beta.Value; + var op = tf.OpDefLib._apply_op_helper("LRN", name: name, keywords: dict); + return op.output; + } + + /// + /// Gradients for Local Response Normalization. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LRNGrad'. + /// + /// + /// A depth radius. + /// + /// + /// An offset (usually &gt; 0 to avoid dividing by 0). + /// + /// + /// A scale factor, usually positive. + /// + /// + /// An exponent. + /// + /// + /// The gradients for LRN. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor l_r_n_grad(Tensor input_grads, Tensor input_image, Tensor output_image, int? depth_radius = null, float? bias = null, float? alpha = null, float? beta = null, string name = "LRNGrad") + { + var dict = new Dictionary(); + dict["input_grads"] = input_grads; + dict["input_image"] = input_image; + dict["output_image"] = output_image; + if (depth_radius.HasValue) + dict["depth_radius"] = depth_radius.Value; + if (bias.HasValue) + dict["bias"] = bias.Value; + if (alpha.HasValue) + dict["alpha"] = alpha.Value; + if (beta.HasValue) + dict["beta"] = beta.Value; + var op = tf.OpDefLib._apply_op_helper("LRNGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Records the latency of producing input_dataset elements in a StatsAggregator. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LatencyStatsDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor latency_stats_dataset(Tensor input_dataset, Tensor tag, TF_DataType[] output_types, Shape[] output_shapes, string name = "LatencyStatsDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["tag"] = tag; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("LatencyStatsDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a learned unigram distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LearnedUnigramCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to randomly sample. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// Optional argument + /// The sampler will sample integers from the interval [0, range_max). + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See explanations of candidate sampling and the data formats at + /// go/candidate-sampling. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) learned_unigram_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int range_max, int? seed = null, int? seed2 = null, string name = "LearnedUnigramCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + dict["range_max"] = range_max; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("LearnedUnigramCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Elementwise computes the bitwise left-shift of x and y. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LeftShift'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If y is negative, or greater than or equal to the width of x in bits the + /// result is implementation defined. + /// + public static Tensor left_shift(Tensor x, Tensor y, string name = "LeftShift") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("LeftShift", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x &lt; y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Less'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Less supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor less(Tensor x, Tensor y, string name = "Less") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Less", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x &lt;= y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LessEqual'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: LessEqual supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor less_equal(Tensor x, Tensor y, string name = "LessEqual") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("LessEqual", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the log of the absolute value of Gamma(x) element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Lgamma'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor lgamma(Tensor x, string name = "Lgamma") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Lgamma", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates values in an interval. + /// + /// + /// 0-D tensor. First entry in the range. + /// + /// + /// 0-D tensor. Last entry in the range. + /// + /// + /// 0-D tensor. Number of values to generate. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LinSpace'. + /// + /// + /// 1-D. The generated values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A sequence of num evenly-spaced values are generated beginning at start. + /// If num &gt; 1, the values in the sequence increase by stop - start / num - 1, + /// so that the last one is exactly stop. + /// + /// For example: + /// + /// + /// tf.linspace(10.0, 12.0, 3, name="linspace") =&gt; [ 10.0 11.0 12.0] + /// + /// + public static Tensor lin_space(Tensor start, Tensor stop, Tensor num, string name = "LinSpace") + { + var dict = new Dictionary(); + dict["start"] = start; + dict["stop"] = stop; + dict["num"] = num; + var op = tf.OpDefLib._apply_op_helper("LinSpace", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the difference between two lists of numbers or strings. + /// + /// + /// 1-D. Values to keep. + /// + /// + /// 1-D. Values to remove. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ListDiff'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : 1-D. Values present in x but not in y. + /// idx : 1-D. Positions of x values preserved in out. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Given a list x and a list y, this operation returns a list out that + /// represents all values that are in x but not in y. The returned list out + /// is sorted in the same order that the numbers appear in x (duplicates are + /// preserved). This operation also returns a list idx that represents the + /// position of each out element in x. In other words: + /// + /// out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1] + /// + /// For example, given this input: + /// + /// + /// x = [1, 2, 3, 4, 5, 6] + /// y = [1, 3, 5] + /// + /// + /// This operation would return: + /// + /// + /// out ==&gt; [2, 4, 6] + /// idx ==&gt; [1, 3, 5] + /// + /// + public static (Tensor output, Tensor idx) list_diff(Tensor x, Tensor y, TF_DataType? out_idx = null, string name = "ListDiff") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + if (out_idx.HasValue) + dict["out_idx"] = out_idx.Value; + var op = tf.OpDefLib._apply_op_helper("ListDiff", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var idx = op.outputs[_idx++]; + return (output, idx); + } + + /// + /// Loads a 2-D (matrix) Tensor with name old_tensor_name from the checkpoint + /// + /// + /// Path to the TensorFlow checkpoint (version 2, TensorBundle) from + /// which the old matrix Tensor will be loaded. + /// + /// + /// Name of the 2-D Tensor to load from checkpoint. + /// + /// + /// An int Tensor of row remappings (generally created by + /// generate_vocab_remapping). Even if no row remapping is needed, this must + /// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted + /// index-valued Tensor (e.g. [8, 9, 10, ...], for partitioned Variables). + /// + /// + /// An int Tensor of column remappings (generally created by + /// generate_vocab_remapping). May be a size-0 Tensor if only row remapping + /// is to be done (e.g. column ordering is the same). + /// + /// + /// A float Tensor containing values to fill in for cells + /// in the output matrix that are not loaded from the checkpoint. Length must be + /// exactly the same as the number of missing / new cells. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LoadAndRemapMatrix'. + /// + /// + /// Optional argument + /// Number of rows (length of the 1st dimension) in the output matrix. + /// + /// + /// Optional argument + /// Number of columns (length of the 2nd dimension) in the output matrix. + /// + /// + /// The maximum number of rows to load from the checkpoint at + /// once. If less than or equal to 0, the entire matrix will be loaded into + /// memory. Setting this arg trades increased disk reads for lower memory usage. + /// + /// + /// Output matrix containing existing values loaded from the + /// checkpoint, and with any missing values filled in from initializing_values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// at ckpt_path and potentially reorders its rows and columns using the + /// specified remappings. + /// + /// Most users should use one of the wrapper initializers (such as + /// tf.contrib.framework.load_and_remap_matrix_initializer) instead of this + /// function directly. + /// + /// The remappings are 1-D tensors with the following properties: + /// + /// * row_remapping must have exactly num_rows entries. Row i of the output + /// matrix will be initialized from the row corresponding to index + /// row_remapping[i] in the old Tensor from the checkpoint. + /// * col_remapping must have either 0 entries (indicating that no column + /// reordering is needed) or num_cols entries. If specified, column j of the + /// output matrix will be initialized from the column corresponding to index + /// col_remapping[j] in the old Tensor from the checkpoint. + /// * A value of -1 in either of the remappings signifies a "missing" entry. In that + /// case, values from the initializing_values tensor will be used to fill that + /// missing row or column. If row_remapping has r missing entries and + /// col_remapping has c missing entries, then the following condition must be + /// true: + /// + /// (r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values) + /// + /// The remapping tensors can be generated using the GenerateVocabRemapping op. + /// + /// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], + /// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing + /// the value from row i, column j of the old tensor in the checkpoint, the output + /// matrix will look like the following: + /// + /// [[w(1, 0), w(1, 2), 0.5], + /// [w(0, 0), w(0, 2), -0.5], + /// [0.25, -0.25, 42]] + /// + public static Tensor load_and_remap_matrix(Tensor ckpt_path, Tensor old_tensor_name, Tensor row_remapping, Tensor col_remapping, Tensor initializing_values, int num_rows, int num_cols, int? max_rows_in_memory = null, string name = "LoadAndRemapMatrix") + { + var dict = new Dictionary(); + dict["ckpt_path"] = ckpt_path; + dict["old_tensor_name"] = old_tensor_name; + dict["row_remapping"] = row_remapping; + dict["col_remapping"] = col_remapping; + dict["initializing_values"] = initializing_values; + dict["num_rows"] = num_rows; + dict["num_cols"] = num_cols; + if (max_rows_in_memory.HasValue) + dict["max_rows_in_memory"] = max_rows_in_memory.Value; + var op = tf.OpDefLib._apply_op_helper("LoadAndRemapMatrix", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes natural logarithm of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Log'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = \log_e x\\). + /// + public static Tensor log(Tensor x, string name = "Log") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Log", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes natural logarithm of (1 + x) element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Log1p'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = \log_e (1 + x)\\). + /// + public static Tensor log1p(Tensor x, string name = "Log1p") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Log1p", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sign and the log of the absolute value of the determinant of + /// + /// + /// Shape is [N, M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogMatrixDeterminant'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sign : The signs of the log determinants of the inputs. Shape is [N]. + /// log_abs_determinant : The logs of the absolute values of the determinants + /// of the N input matrices. Shape is [N]. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// one or more square matrices. + /// + /// The input is a tensor of shape [N, M, M] whose inner-most 2 dimensions + /// form square matrices. The outputs are two tensors containing the signs and + /// absolute values of the log determinants for all N input submatrices + /// [..., :, :] such that the determinant = sign*exp(log_abs_determinant). + /// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU + /// is the LU decomposition of the input and P is the corresponding + /// permutation matrix. + /// + public static (Tensor sign, Tensor log_abs_determinant) log_matrix_determinant(Tensor input, string name = "LogMatrixDeterminant") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("LogMatrixDeterminant", name: name, keywords: dict); + int _idx = 0; + var sign = op.outputs[_idx++]; + var log_abs_determinant = op.outputs[_idx++]; + return (sign, log_abs_determinant); + } + + /// + /// Computes log softmax activations. + /// + /// + /// 2-D with shape [batch_size, num_classes]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogSoftmax'. + /// + /// + /// Same shape as logits. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For each batch i and class j we have + /// + /// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) + /// + public static Tensor log_softmax(Tensor logits, string name = "LogSoftmax") + { + var dict = new Dictionary(); + dict["logits"] = logits; + var op = tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a log-uniform distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogUniformCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to randomly sample. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// Optional argument + /// The sampler will sample integers from the interval [0, range_max). + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See explanations of candidate sampling and the data formats at + /// go/candidate-sampling. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) log_uniform_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int range_max, int? seed = null, int? seed2 = null, string name = "LogUniformCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + dict["range_max"] = range_max; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("LogUniformCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Returns the truth value of x AND y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogicalAnd'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: LogicalAnd supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor logical_and(Tensor x, Tensor y, string name = "LogicalAnd") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("LogicalAnd", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of NOT x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogicalNot'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor logical_not(Tensor x, string name = "LogicalNot") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("LogicalNot", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of x OR y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LogicalOr'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: LogicalOr supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor logical_or(Tensor x, Tensor y, string name = "LogicalOr") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("LogicalOr", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs all keys and values in the table. + /// + /// + /// Handle to the table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableExport'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Returns a tuple with multiple values, as follows: + /// keys : Vector of all keys present in the table. + /// values : Tensor of all values in the table. Indexed in parallel with keys. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor keys, Tensor values) lookup_table_export(Tensor table_handle, TF_DataType Tkeys, TF_DataType Tvalues, string name = "LookupTableExport") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["Tkeys"] = Tkeys; + dict["Tvalues"] = Tvalues; + var op = tf.OpDefLib._apply_op_helper("LookupTableExport", name: name, keywords: dict); + int _idx = 0; + var keys = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + return (keys, values); + } + + /// + /// Outputs all keys and values in the table. + /// + /// + /// Handle to the table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableExportV2'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Returns a tuple with multiple values, as follows: + /// keys : Vector of all keys present in the table. + /// values : Tensor of all values in the table. Indexed in parallel with keys. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor keys, Tensor values) lookup_table_export_v2(Tensor table_handle, TF_DataType Tkeys, TF_DataType Tvalues, string name = "LookupTableExportV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["Tkeys"] = Tkeys; + dict["Tvalues"] = Tvalues; + var op = tf.OpDefLib._apply_op_helper("LookupTableExportV2", name: name, keywords: dict); + int _idx = 0; + var keys = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + return (keys, values); + } + + /// + /// Looks up keys in a table, outputs the corresponding values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableFind'. + /// + /// + /// Same shape as keys. Values found in the table, or default_values + /// for missing keys. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The tensor keys must of the same type as the keys of the table. + /// The output values is of the type of the table values. + /// + /// The scalar default_value is the value output for keys not present in the + /// table. It must also be of the same type as the table values. + /// + public static Tensor lookup_table_find(Tensor table_handle, Tensor keys, Tensor default_value, string name = "LookupTableFind") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["default_value"] = default_value; + var op = tf.OpDefLib._apply_op_helper("LookupTableFind", name: name, keywords: dict); + return op.output; + } + + /// + /// Looks up keys in a table, outputs the corresponding values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableFindV2'. + /// + /// + /// Same shape as keys. Values found in the table, or default_values + /// for missing keys. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The tensor keys must of the same type as the keys of the table. + /// The output values is of the type of the table values. + /// + /// The scalar default_value is the value output for keys not present in the + /// table. It must also be of the same type as the table values. + /// + public static Tensor lookup_table_find_v2(Tensor table_handle, Tensor keys, Tensor default_value, string name = "LookupTableFindV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["default_value"] = default_value; + var op = tf.OpDefLib._apply_op_helper("LookupTableFindV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Replaces the contents of the table with the specified keys and values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// Values to associate with keys. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableImport'. + /// + /// + /// Returns the description of the operation + /// + /// + /// The tensor keys must be of the same type as the keys of the table. + /// The tensor values must be of the type of the table values. + /// + public static Operation lookup_table_import(Tensor table_handle, Tensor keys, Tensor values, string name = "LookupTableImport") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("LookupTableImport", name: name, keywords: dict); + return op; + } + + /// + /// Replaces the contents of the table with the specified keys and values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// Values to associate with keys. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableImportV2'. + /// + /// + /// Returns the description of the operation + /// + /// + /// The tensor keys must be of the same type as the keys of the table. + /// The tensor values must be of the type of the table values. + /// + public static Operation lookup_table_import_v2(Tensor table_handle, Tensor keys, Tensor values, string name = "LookupTableImportV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("LookupTableImportV2", name: name, keywords: dict); + return op; + } + + /// + /// Updates the table to associates keys with values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// Values to associate with keys. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableInsert'. + /// + /// + /// Returns the description of the operation + /// + /// + /// The tensor keys must be of the same type as the keys of the table. + /// The tensor values must be of the type of the table values. + /// + public static Operation lookup_table_insert(Tensor table_handle, Tensor keys, Tensor values, string name = "LookupTableInsert") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("LookupTableInsert", name: name, keywords: dict); + return op; + } + + /// + /// Updates the table to associates keys with values. + /// + /// + /// Handle to the table. + /// + /// + /// Any shape. Keys to look up. + /// + /// + /// Values to associate with keys. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableInsertV2'. + /// + /// + /// Returns the description of the operation + /// + /// + /// The tensor keys must be of the same type as the keys of the table. + /// The tensor values must be of the type of the table values. + /// + public static Operation lookup_table_insert_v2(Tensor table_handle, Tensor keys, Tensor values, string name = "LookupTableInsertV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + dict["keys"] = keys; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("LookupTableInsertV2", name: name, keywords: dict); + return op; + } + + /// + /// Computes the number of elements in the given table. + /// + /// + /// Handle to the table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableSize'. + /// + /// + /// Scalar that contains number of elements in the table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor lookup_table_size(Tensor table_handle, string name = "LookupTableSize") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + var op = tf.OpDefLib._apply_op_helper("LookupTableSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the number of elements in the given table. + /// + /// + /// Handle to the table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LookupTableSizeV2'. + /// + /// + /// Scalar that contains number of elements in the table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor lookup_table_size_v2(Tensor table_handle, string name = "LookupTableSizeV2") + { + var dict = new Dictionary(); + dict["table_handle"] = table_handle; + var op = tf.OpDefLib._apply_op_helper("LookupTableSizeV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Forwards the input to the output. + /// + /// + /// A boolean scalar, representing the branch predicate of the Switch op. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'LoopCond'. + /// + /// + /// The same tensor as input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operator represents the loop termination condition used by the + /// "pivot" switches of a loop. + /// + public static Tensor loop_cond(Tensor input, string name = "LoopCond") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("LoopCond", name: name, keywords: dict); + return op.output; + } + + /// + /// Makes a new iterator from the given dataset and stores it in iterator. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MakeIterator'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation may be executed multiple times. Each execution will reset the + /// iterator in iterator to the first element of dataset. + /// + public static Operation make_iterator(Tensor dataset, Tensor iterator, string name = "MakeIterator") + { + var dict = new Dictionary(); + dict["dataset"] = dataset; + dict["iterator"] = iterator; + var op = tf.OpDefLib._apply_op_helper("MakeIterator", name: name, keywords: dict); + return op; + } + + /// + /// Op removes all elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapClear'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns the description of the operation + /// + public static Operation map_clear(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapClear") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapClear", name: name, keywords: dict); + return op; + } + + /// + /// Op returns the number of incomplete elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapIncompleteSize'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor map_incomplete_size(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapIncompleteSize") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapIncompleteSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Op peeks at the values at the specified key. If the + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapPeek'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// underlying container does not contain this key + /// this op will block until it does. + /// + public static Tensor[] map_peek(Tensor key, Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapPeek") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapPeek", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Op returns the number of elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapSize'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor map_size(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapSize") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Stage (key, values) in the underlying container which behaves like a hashtable. + /// + /// + /// int64 + /// + /// + /// + /// + /// a list of tensors + /// dtypes A list of data types that inserted values should adhere to. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapStage'. + /// + /// + /// Optional argument + /// + /// + /// Maximum number of elements in the Staging Area. If &gt; 0, inserts + /// on the container will block when the capacity is reached. + /// + /// + /// + /// + /// If non-empty, this queue is placed in the given container. Otherwise, + /// a default container is used. + /// + /// + /// It is necessary to match this name to the matching Unstage Op. + /// + /// + /// Returns the description of the operation + /// + public static Operation map_stage(Tensor key, Tensor indices, Tensor[] values, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapStage") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["values"] = values; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapStage", name: name, keywords: dict); + return op; + } + + /// + /// Op removes and returns the values associated with the key + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapUnstage'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// from the underlying container. If the underlying container + /// does not contain this key, the op will block until it does. + /// + public static Tensor[] map_unstage(Tensor key, Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapUnstage") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapUnstage", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Op removes and returns a random (key, value) + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MapUnstageNoKey'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// key : + /// values : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// from the underlying container. If the underlying container + /// does not contain elements, the op will block until it does. + /// + public static (Tensor key, Tensor[] values) map_unstage_no_key(Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "MapUnstageNoKey") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MapUnstageNoKey", name: name, keywords: dict); + int _idx = 0; + var key = op.outputs[_idx++]; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (key, values); + } + + /// + /// Multiply the matrix "a" by the matrix "b". + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatMul'. + /// + /// + /// If true, "a" is transposed before multiplication. + /// + /// + /// If true, "b" is transposed before multiplication. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of + /// "a" (after being transposed if transpose_a is true) must match the + /// outer dimension of "b" (after being transposed if transposed_b is + /// true). + /// + /// *Note*: The default kernel implementation for MatMul on GPUs uses + /// cublas. + /// + public static Tensor mat_mul(Tensor a, Tensor b, bool? transpose_a = null, bool? transpose_b = null, string name = "MatMul") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["b"] = b; + if (transpose_a.HasValue) + dict["transpose_a"] = transpose_a.Value; + if (transpose_b.HasValue) + dict["transpose_b"] = transpose_b.Value; + var op = tf.OpDefLib._apply_op_helper("MatMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the set of files matching one or more glob patterns. + /// + /// + /// Shell wildcard pattern(s). Scalar or vector of type string. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatchingFiles'. + /// + /// + /// A vector of matching filenames. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that this routine only supports wildcard characters in the + /// basename portion of the pattern, not in the directory portion. + /// Note also that the order of filenames returned can be non-deterministic. + /// + public static Tensor matching_files(Tensor pattern, string name = "MatchingFiles") + { + var dict = new Dictionary(); + dict["pattern"] = pattern; + var op = tf.OpDefLib._apply_op_helper("MatchingFiles", name: name, keywords: dict); + return op.output; + } + + /// + /// Copy a tensor setting everything outside a central band in each innermost matrix + /// + /// + /// Rank k tensor. + /// + /// + /// 0-D tensor. Number of subdiagonals to keep. If negative, keep entire + /// lower triangle. + /// + /// + /// 0-D tensor. Number of superdiagonals to keep. If negative, keep + /// entire upper triangle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixBandPart'. + /// + /// + /// Rank k tensor of the same shape as input. The extracted banded tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// to zero. + /// + /// The band part is computed as follows: + /// Assume input has k dimensions [I, J, K, ..., M, N], then the output is a + /// tensor with the same shape where + /// + /// band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]. + /// + /// The indicator function + /// + /// in_band(m, n) = (num_lower &lt; 0 || (m-n) &lt;= num_lower)) && + /// (num_upper &lt; 0 || (n-m) &lt;= num_upper). + /// + /// For example: + /// + /// + /// # if 'input' is [[ 0, 1, 2, 3] + /// [-1, 0, 1, 2] + /// [-2, -1, 0, 1] + /// [-3, -2, -1, 0]], + /// + /// tf.matrix_band_part(input, 1, -1) ==&gt; [[ 0, 1, 2, 3] + /// [-1, 0, 1, 2] + /// [ 0, -1, 0, 1] + /// [ 0, 0, -1, 0]], + /// + /// tf.matrix_band_part(input, 2, 1) ==&gt; [[ 0, 1, 0, 0] + /// [-1, 0, 1, 0] + /// [-2, -1, 0, 1] + /// [ 0, -2, -1, 0]] + /// + /// + /// Useful special cases: + /// + /// + /// tf.matrix_band_part(input, 0, -1) ==&gt; Upper triangular part. + /// tf.matrix_band_part(input, -1, 0) ==&gt; Lower triangular part. + /// tf.matrix_band_part(input, 0, 0) ==&gt; Diagonal. + /// + /// + public static Tensor matrix_band_part(Tensor input, Tensor num_lower, Tensor num_upper, string name = "MatrixBandPart") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["num_lower"] = num_lower; + dict["num_upper"] = num_upper; + var op = tf.OpDefLib._apply_op_helper("MatrixBandPart", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the determinant of one or more square matrices. + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixDeterminant'. + /// + /// + /// Shape is [...]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices. The output is a tensor containing the determinants + /// for all input submatrices [..., :, :]. + /// + public static Tensor matrix_determinant(Tensor input, string name = "MatrixDeterminant") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("MatrixDeterminant", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a batched diagonal tensor with a given batched diagonal values. + /// + /// + /// Rank k, where k &gt;= 1. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixDiag'. + /// + /// + /// Rank k+1, with output.shape = diagonal.shape + [diagonal.shape[-1]]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a diagonal, this operation returns a tensor with the diagonal and + /// everything else padded with zeros. The diagonal is computed as follows: + /// + /// Assume diagonal has k dimensions [I, J, K, ..., N], then the output is a + /// tensor of rank k+1 with dimensions [I, J, K, ..., N, N] where: + /// + /// output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]. + /// + /// For example: + /// + /// + /// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] + /// + /// and diagonal.shape = (2, 4) + /// + /// tf.matrix_diag(diagonal) ==&gt; [[[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0] + /// [0, 6, 0, 0] + /// [0, 0, 7, 0] + /// [0, 0, 0, 8]]] + /// + /// which has shape (2, 4, 4) + /// + /// + public static Tensor matrix_diag(Tensor diagonal, string name = "MatrixDiag") + { + var dict = new Dictionary(); + dict["diagonal"] = diagonal; + var op = tf.OpDefLib._apply_op_helper("MatrixDiag", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the batched diagonal part of a batched tensor. + /// + /// + /// Rank k tensor where k &gt;= 2. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixDiagPart'. + /// + /// + /// The extracted diagonal(s) having shape + /// diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns a tensor with the diagonal part + /// of the batched input. The diagonal part is computed as follows: + /// + /// Assume input has k dimensions [I, J, K, ..., M, N], then the output is a + /// tensor of rank k - 1 with dimensions [I, J, K, ..., min(M, N)] where: + /// + /// diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]. + /// + /// The input must be at least a matrix. + /// + /// For example: + /// + /// + /// # 'input' is [[[1, 0, 0, 0] + /// [0, 2, 0, 0] + /// [0, 0, 3, 0] + /// [0, 0, 0, 4]], + /// [[5, 0, 0, 0] + /// [0, 6, 0, 0] + /// [0, 0, 7, 0] + /// [0, 0, 0, 8]]] + /// + /// and input.shape = (2, 4, 4) + /// + /// tf.matrix_diag_part(input) ==&gt; [[1, 2, 3, 4], [5, 6, 7, 8]] + /// + /// which has shape (2, 4) + /// + /// + public static Tensor matrix_diag_part(Tensor input, string name = "MatrixDiagPart") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("MatrixDiagPart", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated, use python implementation tf.linalg.matrix_exponential. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixExponential'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor matrix_exponential(Tensor input, string name = "MatrixExponential") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("MatrixExponential", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the inverse of one or more square invertible matrices or their + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixInverse'. + /// + /// + /// + /// + /// Shape is [..., M, M]. + /// + /// @compatibility(numpy) + /// Equivalent to np.linalg.inv + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// adjoints (conjugate transposes). + /// + /// The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices. The output is a tensor of the same shape as the input + /// containing the inverse for all input submatrices [..., :, :]. + /// + /// The op uses LU decomposition with partial pivoting to compute the inverses. + /// + /// If a matrix is not invertible there is no guarantee what the op does. It + /// may detect the condition and raise an exception or it may simply return a + /// garbage result. + /// + public static Tensor matrix_inverse(Tensor input, bool? adjoint = null, string name = "MatrixInverse") + { + var dict = new Dictionary(); + dict["input"] = input; + if (adjoint.HasValue) + dict["adjoint"] = adjoint.Value; + var op = tf.OpDefLib._apply_op_helper("MatrixInverse", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the matrix logarithm of one or more square matrices: + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixLogarithm'. + /// + /// + /// Shape is [..., M, M]. + /// + /// @compatibility(scipy) + /// Equivalent to scipy.linalg.logm + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// + /// \\(log(exp(A)) = A\\) + /// + /// This op is only defined for complex matrices. If A is positive-definite and + /// real, then casting to a complex matrix, taking the logarithm and casting back + /// to a real matrix will give the correct result. + /// + /// This function computes the matrix logarithm using the Schur-Parlett algorithm. + /// Details of the algorithm can be found in Section 11.6.2 of: + /// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. + /// ISBN 978-0-898716-46-7. + /// + /// The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices. The output is a tensor of the same shape as the input + /// containing the exponential for all input submatrices [..., :, :]. + /// + public static Tensor matrix_logarithm(Tensor input, string name = "MatrixLogarithm") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("MatrixLogarithm", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a batched matrix tensor with new batched diagonal values. + /// + /// + /// Rank k+1, where k &gt;= 1. + /// + /// + /// Rank k, where k &gt;= 1. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixSetDiag'. + /// + /// + /// Rank k+1, with output.shape = input.shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given input and diagonal, this operation returns a tensor with the + /// same shape and values as input, except for the main diagonal of the + /// innermost matrices. These will be overwritten by the values in diagonal. + /// + /// The output is computed as follows: + /// + /// Assume input has k+1 dimensions [I, J, K, ..., M, N] and diagonal has + /// k dimensions [I, J, K, ..., min(M, N)]. Then the output is a + /// tensor of rank k+1 with dimensions [I, J, K, ..., M, N] where: + /// + /// * output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n] for m == n. + /// * output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n] for m != n. + /// + public static Tensor matrix_set_diag(Tensor input, Tensor diagonal, string name = "MatrixSetDiag") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["diagonal"] = diagonal; + var op = tf.OpDefLib._apply_op_helper("MatrixSetDiag", name: name, keywords: dict); + return op.output; + } + + /// + /// Solves systems of linear equations. + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// Shape is [..., M, K]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixSolve'. + /// + /// + /// Boolean indicating whether to solve with matrix or its (block-wise) + /// adjoint. + /// + /// + /// Shape is [..., M, K]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Matrix is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices. Rhs is a tensor of shape [..., M, K]. The output is + /// a tensor shape [..., M, K]. If adjoint is False then each output matrix + /// satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]. + /// If adjoint is True then each output matrix satisfies + /// adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]. + /// + public static Tensor matrix_solve(Tensor matrix, Tensor rhs, bool? adjoint = null, string name = "MatrixSolve") + { + var dict = new Dictionary(); + dict["matrix"] = matrix; + dict["rhs"] = rhs; + if (adjoint.HasValue) + dict["adjoint"] = adjoint.Value; + var op = tf.OpDefLib._apply_op_helper("MatrixSolve", name: name, keywords: dict); + return op.output; + } + + /// + /// Solves one or more linear least-squares problems. + /// + /// + /// Shape is [..., M, N]. + /// + /// + /// Shape is [..., M, K]. + /// + /// + /// Scalar tensor. + /// + /// @compatibility(numpy) + /// Equivalent to np.linalg.lstsq + /// @end_compatibility + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixSolveLs'. + /// + /// + /// + /// + /// Shape is [..., N, K]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// matrix is a tensor of shape [..., M, N] whose inner-most 2 dimensions + /// form real or complex matrices of size [M, N]. Rhs is a tensor of the same + /// type as matrix and shape [..., M, K]. + /// The output is a tensor shape [..., N, K] where each output matrix solves + /// each of the equations + /// matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] + /// in the least squares sense. + /// + /// We use the following notation for (complex) matrix and right-hand sides + /// in the batch: + /// + /// matrix=\\(A \in \mathbb{C}^{m \times n}\\), + /// rhs=\\(B \in \mathbb{C}^{m \times k}\\), + /// output=\\(X \in \mathbb{C}^{n \times k}\\), + /// l2_regularizer=\\(\lambda \in \mathbb{R}\\). + /// + /// If fast is True, then the solution is computed by solving the normal + /// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then + /// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares + /// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). + /// If \\(m \lt n\\) then output is computed as + /// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the + /// minimum-norm solution to the under-determined linear system, i.e. + /// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), + /// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable + /// when \\(A\\) is numerically full rank and has a condition number + /// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is + /// sufficiently large. + /// + /// If fast is False an algorithm based on the numerically robust complete + /// orthogonal decomposition is used. This computes the minimum-norm + /// least-squares solution, even when \\(A\\) is rank deficient. This path is + /// typically 6-7 times slower than the fast path. If fast is False then + /// l2_regularizer is ignored. + /// + public static Tensor matrix_solve_ls(Tensor matrix, Tensor rhs, Tensor l2_regularizer, bool? fast = null, string name = "MatrixSolveLs") + { + var dict = new Dictionary(); + dict["matrix"] = matrix; + dict["rhs"] = rhs; + dict["l2_regularizer"] = l2_regularizer; + if (fast.HasValue) + dict["fast"] = fast.Value; + var op = tf.OpDefLib._apply_op_helper("MatrixSolveLs", name: name, keywords: dict); + return op.output; + } + + /// + /// Solves systems of linear equations with upper or lower triangular matrices by + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// Shape is [..., M, K]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MatrixTriangularSolve'. + /// + /// + /// Boolean indicating whether the innermost matrices in matrix are + /// lower or upper triangular. + /// + /// + /// Boolean indicating whether to solve with matrix or its (block-wise) + /// adjoint. + /// + /// @compatibility(numpy) + /// Equivalent to scipy.linalg.solve_triangular + /// @end_compatibility + /// + /// + /// Shape is [..., M, K]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// backsubstitution. + /// + /// matrix is a tensor of shape [..., M, M] whose inner-most 2 dimensions form + /// square matrices. If lower is True then the strictly upper triangular part + /// of each inner-most matrix is assumed to be zero and not accessed. + /// If lower is False then the strictly lower triangular part of each inner-most + /// matrix is assumed to be zero and not accessed. + /// rhs is a tensor of shape [..., M, K]. + /// + /// The output is a tensor of shape [..., M, K]. If adjoint is + /// True then the innermost matrices in output satisfy matrix equations + /// matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]. + /// If adjoint is False then the strictly then the innermost matrices in + /// output satisfy matrix equations + /// adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]. + /// + public static Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool? lower = null, bool? adjoint = null, string name = "MatrixTriangularSolve") + { + var dict = new Dictionary(); + dict["matrix"] = matrix; + dict["rhs"] = rhs; + if (lower.HasValue) + dict["lower"] = lower.Value; + if (adjoint.HasValue) + dict["adjoint"] = adjoint.Value; + var op = tf.OpDefLib._apply_op_helper("MatrixTriangularSolve", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the maximum of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Max'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor max(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Max") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Max", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs max pooling on the input. + /// + /// + /// 4-D input to pool over. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPool'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// The max pooled output tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool(Tensor input, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPool") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs 3D max pooling on the input. + /// + /// + /// Shape [batch, depth, rows, cols, channels] tensor to pool over. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPool3D'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have ksize[0] = ksize[4] = 1. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// The max pooled output tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool3d(Tensor input, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPool3D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPool3D", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of max pooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// Output backprop of shape [batch, depth, rows, cols, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPool3DGrad'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have ksize[0] = ksize[4] = 1. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool3d_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPool3DGrad") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPool3DGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// Output backprop of shape [batch, depth, rows, cols, channels]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPool3DGradGrad'. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The size of the window for each dimension of + /// the input tensor. Must have ksize[0] = ksize[4] = 1. + /// + /// + /// Optional argument + /// 1-D tensor of length 5. The stride of the sliding window for each + /// dimension of input. Must have strides[0] = strides[4] = 1. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// The data format of the input and output data. With the + /// default format "NDHWC", the data is stored in the order of: + /// [batch, in_depth, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCDHW", the data storage order is: + /// [batch, in_channels, in_depth, in_height, in_width]. + /// + /// + /// Gradients of gradients w.r.t. the input to max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool3d_grad_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPool3DGradGrad") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPool3DGradGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// 4-D. Gradients w.r.t. the output of max_pool. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGrad'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// Gradients w.r.t. the input to max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPoolGrad") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// 4-D. Gradients of gradients w.r.t. the input of max_pool. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGradGrad'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// Gradients of gradients w.r.t. the input to max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = null, string name = "MaxPoolGradGrad") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGradGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// 4-D. Gradients of gradients w.r.t. the input of max_pool. + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGradGradV2'. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// Gradients of gradients w.r.t. the input to max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad_grad_v2(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format = null, string name = "MaxPoolGradGradV2") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGradGradV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes second-order gradients of the maxpooling function. + /// + /// + /// The original input. + /// + /// + /// 4-D with shape [batch, height, width, channels]. Gradients w.r.t. the + /// input of max_pool. + /// + /// + /// The indices of the maximum values chosen for each output of max_pool. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGradGradWithArgmax'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Gradients of gradients w.r.t. the input of max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad_grad_with_argmax(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, string name = "MaxPoolGradGradWithArgmax") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["grad"] = grad; + dict["argmax"] = argmax; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGradGradWithArgmax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// The original input tensor. + /// + /// + /// The original output tensor. + /// + /// + /// 4-D. Gradients w.r.t. the output of max_pool. + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGradV2'. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// Gradients w.r.t. the input to max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad_v2(Tensor orig_input, Tensor orig_output, Tensor grad, Tensor ksize, Tensor strides, string padding, string data_format = null, string name = "MaxPoolGradV2") + { + var dict = new Dictionary(); + dict["orig_input"] = orig_input; + dict["orig_output"] = orig_output; + dict["grad"] = grad; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGradV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients of the maxpooling function. + /// + /// + /// The original input. + /// + /// + /// 4-D with shape [batch, height, width, channels]. Gradients w.r.t. the + /// output of max_pool. + /// + /// + /// The indices of the maximum values chosen for each output of max_pool. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolGradWithArgmax'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Gradients w.r.t. the input of max_pool. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_grad_with_argmax(Tensor input, Tensor grad, Tensor argmax, int[] ksize, int[] strides, string padding, string name = "MaxPoolGradWithArgmax") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["grad"] = grad; + dict["argmax"] = argmax; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("MaxPoolGradWithArgmax", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs max pooling on the input. + /// + /// + /// 4-D input to pool over. + /// + /// + /// The size of the window for each dimension of the input tensor. + /// + /// + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolV2'. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Specify the data format of the input and output data. With the + /// default format "NHWC", the data is stored in the order of: + /// [batch, in_height, in_width, in_channels]. + /// Alternatively, the format could be "NCHW", the data storage order of: + /// [batch, in_channels, in_height, in_width]. + /// + /// + /// The max pooled output tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor max_pool_v2(Tensor input, Tensor ksize, Tensor strides, string padding, string data_format = null, string name = "MaxPoolV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("MaxPoolV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Performs max pooling on the input and outputs both max values and indices. + /// + /// + /// 4-D with shape [batch, height, width, channels]. Input to pool over. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MaxPoolWithArgmax'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the + /// input tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : The max pooled output tensor. + /// argmax : 4-D. The flattened indices of the max values chosen for each output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The indices in argmax are flattened, so that a maximum value at position + /// [b, y, x, c] becomes flattened index + /// ((b * height + y) * width + x) * channels + c. + /// + /// The indices returned are always in [0, height) x [0, width) before flattening, + /// even if padding is involved and the mathematically correct answer is outside + /// (either negative or too large). This is a bug, but fixing it is difficult to do + /// in a safe backwards compatible way, especially due to flattening. + /// + public static (Tensor output, Tensor argmax) max_pool_with_argmax(Tensor input, int[] ksize, int[] strides, string padding, TF_DataType? Targmax = null, string name = "MaxPoolWithArgmax") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + if (Targmax.HasValue) + dict["Targmax"] = Targmax.Value; + var op = tf.OpDefLib._apply_op_helper("MaxPoolWithArgmax", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var argmax = op.outputs[_idx++]; + return (output, argmax); + } + + /// + /// Returns the max of x and y (i.e. x &gt; y ? x : y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Maximum'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Maximum supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor maximum(Tensor x, Tensor y, string name = "Maximum") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Maximum", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the mean of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Mean'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor mean(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Mean") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Mean", name: name, keywords: dict); + return op.output; + } + + /// + /// Forwards the value of an available tensor from inputs to output. + /// + /// + /// The input tensors, exactly one of which will become available. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Merge'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : Will be set to the available input tensor. + /// value_index : The index of the chosen input tensor in inputs. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Merge waits for at least one of the tensors in inputs to become available. + /// It is usually combined with Switch to implement branching. + /// + /// Merge forwards the first tensor to become available to output, and sets + /// value_index to its index in inputs. + /// + public static (Tensor output, Tensor value_index) merge(Tensor[] inputs, string name = "Merge") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("Merge", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var value_index = op.outputs[_idx++]; + return (output, value_index); + } + + /// + /// Merges summaries. + /// + /// + /// Can be of any shape. Each must contain serialized Summary protocol + /// buffers. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MergeSummary'. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a + /// [Summary](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + /// protocol buffer that contains the union of all the values in the input + /// summaries. + /// + /// When the Op is run, it reports an InvalidArgument error if multiple values + /// in the summaries to merge use the same tag. + /// + public static Tensor merge_summary(Tensor[] inputs, string name = "MergeSummary") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("MergeSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// V2 format specific: merges the metadata files of sharded checkpoints. The + /// + /// + /// prefixes of V2 checkpoints to merge. + /// + /// + /// scalar. The desired final prefix. Allowed to be the same + /// as one of the checkpoint_prefixes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MergeV2Checkpoints'. + /// + /// + /// see above. + /// + /// + /// Returns the description of the operation + /// + /// + /// result is one logical checkpoint, with one physical metadata file and renamed + /// data files. + /// + /// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. + /// + /// If delete_old_dirs is true, attempts to delete recursively the dirname of each + /// path in the input checkpoint_prefixes. This is useful when those paths are non + /// user-facing temporary locations. + /// + public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "MergeV2Checkpoints", name, + checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); + result = null; + return null; + //try + //{ + // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); + // result = null; + // return null; + //} + //catch (System.Exception) + //{ + // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, + // allow_missing_files: allow_missing_files, name: name, ctx: ctx); + //} + } + var dict = new Dictionary(); + dict["checkpoint_prefixes"] = checkpoint_prefixes; + dict["destination_prefix"] = destination_prefix; + dict["delete_old_dirs"] = delete_old_dirs; + var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); + return op; + } + + //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) + //{ + // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); + // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); + // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; + // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; + // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); + // result = null; + // return null; + //} + + /// + /// Transforms a spectrogram into a form that's useful for speech recognition. + /// + /// + /// Typically produced by the Spectrogram op, with magnitude_squared + /// set to true. + /// + /// + /// How many samples per second the source audio used. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Mfcc'. + /// + /// + /// The highest frequency to use when calculating the + /// ceptstrum. + /// + /// + /// The lowest frequency to use when calculating the + /// ceptstrum. + /// + /// + /// Resolution of the Mel bank used internally. + /// + /// + /// How many output channels to produce per time slice. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Mel Frequency Cepstral Coefficients are a way of representing audio data that's + /// been effective as an input feature for machine learning. They are created by + /// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the + /// higher frequencies that are less significant to the human ear. They have a long + /// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum + /// is a good resource to learn more. + /// + public static Tensor mfcc(Tensor spectrogram, Tensor sample_rate, float? upper_frequency_limit = null, float? lower_frequency_limit = null, int? filterbank_channel_count = null, int? dct_coefficient_count = null, string name = "Mfcc") + { + var dict = new Dictionary(); + dict["spectrogram"] = spectrogram; + dict["sample_rate"] = sample_rate; + if (upper_frequency_limit.HasValue) + dict["upper_frequency_limit"] = upper_frequency_limit.Value; + if (lower_frequency_limit.HasValue) + dict["lower_frequency_limit"] = lower_frequency_limit.Value; + if (filterbank_channel_count.HasValue) + dict["filterbank_channel_count"] = filterbank_channel_count.Value; + if (dct_coefficient_count.HasValue) + dict["dct_coefficient_count"] = dct_coefficient_count.Value; + var op = tf.OpDefLib._apply_op_helper("Mfcc", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the minimum of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Min'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor min(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Min") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Min", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the min of x and y (i.e. x &lt; y ? x : y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Minimum'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Minimum supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor minimum(Tensor x, Tensor y, string name = "Minimum") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Minimum", name: name, keywords: dict); + return op.output; + } + + /// + /// Pads a tensor with mirrored values. + /// + /// + /// The input tensor to be padded. + /// + /// + /// A two-column matrix specifying the padding sizes. The number of + /// rows must be the same as the rank of input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MirrorPad'. + /// + /// + /// Optional argument + /// Either REFLECT or SYMMETRIC. In reflect mode the padded regions + /// do not include the borders, while in symmetric mode the padded regions + /// do include the borders. For example, if input is [1, 2, 3] and paddings + /// is [0, 2], then the output is [1, 2, 3, 2, 1] in reflect mode, and + /// it is [1, 2, 3, 3, 2] in symmetric mode. + /// + /// + /// The padded tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation pads a input with mirrored values according to the paddings + /// you specify. paddings is an integer tensor with shape [n, 2], where n is + /// the rank of input. For each dimension D of input, paddings[D, 0] indicates + /// how many values to add before the contents of input in that dimension, and + /// paddings[D, 1] indicates how many values to add after the contents of input + /// in that dimension. Both paddings[D, 0] and paddings[D, 1] must be no greater + /// than input.dim_size(D) (or input.dim_size(D) - 1) if copy_border is true + /// (if false, respectively). + /// + /// The padded size of each dimension D of the output is: + /// + /// paddings(D, 0) + input.dim_size(D) + paddings(D, 1) + /// + /// For example: + /// + /// + /// # 't' is [[1, 2, 3], [4, 5, 6]]. + /// # 'paddings' is [[1, 1]], [2, 2]]. + /// # 'mode' is SYMMETRIC. + /// # rank of 't' is 2. + /// pad(t, paddings) ==&gt; [[2, 1, 1, 2, 3, 3, 2] + /// [2, 1, 1, 2, 3, 3, 2] + /// [5, 4, 4, 5, 6, 6, 5] + /// [5, 4, 4, 5, 6, 6, 5]] + /// + /// + public static Tensor mirror_pad(Tensor input, Tensor paddings, string mode, string name = "MirrorPad") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + dict["mode"] = mode; + var op = tf.OpDefLib._apply_op_helper("MirrorPad", name: name, keywords: dict); + return op.output; + } + + /// + /// Gradient op for MirrorPad op. This op folds a mirror-padded tensor. + /// + /// + /// The input tensor to be folded. + /// + /// + /// A two-column matrix specifying the padding sizes. The number of + /// rows must be the same as the rank of input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MirrorPadGrad'. + /// + /// + /// Optional argument + /// The mode used in the MirrorPad op. + /// + /// + /// The folded tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation folds the padded areas of input by MirrorPad according to the + /// paddings you specify. paddings must be the same as paddings argument + /// given to the corresponding MirrorPad op. + /// + /// The folded size of each dimension D of the output is: + /// + /// input.dim_size(D) - paddings(D, 0) - paddings(D, 1) + /// + /// For example: + /// + /// + /// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. + /// # 'paddings' is [[0, 1]], [0, 1]]. + /// # 'mode' is SYMMETRIC. + /// # rank of 't' is 2. + /// pad(t, paddings) ==&gt; [[ 1, 5] + /// [11, 28]] + /// + /// + public static Tensor mirror_pad_grad(Tensor input, Tensor paddings, string mode, string name = "MirrorPadGrad") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + dict["mode"] = mode; + var op = tf.OpDefLib._apply_op_helper("MirrorPadGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns element-wise remainder of division. This emulates C semantics in that + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Mod'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// the result here is consistent with a truncating divide. E.g. + /// tf.truncatediv(x, y) * y + truncate_mod(x, y) = x. + /// + /// *NOTE*: Mod supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor mod(Tensor x, Tensor y, string name = "Mod") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Mod", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x * y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Mul'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Multiply supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor mul(Tensor x, Tensor y, string name = "Mul") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Mul", name: name, keywords: dict); + return op.output; + } + + /// + /// Draws samples from a multinomial distribution. + /// + /// + /// 2-D Tensor with shape [batch_size, num_classes]. Each slice [i, :] + /// represents the unnormalized log probabilities for all classes. + /// + /// + /// 0-D. Number of independent samples to draw for each row slice. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Multinomial'. + /// + /// + /// If either seed or seed2 is set to be non-zero, the internal random number + /// generator is seeded by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// + /// + /// 2-D Tensor with shape [batch_size, num_samples]. Each slice [i, :] + /// contains the drawn class labels with range [0, num_classes). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor multinomial(Tensor logits, Tensor num_samples, int? seed = null, int? seed2 = null, TF_DataType? output_dtype = null, string name = "Multinomial") + { + var dict = new Dictionary(); + dict["logits"] = logits; + dict["num_samples"] = num_samples; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (output_dtype.HasValue) + dict["output_dtype"] = output_dtype.Value; + var op = tf.OpDefLib._apply_op_helper("Multinomial", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table that uses tensors as the backing store. + /// + /// + /// The key used to represent empty key buckets internally. Must not + /// be used in insert or lookup operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableDenseHashTable'. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// + /// + /// The shape of each value. + /// + /// + /// The initial number of hash table buckets. Must be a power + /// to 2. + /// + /// + /// The maximum ratio between number of entries and number of + /// buckets before growing the table. Must be between 0 and 1. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It uses "open addressing" with quadratic reprobing to resolve + /// collisions. + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a scalar. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_dense_hash_table(Tensor empty_key, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, Shape value_shape = null, int? initial_num_buckets = null, float? max_load_factor = null, string name = "MutableDenseHashTable") + { + var dict = new Dictionary(); + dict["empty_key"] = empty_key; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + if (value_shape != null) + dict["value_shape"] = value_shape; + if (initial_num_buckets.HasValue) + dict["initial_num_buckets"] = initial_num_buckets.Value; + if (max_load_factor.HasValue) + dict["max_load_factor"] = max_load_factor.Value; + var op = tf.OpDefLib._apply_op_helper("MutableDenseHashTable", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table that uses tensors as the backing store. + /// + /// + /// The key used to represent empty key buckets internally. Must not + /// be used in insert or lookup operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableDenseHashTableV2'. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// + /// + /// The shape of each value. + /// + /// + /// The initial number of hash table buckets. Must be a power + /// to 2. + /// + /// + /// The maximum ratio between number of entries and number of + /// buckets before growing the table. Must be between 0 and 1. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It uses "open addressing" with quadratic reprobing to resolve + /// collisions. + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a scalar. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_dense_hash_table_v2(Tensor empty_key, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, Shape value_shape = null, int? initial_num_buckets = null, float? max_load_factor = null, string name = "MutableDenseHashTableV2") + { + var dict = new Dictionary(); + dict["empty_key"] = empty_key; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + if (value_shape != null) + dict["value_shape"] = value_shape; + if (initial_num_buckets.HasValue) + dict["initial_num_buckets"] = initial_num_buckets.Value; + if (max_load_factor.HasValue) + dict["max_load_factor"] = max_load_factor.Value; + var op = tf.OpDefLib._apply_op_helper("MutableDenseHashTableV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableHashTable'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// If true and shared_name is empty, the table is shared + /// using the node name. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a scalar. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_hash_table(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, string name = "MutableHashTable") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + var op = tf.OpDefLib._apply_op_helper("MutableHashTable", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableHashTableOfTensors'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// + /// + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a vector. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_hash_table_of_tensors(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, Shape value_shape = null, string name = "MutableHashTableOfTensors") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + if (value_shape != null) + dict["value_shape"] = value_shape; + var op = tf.OpDefLib._apply_op_helper("MutableHashTableOfTensors", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableHashTableOfTensorsV2'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// + /// + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a vector. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_hash_table_of_tensors_v2(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, Shape value_shape = null, string name = "MutableHashTableOfTensorsV2") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + if (value_shape != null) + dict["value_shape"] = value_shape; + var op = tf.OpDefLib._apply_op_helper("MutableHashTableOfTensorsV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an empty hash table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutableHashTableV2'. + /// + /// + /// Optional argument + /// Type of the table keys. + /// + /// + /// Optional argument + /// Type of the table values. + /// + /// + /// If non-empty, this table is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this table is shared under the given name across + /// multiple sessions. + /// + /// + /// If true and shared_name is empty, the table is shared + /// using the node name. + /// + /// + /// Handle to a table. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op creates a mutable hash table, specifying the type of its keys and + /// values. Each value must be a scalar. Data can be inserted into the table using + /// the insert operations. It does not support the initialization operation. + /// + public static Tensor mutable_hash_table_v2(TF_DataType key_dtype, TF_DataType value_dtype, string container = null, string shared_name = null, bool? use_node_name_sharing = null, string name = "MutableHashTableV2") + { + var dict = new Dictionary(); + dict["key_dtype"] = key_dtype; + dict["value_dtype"] = value_dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (use_node_name_sharing.HasValue) + dict["use_node_name_sharing"] = use_node_name_sharing.Value; + var op = tf.OpDefLib._apply_op_helper("MutableHashTableV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Locks a mutex resource. The output is the lock. So long as the lock tensor + /// + /// + /// The mutex resource to lock. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutexLock'. + /// + /// + /// A tensor that keeps a shared pointer to a lock on the mutex; + /// when the Tensor is destroyed, the use count on the shared pointer is decreased + /// by 1. When it reaches 0, the lock is released. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// is alive, any other request to use MutexLock with this mutex will wait. + /// + /// This is particularly useful for creating a critical section when used in + /// conjunction with MutexLockIdentity: + /// + /// + /// + /// mutex = mutex_v2( + /// shared_name=handle_name, container=container, name=name) + /// + /// def execute_in_critical_section(fn, *args, **kwargs): + /// lock = gen_resource_variable_ops.mutex_lock(mutex) + /// + /// with ops.control_dependencies([lock]): + /// r = fn(*args, **kwargs) + /// + /// with ops.control_dependencies(nest.flatten(r)): + /// with ops.colocate_with(mutex): + /// ensure_lock_exists = mutex_lock_identity(lock) + /// + /// # Make sure that if any element of r is accessed, all of + /// # them are executed together. + /// r = nest.map_structure(tf.identity, r) + /// + /// with ops.control_dependencies([ensure_lock_exists]): + /// return nest.map_structure(tf.identity, r) + /// + /// + /// While fn is running in the critical section, no other functions which wish to + /// use this critical section may run. + /// + /// Often the use case is that two executions of the same graph, in parallel, + /// wish to run fn; and we wish to ensure that only one of them executes + /// at a time. This is especially important if fn modifies one or more + /// variables at a time. + /// + /// It is also useful if two separate functions must share a resource, but we + /// wish to ensure the usage is exclusive. + /// + public static Tensor mutex_lock(Tensor mutex, string name = "MutexLock") + { + var dict = new Dictionary(); + dict["mutex"] = mutex; + var op = tf.OpDefLib._apply_op_helper("MutexLock", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a Mutex resource that can be locked by MutexLock. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'MutexV2'. + /// + /// + /// If non-empty, this variable is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this variable is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The mutex resource. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor mutex_v2(string container = null, string shared_name = null, string name = "MutexV2") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("MutexV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes numerical negative value element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Neg'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = -x\\). + /// + public static Tensor neg(Tensor x, string name = "Neg") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Neg", name: name, keywords: dict); + return op.output; + } + + /// + /// Training via negative sampling. + /// + /// + /// input word embedding. + /// + /// + /// output word embedding. + /// + /// + /// A vector of word ids. + /// + /// + /// A vector of word ids. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NegTrain'. + /// + /// + /// Optional argument + /// Count of words in the vocabulary. + /// + /// + /// Optional argument + /// Number of negative samples per example. + /// + /// + /// Returns the description of the operation + /// + public static Operation neg_train(Tensor w_in, Tensor w_out, Tensor examples, Tensor labels, Tensor lr, int[] vocab_count, int num_negative_samples, string name = "NegTrain") + { + var dict = new Dictionary(); + dict["w_in"] = w_in; + dict["w_out"] = w_out; + dict["examples"] = examples; + dict["labels"] = labels; + dict["lr"] = lr; + dict["vocab_count"] = vocab_count; + dict["num_negative_samples"] = num_negative_samples; + var op = tf.OpDefLib._apply_op_helper("NegTrain", name: name, keywords: dict); + return op; + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// The tensor to be made available to the next iteration. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NextIteration'. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor next_iteration(Tensor data, string name = "NextIteration") + { + var dict = new Dictionary(); + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("NextIteration", name: name, keywords: dict); + return op.output; + } + + /// + /// Does nothing. Only useful as a placeholder for control edges. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NoOp'. + /// + /// + /// Returns the description of the operation + /// + public static Operation no_op(string name = "NoOp") + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("NoOp", name: name, keywords: dict); + return op; + } + + /// + /// Greedily selects a subset of bounding boxes in descending order of score, + /// + /// + /// A 2-D float tensor of shape [num_boxes, 4]. + /// + /// + /// A 1-D float tensor of shape [num_boxes] representing a single + /// score corresponding to each box (each row of boxes). + /// + /// + /// A scalar integer tensor representing the maximum number of + /// boxes to be selected by non max suppression. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NonMaxSuppression'. + /// + /// + /// A float representing the threshold for deciding whether boxes + /// overlap too much with respect to IOU. + /// + /// + /// A 1-D integer tensor of shape [M] representing the selected + /// indices from the boxes tensor, where M &lt;= max_output_size. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// pruning away boxes that have high intersection-over-union (IOU) overlap + /// with previously selected boxes. Bounding boxes are supplied as + /// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any + /// diagonal pair of box corners and the coordinates can be provided as normalized + /// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm + /// is agnostic to where the origin is in the coordinate system. Note that this + /// algorithm is invariant to orthogonal transformations and translations + /// of the coordinate system; thus translating or reflections of the coordinate + /// system result in the same boxes being selected by the algorithm. + /// The output of this operation is a set of integers indexing into the input + /// collection of bounding boxes representing the selected boxes. The bounding + /// box coordinates corresponding to the selected indices can then be obtained + /// using the tf.gather operation. For example: + /// selected_indices = tf.image.non_max_suppression( + /// boxes, scores, max_output_size, iou_threshold) + /// selected_boxes = tf.gather(boxes, selected_indices) + /// + public static Tensor non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size, float? iou_threshold = null, string name = "NonMaxSuppression") + { + var dict = new Dictionary(); + dict["boxes"] = boxes; + dict["scores"] = scores; + dict["max_output_size"] = max_output_size; + if (iou_threshold.HasValue) + dict["iou_threshold"] = iou_threshold.Value; + var op = tf.OpDefLib._apply_op_helper("NonMaxSuppression", name: name, keywords: dict); + return op.output; + } + + /// + /// Greedily selects a subset of bounding boxes in descending order of score, + /// + /// + /// A 2-D float tensor of shape [num_boxes, 4]. + /// + /// + /// A 1-D float tensor of shape [num_boxes] representing a single + /// score corresponding to each box (each row of boxes). + /// + /// + /// A scalar integer tensor representing the maximum number of + /// boxes to be selected by non max suppression. + /// + /// + /// A 0-D float tensor representing the threshold for deciding whether + /// boxes overlap too much with respect to IOU. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NonMaxSuppressionV2'. + /// + /// + /// A 1-D integer tensor of shape [M] representing the selected + /// indices from the boxes tensor, where M &lt;= max_output_size. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// pruning away boxes that have high intersection-over-union (IOU) overlap + /// with previously selected boxes. Bounding boxes are supplied as + /// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any + /// diagonal pair of box corners and the coordinates can be provided as normalized + /// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm + /// is agnostic to where the origin is in the coordinate system. Note that this + /// algorithm is invariant to orthogonal transformations and translations + /// of the coordinate system; thus translating or reflections of the coordinate + /// system result in the same boxes being selected by the algorithm. + /// + /// The output of this operation is a set of integers indexing into the input + /// collection of bounding boxes representing the selected boxes. The bounding + /// box coordinates corresponding to the selected indices can then be obtained + /// using the tf.gather operation. For example: + /// + /// selected_indices = tf.image.non_max_suppression_v2( + /// boxes, scores, max_output_size, iou_threshold) + /// selected_boxes = tf.gather(boxes, selected_indices) + /// + public static Tensor non_max_suppression_v2(Tensor boxes, Tensor scores, Tensor max_output_size, Tensor iou_threshold, string name = "NonMaxSuppressionV2") + { + var dict = new Dictionary(); + dict["boxes"] = boxes; + dict["scores"] = scores; + dict["max_output_size"] = max_output_size; + dict["iou_threshold"] = iou_threshold; + var op = tf.OpDefLib._apply_op_helper("NonMaxSuppressionV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Greedily selects a subset of bounding boxes in descending order of score, + /// + /// + /// A 2-D float tensor of shape [num_boxes, 4]. + /// + /// + /// A 1-D float tensor of shape [num_boxes] representing a single + /// score corresponding to each box (each row of boxes). + /// + /// + /// A scalar integer tensor representing the maximum number of + /// boxes to be selected by non max suppression. + /// + /// + /// A 0-D float tensor representing the threshold for deciding whether + /// boxes overlap too much with respect to IOU. + /// + /// + /// A 0-D float tensor representing the threshold for deciding when to remove + /// boxes based on score. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NonMaxSuppressionV3'. + /// + /// + /// A 1-D integer tensor of shape [M] representing the selected + /// indices from the boxes tensor, where M &lt;= max_output_size. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// pruning away boxes that have high intersection-over-union (IOU) overlap + /// with previously selected boxes. Bounding boxes with score less than + /// score_threshold are removed. Bounding boxes are supplied as + /// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any + /// diagonal pair of box corners and the coordinates can be provided as normalized + /// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm + /// is agnostic to where the origin is in the coordinate system and more + /// generally is invariant to orthogonal transformations and translations + /// of the coordinate system; thus translating or reflections of the coordinate + /// system result in the same boxes being selected by the algorithm. + /// The output of this operation is a set of integers indexing into the input + /// collection of bounding boxes representing the selected boxes. The bounding + /// box coordinates corresponding to the selected indices can then be obtained + /// using the tf.gather operation. For example: + /// selected_indices = tf.image.non_max_suppression_v2( + /// boxes, scores, max_output_size, iou_threshold, score_threshold) + /// selected_boxes = tf.gather(boxes, selected_indices) + /// + public static Tensor non_max_suppression_v3(Tensor boxes, Tensor scores, Tensor max_output_size, Tensor iou_threshold, Tensor score_threshold, string name = "NonMaxSuppressionV3") + { + var dict = new Dictionary(); + dict["boxes"] = boxes; + dict["scores"] = scores; + dict["max_output_size"] = max_output_size; + dict["iou_threshold"] = iou_threshold; + dict["score_threshold"] = score_threshold; + var op = tf.OpDefLib._apply_op_helper("NonMaxSuppressionV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Greedily selects a subset of bounding boxes in descending order of score, + /// + /// + /// A 2-D float tensor of shape [num_boxes, 4]. + /// + /// + /// A 1-D float tensor of shape [num_boxes] representing a single + /// score corresponding to each box (each row of boxes). + /// + /// + /// A scalar integer tensor representing the maximum number of + /// boxes to be selected by non max suppression. + /// + /// + /// A 0-D float tensor representing the threshold for deciding whether + /// boxes overlap too much with respect to IOU. + /// + /// + /// A 0-D float tensor representing the threshold for deciding when to remove + /// boxes based on score. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NonMaxSuppressionV4'. + /// + /// + /// If true, the output selected_indices is padded to be of length + /// max_output_size. Defaults to false. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// selected_indices : A 1-D integer tensor of shape [M] representing the selected + /// indices from the boxes tensor, where M &lt;= max_output_size. + /// valid_outputs : A 0-D integer tensor representing the number of valid elements in + /// selected_indices, with the valid elements appearing first. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// pruning away boxes that have high intersection-over-union (IOU) overlap + /// with previously selected boxes. Bounding boxes with score less than + /// score_threshold are removed. Bounding boxes are supplied as + /// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any + /// diagonal pair of box corners and the coordinates can be provided as normalized + /// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm + /// is agnostic to where the origin is in the coordinate system and more + /// generally is invariant to orthogonal transformations and translations + /// of the coordinate system; thus translating or reflections of the coordinate + /// system result in the same boxes being selected by the algorithm. + /// The output of this operation is a set of integers indexing into the input + /// collection of bounding boxes representing the selected boxes. The bounding + /// box coordinates corresponding to the selected indices can then be obtained + /// using the tf.gather operation. For example: + /// selected_indices = tf.image.non_max_suppression_v2( + /// boxes, scores, max_output_size, iou_threshold, score_threshold) + /// selected_boxes = tf.gather(boxes, selected_indices) + /// + public static (Tensor selected_indices, Tensor valid_outputs) non_max_suppression_v4(Tensor boxes, Tensor scores, Tensor max_output_size, Tensor iou_threshold, Tensor score_threshold, bool? pad_to_max_output_size = null, string name = "NonMaxSuppressionV4") + { + var dict = new Dictionary(); + dict["boxes"] = boxes; + dict["scores"] = scores; + dict["max_output_size"] = max_output_size; + dict["iou_threshold"] = iou_threshold; + dict["score_threshold"] = score_threshold; + if (pad_to_max_output_size.HasValue) + dict["pad_to_max_output_size"] = pad_to_max_output_size.Value; + var op = tf.OpDefLib._apply_op_helper("NonMaxSuppressionV4", name: name, keywords: dict); + int _idx = 0; + var selected_indices = op.outputs[_idx++]; + var valid_outputs = op.outputs[_idx++]; + return (selected_indices, valid_outputs); + } + + /// + /// Greedily selects a subset of bounding boxes in descending order of score, + /// + /// + /// A 2-D float tensor of shape [num_boxes, num_boxes] representing + /// the n-by-n box overlap values. + /// + /// + /// A 1-D float tensor of shape [num_boxes] representing a single + /// score corresponding to each box (each row of boxes). + /// + /// + /// A scalar integer tensor representing the maximum number of + /// boxes to be selected by non max suppression. + /// + /// + /// A 0-D float tensor representing the threshold for deciding whether + /// boxes overlap too. + /// + /// + /// A 0-D float tensor representing the threshold for deciding when to remove + /// boxes based on score. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NonMaxSuppressionWithOverlaps'. + /// + /// + /// A 1-D integer tensor of shape [M] representing the selected + /// indices from the boxes tensor, where M &lt;= max_output_size. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// pruning away boxes that have high overlaps + /// with previously selected boxes. Bounding boxes with score less than + /// score_threshold are removed. N-by-n overlap values are supplied as square matrix, + /// which allows for defining a custom overlap criterium (eg. intersection over union, + /// intersection over area, etc.). + /// + /// The output of this operation is a set of integers indexing into the input + /// collection of bounding boxes representing the selected boxes. The bounding + /// box coordinates corresponding to the selected indices can then be obtained + /// using the tf.gather operation. For example: + /// + /// selected_indices = tf.image.non_max_suppression_with_overlaps( + /// overlaps, scores, max_output_size, overlap_threshold, score_threshold) + /// selected_boxes = tf.gather(boxes, selected_indices) + /// + public static Tensor non_max_suppression_with_overlaps(Tensor overlaps, Tensor scores, Tensor max_output_size, Tensor overlap_threshold, Tensor score_threshold, string name = "NonMaxSuppressionWithOverlaps") + { + var dict = new Dictionary(); + dict["overlaps"] = overlaps; + dict["scores"] = scores; + dict["max_output_size"] = max_output_size; + dict["overlap_threshold"] = overlap_threshold; + dict["score_threshold"] = score_threshold; + var op = tf.OpDefLib._apply_op_helper("NonMaxSuppressionWithOverlaps", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the truth value of (x != y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NotEqual'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: NotEqual supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor not_equal(Tensor x, Tensor y, string name = "NotEqual") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("NotEqual", name: name, keywords: dict); + return op.output; + } + + /// + /// Finds values of the n-th order statistic for the last dimension. + /// + /// + /// 1-D or higher with last dimension at least n+1. + /// + /// + /// 0-D. Position of sorted vector to select along the last dimension (along + /// each row for matrices). Valid range of n is [0, input.shape[:-1]) + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'NthElement'. + /// + /// + /// When set to True, find the nth-largest value in the vector and vice + /// versa. + /// + /// + /// The n-th order statistic along each last dimensional slice. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If the input is a vector (rank-1), finds the entries which is the nth-smallest + /// value in the vector and outputs their values as scalar tensor. + /// + /// For matrices (resp. higher rank input), computes the entries which is the + /// nth-smallest value in each row (resp. vector along the last dimension). Thus, + /// + /// values.shape = input.shape[:-1] + /// + public static Tensor nth_element(Tensor input, Tensor n, bool? reverse = null, string name = "NthElement") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["n"] = n; + if (reverse.HasValue) + dict["reverse"] = reverse.Value; + var op = tf.OpDefLib._apply_op_helper("NthElement", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a one-hot tensor. + /// + /// + /// A tensor of indices. + /// + /// + /// A scalar defining the depth of the one hot dimension. + /// + /// + /// A scalar defining the value to fill in output when indices[j] = i. + /// + /// + /// A scalar defining the value to fill in output when indices[j] != i. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OneHot'. + /// + /// + /// The axis to fill (default: -1, a new inner-most axis). + /// + /// + /// The one-hot tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The locations represented by indices in indices take value on_value, + /// while all other locations take value off_value. + /// + /// If the input indices is rank N, the output will have rank N+1, + /// The new axis is created at dimension axis (default: the new axis is + /// appended at the end). + /// + /// If indices is a scalar the output shape will be a vector of length depth. + /// + /// If indices is a vector of length features, the output shape will be: + /// + /// features x depth if axis == -1 + /// depth x features if axis == 0 + /// + /// + /// If indices is a matrix (batch) with shape [batch, features], + /// the output shape will be: + /// + /// batch x features x depth if axis == -1 + /// batch x depth x features if axis == 1 + /// depth x batch x features if axis == 0 + /// + /// + /// + /// Examples + /// ========= + /// + /// Suppose that + /// + /// + /// indices = [0, 2, -1, 1] + /// depth = 3 + /// on_value = 5.0 + /// off_value = 0.0 + /// axis = -1 + /// + /// + /// Then output is [4 x 3]: + /// + /// + /// output = + /// [5.0 0.0 0.0] // one_hot(0) + /// [0.0 0.0 5.0] // one_hot(2) + /// [0.0 0.0 0.0] // one_hot(-1) + /// [0.0 5.0 0.0] // one_hot(1) + /// + /// + /// Suppose that + /// + /// + /// indices = [0, 2, -1, 1] + /// depth = 3 + /// on_value = 0.0 + /// off_value = 3.0 + /// axis = 0 + /// + /// + /// Then output is [3 x 4]: + /// + /// + /// output = + /// [0.0 3.0 3.0 3.0] + /// [3.0 3.0 3.0 0.0] + /// [3.0 3.0 3.0 3.0] + /// [3.0 0.0 3.0 3.0] + /// // ^ one_hot(0) + /// // ^ one_hot(2) + /// // ^ one_hot(-1) + /// // ^ one_hot(1) + /// + /// Suppose that + /// + /// + /// indices = [[0, 2], [1, -1]] + /// depth = 3 + /// on_value = 1.0 + /// off_value = 0.0 + /// axis = -1 + /// + /// + /// Then output is [2 x 2 x 3]: + /// + /// + /// output = + /// [ + /// [1.0, 0.0, 0.0] // one_hot(0) + /// [0.0, 0.0, 1.0] // one_hot(2) + /// ][ + /// [0.0, 1.0, 0.0] // one_hot(1) + /// [0.0, 0.0, 0.0] // one_hot(-1) + /// ] + /// + /// + public static Tensor one_hot(Tensor indices, Tensor depth, Tensor on_value, Tensor off_value, int? axis = null, string name = "OneHot") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["depth"] = depth; + dict["on_value"] = on_value; + dict["off_value"] = off_value; + if (axis.HasValue) + dict["axis"] = axis.Value; + var op = tf.OpDefLib._apply_op_helper("OneHot", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a tensor of ones with the same shape and type as x. + /// + /// + /// a tensor of type T. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OnesLike'. + /// + /// + /// a tensor of the same shape and type as x but filled with ones. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ones_like(Tensor x, string name = "OnesLike") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("OnesLike", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset by applying optimizations to input_dataset. + /// + /// + /// A variant tensor representing the input dataset. + /// + /// + /// A tf.string vector tf.Tensor identifying optimizations to use. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OptimizeDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Creates a dataset by applying optimizations to input_dataset. + /// + public static Tensor optimize_dataset(Tensor input_dataset, Tensor optimizations, TF_DataType[] output_types, Shape[] output_shapes, string name = "OptimizeDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["optimizations"] = optimizations; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("OptimizeDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Constructs an Optional variant from a tuple of tensors. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OptionalFromValue'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor optional_from_value(Tensor[] components, string name = "OptionalFromValue") + { + var dict = new Dictionary(); + dict["components"] = components; + var op = tf.OpDefLib._apply_op_helper("OptionalFromValue", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the value stored in an Optional variant or raises an error if none exists. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OptionalGetValue'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] optional_get_value(Tensor optional, TF_DataType[] output_types, Shape[] output_shapes, string name = "OptionalGetValue") + { + var dict = new Dictionary(); + dict["optional"] = optional; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("OptionalGetValue", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Returns true if and only if the given Optional variant has a value. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OptionalHasValue'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor optional_has_value(Tensor optional, string name = "OptionalHasValue") + { + var dict = new Dictionary(); + dict["optional"] = optional; + var op = tf.OpDefLib._apply_op_helper("OptionalHasValue", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates an Optional variant with no value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OptionalNone'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor optional_none(string name = "OptionalNone") + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("OptionalNone", name: name, keywords: dict); + return op.output; + } + + /// + /// Op removes all elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapClear'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns the description of the operation + /// + public static Operation ordered_map_clear(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapClear") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapClear", name: name, keywords: dict); + return op; + } + + /// + /// Op returns the number of incomplete elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapIncompleteSize'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ordered_map_incomplete_size(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapIncompleteSize") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapIncompleteSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Op peeks at the values at the specified key. If the + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapPeek'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// underlying container does not contain this key + /// this op will block until it does. This Op is optimized for + /// performance. + /// + public static Tensor[] ordered_map_peek(Tensor key, Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapPeek") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapPeek", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Op returns the number of elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapSize'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ordered_map_size(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapSize") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Stage (key, values) in the underlying container which behaves like a ordered + /// + /// + /// int64 + /// + /// + /// + /// + /// a list of tensors + /// dtypes A list of data types that inserted values should adhere to. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapStage'. + /// + /// + /// Optional argument + /// + /// + /// Maximum number of elements in the Staging Area. If &gt; 0, inserts + /// on the container will block when the capacity is reached. + /// + /// + /// + /// + /// If non-empty, this queue is placed in the given container. Otherwise, + /// a default container is used. + /// + /// + /// It is necessary to match this name to the matching Unstage Op. + /// + /// + /// Returns the description of the operation + /// + /// + /// associative container. Elements are ordered by key. + /// + public static Operation ordered_map_stage(Tensor key, Tensor indices, Tensor[] values, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapStage") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["values"] = values; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapStage", name: name, keywords: dict); + return op; + } + + /// + /// Op removes and returns the values associated with the key + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapUnstage'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// from the underlying container. If the underlying container + /// does not contain this key, the op will block until it does. + /// + public static Tensor[] ordered_map_unstage(Tensor key, Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapUnstage") + { + var dict = new Dictionary(); + dict["key"] = key; + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapUnstage", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Op removes and returns the (key, value) element with the smallest + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OrderedMapUnstageNoKey'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// key : + /// values : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// key from the underlying container. If the underlying container + /// does not contain elements, the op will block until it does. + /// + public static (Tensor key, Tensor[] values) ordered_map_unstage_no_key(Tensor indices, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "OrderedMapUnstageNoKey") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("OrderedMapUnstageNoKey", name: name, keywords: dict); + int _idx = 0; + var key = op.outputs[_idx++]; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (key, values); + } + + /// + /// Retrieves a single tensor from the computation outfeed. This operation will + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OutfeedDequeue'. + /// + /// + /// Optional argument + /// The type of elements in the tensor. + /// + /// + /// Optional argument + /// The shape of the tensor. + /// + /// + /// The TPU device to use. This should be -1 when the Op + /// is running on a TPU device, and &gt;= 0 when the Op is running on the CPU + /// device. + /// + /// + /// A tensor that will be read from the device outfeed. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// block indefinitely until data is available. + /// + public static Tensor outfeed_dequeue(TF_DataType dtype, Shape shape, int? device_ordinal = null, string name = "OutfeedDequeue") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + if (device_ordinal.HasValue) + dict["device_ordinal"] = device_ordinal.Value; + var op = tf.OpDefLib._apply_op_helper("OutfeedDequeue", name: name, keywords: dict); + return op.output; + } + + /// + /// Retrieve multiple values that will be emitted by the computation as an XLA + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OutfeedDequeueTuple'. + /// + /// + /// Optional argument + /// The element types of each element in outputs. + /// + /// + /// Optional argument + /// The shapes of each tensor in outputs. + /// + /// + /// The TPU device to use. This should be -1 when the Op + /// is running on a TPU device, and &gt;= 0 when the Op is running on the CPU + /// device. + /// + /// + /// A list of tensors that will be read from the outfeed. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// tuple. This operations will block indefinitely until data is available. + /// Output i corresponds to XLA tuple element i. + /// + public static Tensor[] outfeed_dequeue_tuple(TF_DataType[] dtypes, Shape[] shapes, int? device_ordinal = null, string name = "OutfeedDequeueTuple") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + dict["shapes"] = shapes; + if (device_ordinal.HasValue) + dict["device_ordinal"] = device_ordinal.Value; + var op = tf.OpDefLib._apply_op_helper("OutfeedDequeueTuple", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// An op which emits a single Tensor value from an XLA computation. + /// + /// + /// A tensor that will be inserted into the outfeed queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OutfeedEnqueue'. + /// + /// + /// Returns the description of the operation + /// + public static Operation outfeed_enqueue(Tensor input, string name = "OutfeedEnqueue") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("OutfeedEnqueue", name: name, keywords: dict); + return op; + } + + /// + /// An op which emits multiple Tensor values from an XLA computation. + /// + /// + /// A list of tensors that will be inserted into the outfeed queue as an + /// XLA tuple. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'OutfeedEnqueueTuple'. + /// + /// + /// Returns the description of the operation + /// + public static Operation outfeed_enqueue_tuple(Tensor[] inputs, string name = "OutfeedEnqueueTuple") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("OutfeedEnqueueTuple", name: name, keywords: dict); + return op; + } + + /// + /// Packs a list of N rank-R tensors into one rank-(R+1) tensor. + /// + /// + /// Must be of same shape and type. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Pack'. + /// + /// + /// Dimension along which to pack. Negative values wrap around, so the + /// valid range is [-(R+1), R+1). + /// + /// + /// The packed tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Packs the N tensors in values into a tensor with rank one higher than each + /// tensor in values, by packing them along the axis dimension. + /// Given a list of tensors of shape (A, B, C); + /// + /// if axis == 0 then the output tensor will have the shape (N, A, B, C). + /// if axis == 1 then the output tensor will have the shape (A, N, B, C). + /// Etc. + /// + /// For example: + /// + /// + /// # 'x' is [1, 4] + /// # 'y' is [2, 5] + /// # 'z' is [3, 6] + /// pack([x, y, z]) =&gt; [[1, 4], [2, 5], [3, 6]] # Pack along first dim. + /// pack([x, y, z], axis=1) =&gt; [[1, 2, 3], [4, 5, 6]] + /// + /// + /// This is the opposite of unpack. + /// + public static Tensor pack(Tensor[] values, int? axis = null, string name = "Pack") + { + var dict = new Dictionary(); + dict["values"] = values; + if (axis.HasValue) + dict["axis"] = axis.Value; + var op = tf.OpDefLib._apply_op_helper("Pack", name: name, keywords: dict); + return op.output; + } + + /// + /// Pads a tensor with zeros. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Pad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation pads a input with zeros according to the paddings you + /// specify. paddings is an integer tensor with shape [Dn, 2], where n is the + /// rank of input. For each dimension D of input, paddings[D, 0] indicates + /// how many zeros to add before the contents of input in that dimension, and + /// paddings[D, 1] indicates how many zeros to add after the contents of input + /// in that dimension. + /// + /// The padded size of each dimension D of the output is: + /// + /// paddings(D, 0) + input.dim_size(D) + paddings(D, 1) + /// + /// For example: + /// + /// + /// # 't' is [[1, 1], [2, 2]] + /// # 'paddings' is [[1, 1], [2, 2]] + /// # rank of 't' is 2 + /// pad(t, paddings) ==&gt; [[0, 0, 0, 0, 0, 0] + /// [0, 0, 1, 1, 0, 0] + /// [0, 0, 2, 2, 0, 0] + /// [0, 0, 0, 0, 0, 0]] + /// + /// + /// + public static Tensor pad(Tensor input, Tensor paddings, string name = "Pad") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + var op = tf.OpDefLib._apply_op_helper("Pad", name: name, keywords: dict); + return op.output; + } + + /// + /// Pads a tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PadV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation pads input according to the paddings and constant_values + /// you specify. paddings is an integer tensor with shape [Dn, 2], where n is + /// the rank of input. For each dimension D of input, paddings[D, 0] indicates + /// how many padding values to add before the contents of input in that dimension, + /// and paddings[D, 1] indicates how many padding values to add after the contents + /// of input in that dimension. constant_values is a scalar tensor of the same + /// type as input that indicates the value to use for padding input. + /// + /// The padded size of each dimension D of the output is: + /// + /// paddings(D, 0) + input.dim_size(D) + paddings(D, 1) + /// + /// For example: + /// + /// + /// # 't' is [[1, 1], [2, 2]] + /// # 'paddings' is [[1, 1], [2, 2]] + /// # 'constant_values' is 0 + /// # rank of 't' is 2 + /// pad(t, paddings) ==&gt; [[0, 0, 0, 0, 0, 0] + /// [0, 0, 1, 1, 0, 0] + /// [0, 0, 2, 2, 0, 0] + /// [0, 0, 0, 0, 0, 0]] + /// + /// + public static Tensor pad_v2(Tensor input, Tensor paddings, Tensor constant_values, string name = "PadV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + dict["constant_values"] = constant_values; + var op = tf.OpDefLib._apply_op_helper("PadV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that batches and pads batch_size elements from the input. + /// + /// + /// + /// + /// A scalar representing the number of elements to accumulate in a + /// batch. + /// + /// + /// A list of int64 tensors representing the desired padded shapes + /// of the corresponding output components. These shapes may be partially + /// specified, using -1 to indicate that a particular dimension should be + /// padded to the maximum size of all batch elements. + /// + /// + /// A list of scalars containing the padding value to use for + /// each of the outputs. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PaddedBatchDataset'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor padded_batch_dataset(Tensor input_dataset, Tensor batch_size, Tensor[] padded_shapes, Tensor[] padding_values, Shape[] output_shapes, string name = "PaddedBatchDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["batch_size"] = batch_size; + dict["padded_shapes"] = padded_shapes; + dict["padding_values"] = padding_values; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("PaddedBatchDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that batches and pads batch_size elements from the input. + /// + /// + /// + /// + /// A scalar representing the number of elements to accumulate in a + /// batch. + /// + /// + /// A list of int64 tensors representing the desired padded shapes + /// of the corresponding output components. These shapes may be partially + /// specified, using -1 to indicate that a particular dimension should be + /// padded to the maximum size of all batch elements. + /// + /// + /// A list of scalars containing the padding value to use for + /// each of the outputs. + /// + /// + /// A scalar representing whether the last batch should be dropped in case its size + /// is smaller than desired. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PaddedBatchDatasetV2'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor padded_batch_dataset_v2(Tensor input_dataset, Tensor batch_size, Tensor[] padded_shapes, Tensor[] padding_values, Tensor drop_remainder, Shape[] output_shapes, string name = "PaddedBatchDatasetV2") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["batch_size"] = batch_size; + dict["padded_shapes"] = padded_shapes; + dict["padding_values"] = padding_values; + dict["drop_remainder"] = drop_remainder; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("PaddedBatchDatasetV2", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that produces elements in first-in first-out order. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PaddingFIFOQueue'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. + /// Shapes of fixed rank but variable size are allowed by setting + /// any shape dimension to -1. In this case, the inputs' shape may vary along + /// the given dimension, and DequeueMany will pad the given dimension with + /// zeros up to the maximum shape of all elements in the given batch. + /// If the length of this attr is 0, different queue elements may have + /// different ranks and shapes, but only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Variable-size shapes are allowed by setting the corresponding shape dimensions + /// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum + /// size of any given element in the minibatch. See below for details. + /// + public static Tensor padding_f_i_f_o_queue(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, string container = null, string shared_name = null, string name = "PaddingFIFOQueue") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("PaddingFIFOQueue", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that produces elements in first-in first-out order. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PaddingFIFOQueueV2'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. + /// Shapes of fixed rank but variable size are allowed by setting + /// any shape dimension to -1. In this case, the inputs' shape may vary along + /// the given dimension, and DequeueMany will pad the given dimension with + /// zeros up to the maximum shape of all elements in the given batch. + /// If the length of this attr is 0, different queue elements may have + /// different ranks and shapes, but only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Variable-size shapes are allowed by setting the corresponding shape dimensions + /// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum + /// size of any given element in the minibatch. See below for details. + /// + public static Tensor padding_f_i_f_o_queue_v2(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, string container = null, string shared_name = null, string name = "PaddingFIFOQueueV2") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("PaddingFIFOQueueV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Concatenates a list of N tensors along the first dimension. + /// + /// + /// Tensors to be concatenated. All must have size 1 in the first dimension + /// and same shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParallelConcat'. + /// + /// + /// Optional argument + /// the final shape of the result; should be equal to the shapes of any input + /// but with the number of input values in the first dimension. + /// + /// + /// The concatenated tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input tensors are all required to have size 1 in the first dimension. + /// + /// For example: + /// + /// + /// # 'x' is [[1, 4]] + /// # 'y' is [[2, 5]] + /// # 'z' is [[3, 6]] + /// parallel_concat([x, y, z]) =&gt; [[1, 4], [2, 5], [3, 6]] # Pack along first dim. + /// + /// + /// The difference between concat and parallel_concat is that concat requires all + /// of the inputs be computed before the operation will begin but doesn't require + /// that the input shapes be known during graph construction. Parallel concat + /// will copy pieces of the input into the output as they become available, in + /// some situations this can provide a performance benefit. + /// + public static Tensor parallel_concat(Tensor[] values, Shape shape, string name = "ParallelConcat") + { + var dict = new Dictionary(); + dict["values"] = values; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("ParallelConcat", name: name, keywords: dict); + return op.output; + } + + /// + /// Interleave the values from the data tensors into a single tensor. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParallelDynamicStitch'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Builds a merged tensor such that + /// + /// + /// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] + /// + /// + /// For example, if each indices[m] is scalar or vector, we have + /// + /// + /// # Scalar indices: + /// merged[indices[m], ...] = data[m][...] + /// + /// # Vector indices: + /// merged[indices[m][i], ...] = data[m][i, ...] + /// + /// + /// Each data[i].shape must start with the corresponding indices[i].shape, + /// and the rest of data[i].shape must be constant w.r.t. i. That is, we + /// must have data[i].shape = indices[i].shape + constant. In terms of this + /// constant, the output shape is + /// + /// merged.shape = [max(indices)] + constant + /// + /// Values may be merged in parallel, so if an index appears in both indices[m][i] + /// and indices[n][j], the result may be invalid. This differs from the normal + /// DynamicStitch operator that defines the behavior in that case. + /// + /// For example: + /// + /// + /// indices[0] = 6 + /// indices[1] = [4, 1] + /// indices[2] = [[5, 2], [0, 3]] + /// data[0] = [61, 62] + /// data[1] = [[41, 42], [11, 12]] + /// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] + /// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], + /// [51, 52], [61, 62]] + /// + /// + /// This method can be used to merge partitions created by dynamic_partition + /// as illustrated on the following example: + /// + /// + /// # Apply function (increments x_i) on elements for which a certain condition + /// # apply (x_i != -1 in this example). + /// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) + /// condition_mask=tf.not_equal(x,tf.constant(-1.)) + /// partitioned_data = tf.dynamic_partition( + /// x, tf.cast(condition_mask, tf.int32) , 2) + /// partitioned_data[1] = partitioned_data[1] + 1.0 + /// condition_indices = tf.dynamic_partition( + /// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) + /// x = tf.dynamic_stitch(condition_indices, partitioned_data) + /// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain + /// # unchanged. + /// + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/DynamicStitch.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor parallel_dynamic_stitch(Tensor[] indices, Tensor[] data, string name = "ParallelDynamicStitch") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("ParallelDynamicStitch", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from a normal distribution. The parameters may each be a + /// + /// + /// The shape of the output tensor. Batches are indexed by the 0th dimension. + /// + /// + /// The mean parameter of each batch. + /// + /// + /// The standard deviation parameter of each batch. Must be greater than 0. + /// + /// + /// The minimum cutoff. May be -infinity. + /// + /// + /// The maximum cutoff. May be +infinity, and must be more than the minval + /// for each batch. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParameterizedTruncatedNormal'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A matrix of shape num_batches x samples_per_batch, filled with random + /// truncated normal values using the parameters for each row. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// scalar which applies to the entire output, or a vector of length shape[0] which + /// stores the parameters for each batch. + /// + public static Tensor parameterized_truncated_normal(Tensor shape, Tensor means, Tensor stdevs, Tensor minvals, Tensor maxvals, int? seed = null, int? seed2 = null, string name = "ParameterizedTruncatedNormal") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["means"] = means; + dict["stdevs"] = stdevs; + dict["minvals"] = minvals; + dict["maxvals"] = maxvals; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("ParameterizedTruncatedNormal", name: name, keywords: dict); + return op.output; + } + + /// + /// Transforms a vector of brain.Example protos (as strings) into typed tensors. + /// + /// + /// A vector containing a batch of binary serialized Example protos. + /// + /// + /// A vector containing the names of the serialized protos. + /// May contain, for example, table key (descriptive) names for the + /// corresponding serialized protos. These are purely useful for debugging + /// purposes, and the presence of values here has no effect on the output. + /// May also be an empty vector if no names are available. + /// If non-empty, this vector must be the same length as "serialized". + /// + /// + /// A list of Nsparse string Tensors (scalars). + /// The keys expected in the Examples' features associated with sparse values. + /// + /// + /// A list of Ndense string Tensors (scalars). + /// The keys expected in the Examples' features associated with dense values. + /// + /// + /// A list of Ndense Tensors (some may be empty). + /// dense_defaults[j] provides default values + /// when the example's feature_map lacks dense_key[j]. If an empty Tensor is + /// provided for dense_defaults[j], then the Feature dense_keys[j] is required. + /// The input type is inferred from dense_defaults[j], even when it's empty. + /// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, + /// then the shape of dense_defaults[j] must match that of dense_shapes[j]. + /// If dense_shapes[j] has an undefined major dimension (variable strides dense + /// feature), dense_defaults[j] must contain a single element: + /// the padding element. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseExample'. + /// + /// + /// Optional argument + /// A list of Nsparse types; the data types of data in each Feature + /// given in sparse_keys. + /// Currently the ParseExample supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// Optional argument + /// A list of Ndense shapes; the shapes of data in each Feature + /// given in dense_keys. + /// The number of elements in the Feature corresponding to dense_key[j] + /// must always equal dense_shapes[j].NumEntries(). + /// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output + /// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): + /// The dense outputs are just the inputs row-stacked by batch. + /// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case + /// the shape of the output Tensor dense_values[j] will be + /// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks + /// of elements of length D1 * .... * DN, across all minibatch entries + /// in the input. Any minibatch entry with less than M blocks of elements of + /// length D1 * ... * DN will be padded with the corresponding default_value + /// scalar element along the second dimension. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sparse_indices : + /// sparse_values : + /// sparse_shapes : + /// dense_values : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor[] sparse_indices, Tensor[] sparse_values, Tensor[] sparse_shapes, Tensor[] dense_values) parse_example(Tensor serialized, Tensor names, Tensor[] sparse_keys, Tensor[] dense_keys, Tensor[] dense_defaults, TF_DataType[] sparse_types, Shape[] dense_shapes, string name = "ParseExample") + { + var dict = new Dictionary(); + dict["serialized"] = serialized; + dict["names"] = names; + dict["sparse_keys"] = sparse_keys; + dict["dense_keys"] = dense_keys; + dict["dense_defaults"] = dense_defaults; + dict["sparse_types"] = sparse_types; + dict["dense_shapes"] = dense_shapes; + var op = tf.OpDefLib._apply_op_helper("ParseExample", name: name, keywords: dict); + int _idx = 0; + var sparse_indices = Enumerable.Range(0, op.OutputListLength("sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var sparse_values = Enumerable.Range(0, op.OutputListLength("sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var sparse_shapes = Enumerable.Range(0, op.OutputListLength("sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var dense_values = Enumerable.Range(0, op.OutputListLength("dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (sparse_indices, sparse_values, sparse_shapes, dense_values); + } + + /// + /// Transforms input_dataset containing Example protos as vectors of DT_STRING into a dataset of Tensor or SparseTensor objects representing the parsed features. + /// + /// + /// + /// + /// + /// + /// A dict mapping string keys to Tensors. + /// The keys of the dict must match the dense_keys of the feature. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseExampleDataset'. + /// + /// + /// Optional argument + /// A list of string keys in the examples features. + /// The results for these keys will be returned as SparseTensor objects. + /// + /// + /// Optional argument + /// A list of Ndense string Tensors (scalars). + /// The keys expected in the Examples features associated with dense values. + /// + /// + /// Optional argument + /// A list of DTypes of the same length as sparse_keys. + /// Only tf.float32 (FloatList), tf.int64 (Int64List), + /// and tf.string (BytesList) are supported. + /// + /// + /// Optional argument + /// List of tuples with the same length as dense_keys. + /// The shape of the data for each dense feature referenced by dense_keys. + /// Required for any input tensors identified by dense_keys. Must be + /// either fully defined, or may contain an unknown first dimension. + /// An unknown first dimension means the feature is treated as having + /// a variable number of blocks, and the output shape along this dimension + /// is considered unknown at graph build time. Padding is applied for + /// minibatch elements smaller than the maximum number of blocks for the + /// given feature along this dimension. + /// + /// + /// Optional argument + /// The type list for the return values. + /// + /// + /// Optional argument + /// The list of shapes being produced. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor parse_example_dataset(Tensor input_dataset, Tensor num_parallel_calls, Tensor[] dense_defaults, string[] sparse_keys, string[] dense_keys, TF_DataType[] sparse_types, Shape[] dense_shapes, TF_DataType[] output_types, Shape[] output_shapes, string name = "ParseExampleDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["num_parallel_calls"] = num_parallel_calls; + dict["dense_defaults"] = dense_defaults; + dict["sparse_keys"] = sparse_keys; + dict["dense_keys"] = dense_keys; + dict["sparse_types"] = sparse_types; + dict["dense_shapes"] = dense_shapes; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("ParseExampleDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors. + /// + /// + /// A vector containing binary serialized SequenceExample protos. + /// + /// + /// A vector containing the names of the serialized protos. + /// May contain, for example, table key (descriptive) name for the + /// corresponding serialized proto. This is purely useful for debugging + /// purposes, and the presence of values here has no effect on the output. + /// May also be an empty vector if no name is available. + /// + /// + /// A list of Ncontext_dense Tensors (some may be empty). + /// context_dense_defaults[j] provides default values + /// when the SequenceExample's context map lacks context_dense_key[j]. + /// If an empty Tensor is provided for context_dense_defaults[j], + /// then the Feature context_dense_keys[j] is required. + /// The input type is inferred from context_dense_defaults[j], even when it's + /// empty. If context_dense_defaults[j] is not empty, its shape must match + /// context_dense_shapes[j]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseSequenceExample'. + /// + /// + /// Optional argument + /// A vector listing the + /// FeatureList keys which may be missing from the SequenceExamples. If the + /// associated FeatureList is missing, it is treated as empty. By default, + /// any FeatureList not listed in this vector must exist in the SequenceExamples. + /// + /// + /// Optional argument + /// A list of Ncontext_sparse string Tensors (scalars). + /// The keys expected in the Examples' features associated with context_sparse + /// values. + /// + /// + /// Optional argument + /// A list of Ncontext_dense string Tensors (scalars). + /// The keys expected in the SequenceExamples' context features associated with + /// dense values. + /// + /// + /// Optional argument + /// A list of Nfeature_list_sparse string Tensors + /// (scalars). The keys expected in the FeatureLists associated with sparse + /// values. + /// + /// + /// Optional argument + /// A list of Nfeature_list_dense string Tensors (scalars). + /// The keys expected in the SequenceExamples' feature_lists associated + /// with lists of dense values. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A list of Ncontext_sparse types; the data types of data in + /// each context Feature given in context_sparse_keys. + /// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// + /// + /// A list of Ncontext_dense shapes; the shapes of data in + /// each context Feature given in context_dense_keys. + /// The number of elements in the Feature corresponding to context_dense_key[j] + /// must always equal context_dense_shapes[j].NumEntries(). + /// The shape of context_dense_values[j] will match context_dense_shapes[j]. + /// + /// + /// A list of Nfeature_list_sparse types; the data types + /// of data in each FeatureList given in feature_list_sparse_keys. + /// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// A list of Nfeature_list_dense shapes; the shapes of + /// data in each FeatureList given in feature_list_dense_keys. + /// The shape of each Feature in the FeatureList corresponding to + /// feature_list_dense_key[j] must always equal + /// feature_list_dense_shapes[j].NumEntries(). + /// + /// + /// Returns a tuple with multiple values, as follows: + /// context_sparse_indices : + /// context_sparse_values : + /// context_sparse_shapes : + /// context_dense_values : + /// feature_list_sparse_indices : + /// feature_list_sparse_values : + /// feature_list_sparse_shapes : + /// feature_list_dense_values : + /// feature_list_dense_lengths : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor[] context_sparse_indices, Tensor[] context_sparse_values, Tensor[] context_sparse_shapes, Tensor[] context_dense_values, Tensor[] feature_list_sparse_indices, Tensor[] feature_list_sparse_values, Tensor[] feature_list_sparse_shapes, Tensor[] feature_list_dense_values, Tensor[] feature_list_dense_lengths) parse_sequence_example(Tensor serialized, Tensor debug_name, Tensor[] context_dense_defaults, string[] feature_list_dense_missing_assumed_empty, string[] context_sparse_keys, string[] context_dense_keys, string[] feature_list_sparse_keys, string[] feature_list_dense_keys, int? Ncontext_sparse = null, int? Ncontext_dense = null, int? Nfeature_list_sparse = null, int? Nfeature_list_dense = null, TF_DataType[] context_sparse_types = null, TF_DataType[] feature_list_dense_types = null, Shape[] context_dense_shapes = null, TF_DataType[] feature_list_sparse_types = null, Shape[] feature_list_dense_shapes = null, string name = "ParseSequenceExample") + { + var dict = new Dictionary(); + dict["serialized"] = serialized; + dict["debug_name"] = debug_name; + dict["context_dense_defaults"] = context_dense_defaults; + dict["feature_list_dense_missing_assumed_empty"] = feature_list_dense_missing_assumed_empty; + dict["context_sparse_keys"] = context_sparse_keys; + dict["context_dense_keys"] = context_dense_keys; + dict["feature_list_sparse_keys"] = feature_list_sparse_keys; + dict["feature_list_dense_keys"] = feature_list_dense_keys; + if (Ncontext_sparse.HasValue) + dict["Ncontext_sparse"] = Ncontext_sparse.Value; + if (Ncontext_dense.HasValue) + dict["Ncontext_dense"] = Ncontext_dense.Value; + if (Nfeature_list_sparse.HasValue) + dict["Nfeature_list_sparse"] = Nfeature_list_sparse.Value; + if (Nfeature_list_dense.HasValue) + dict["Nfeature_list_dense"] = Nfeature_list_dense.Value; + if (context_sparse_types != null) + dict["context_sparse_types"] = context_sparse_types; + if (feature_list_dense_types != null) + dict["feature_list_dense_types"] = feature_list_dense_types; + if (context_dense_shapes != null) + dict["context_dense_shapes"] = context_dense_shapes; + if (feature_list_sparse_types != null) + dict["feature_list_sparse_types"] = feature_list_sparse_types; + if (feature_list_dense_shapes != null) + dict["feature_list_dense_shapes"] = feature_list_dense_shapes; + var op = tf.OpDefLib._apply_op_helper("ParseSequenceExample", name: name, keywords: dict); + int _idx = 0; + var context_sparse_indices = Enumerable.Range(0, op.OutputListLength("context_sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_sparse_values = Enumerable.Range(0, op.OutputListLength("context_sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_sparse_shapes = Enumerable.Range(0, op.OutputListLength("context_sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_dense_values = Enumerable.Range(0, op.OutputListLength("context_dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_indices = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_values = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_shapes = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_dense_values = Enumerable.Range(0, op.OutputListLength("feature_list_dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_dense_lengths = Enumerable.Range(0, op.OutputListLength("feature_list_dense_lengths")).Select(_ => op.outputs[_idx++]).ToArray(); + return (context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths); + } + + /// + /// Transforms a tf.Example proto (as a string) into typed tensors. + /// + /// + /// A vector containing a batch of binary serialized Example protos. + /// + /// + /// A list of Tensors (some may be empty), whose length matches + /// the length of dense_keys. dense_defaults[j] provides default values + /// when the example's feature_map lacks dense_key[j]. If an empty Tensor is + /// provided for dense_defaults[j], then the Feature dense_keys[j] is required. + /// The input type is inferred from dense_defaults[j], even when it's empty. + /// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, + /// then the shape of dense_defaults[j] must match that of dense_shapes[j]. + /// If dense_shapes[j] has an undefined major dimension (variable strides dense + /// feature), dense_defaults[j] must contain a single element: + /// the padding element. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseSingleExample'. + /// + /// + /// Optional argument + /// The number of sparse features to be parsed from the example. This + /// must match the lengths of sparse_keys and sparse_types. + /// + /// + /// Optional argument + /// A list of num_sparse strings. + /// The keys expected in the Examples' features associated with sparse values. + /// + /// + /// Optional argument + /// The keys expected in the Examples' features associated with dense + /// values. + /// + /// + /// Optional argument + /// A list of num_sparse types; the data types of data in each + /// Feature given in sparse_keys. + /// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// Optional argument + /// The shapes of data in each Feature given in dense_keys. + /// The length of this list must match the length of dense_keys. The + /// number of elements in the Feature corresponding to dense_key[j] must + /// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == + /// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] + /// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, + /// ..., DN), the shape of the output Tensor dense_values[j] will be (M, + /// D1, .., DN), where M is the number of blocks of elements of length + /// D1 * .... * DN, in the input. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sparse_indices : + /// sparse_values : + /// sparse_shapes : + /// dense_values : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor[] sparse_indices, Tensor[] sparse_values, Tensor[] sparse_shapes, Tensor[] dense_values) parse_single_example(Tensor serialized, Tensor[] dense_defaults, int num_sparse, string[] sparse_keys, string[] dense_keys, TF_DataType[] sparse_types, Shape[] dense_shapes, string name = "ParseSingleExample") + { + var dict = new Dictionary(); + dict["serialized"] = serialized; + dict["dense_defaults"] = dense_defaults; + dict["num_sparse"] = num_sparse; + dict["sparse_keys"] = sparse_keys; + dict["dense_keys"] = dense_keys; + dict["sparse_types"] = sparse_types; + dict["dense_shapes"] = dense_shapes; + var op = tf.OpDefLib._apply_op_helper("ParseSingleExample", name: name, keywords: dict); + int _idx = 0; + var sparse_indices = Enumerable.Range(0, op.OutputListLength("sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var sparse_values = Enumerable.Range(0, op.OutputListLength("sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var sparse_shapes = Enumerable.Range(0, op.OutputListLength("sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var dense_values = Enumerable.Range(0, op.OutputListLength("dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (sparse_indices, sparse_values, sparse_shapes, dense_values); + } + + /// + /// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. + /// + /// + /// A scalar containing a binary serialized SequenceExample proto. + /// + /// + /// A vector listing the + /// FeatureList keys which may be missing from the SequenceExample. If the + /// associated FeatureList is missing, it is treated as empty. By default, + /// any FeatureList not listed in this vector must exist in the SequenceExample. + /// + /// + /// A list of Ncontext_sparse string Tensors (scalars). + /// The keys expected in the Examples' features associated with context_sparse + /// values. + /// + /// + /// A list of Ncontext_dense string Tensors (scalars). + /// The keys expected in the SequenceExamples' context features associated with + /// dense values. + /// + /// + /// A list of Nfeature_list_sparse string Tensors + /// (scalars). The keys expected in the FeatureLists associated with sparse + /// values. + /// + /// + /// A list of Nfeature_list_dense string Tensors (scalars). + /// The keys expected in the SequenceExamples' feature_lists associated + /// with lists of dense values. + /// + /// + /// A list of Ncontext_dense Tensors (some may be empty). + /// context_dense_defaults[j] provides default values + /// when the SequenceExample's context map lacks context_dense_key[j]. + /// If an empty Tensor is provided for context_dense_defaults[j], + /// then the Feature context_dense_keys[j] is required. + /// The input type is inferred from context_dense_defaults[j], even when it's + /// empty. If context_dense_defaults[j] is not empty, its shape must match + /// context_dense_shapes[j]. + /// + /// + /// A scalar containing the name of the serialized proto. + /// May contain, for example, table key (descriptive) name for the + /// corresponding serialized proto. This is purely useful for debugging + /// purposes, and the presence of values here has no effect on the output. + /// May also be an empty scalar if no name is available. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseSingleSequenceExample'. + /// + /// + /// A list of Ncontext_sparse types; the data types of data in + /// each context Feature given in context_sparse_keys. + /// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// + /// + /// A list of Ncontext_dense shapes; the shapes of data in + /// each context Feature given in context_dense_keys. + /// The number of elements in the Feature corresponding to context_dense_key[j] + /// must always equal context_dense_shapes[j].NumEntries(). + /// The shape of context_dense_values[j] will match context_dense_shapes[j]. + /// + /// + /// A list of Nfeature_list_sparse types; the data types + /// of data in each FeatureList given in feature_list_sparse_keys. + /// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), + /// DT_INT64 (Int64List), and DT_STRING (BytesList). + /// + /// + /// A list of Nfeature_list_dense shapes; the shapes of + /// data in each FeatureList given in feature_list_dense_keys. + /// The shape of each Feature in the FeatureList corresponding to + /// feature_list_dense_key[j] must always equal + /// feature_list_dense_shapes[j].NumEntries(). + /// + /// + /// Returns a tuple with multiple values, as follows: + /// context_sparse_indices : + /// context_sparse_values : + /// context_sparse_shapes : + /// context_dense_values : + /// feature_list_sparse_indices : + /// feature_list_sparse_values : + /// feature_list_sparse_shapes : + /// feature_list_dense_values : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor[] context_sparse_indices, Tensor[] context_sparse_values, Tensor[] context_sparse_shapes, Tensor[] context_dense_values, Tensor[] feature_list_sparse_indices, Tensor[] feature_list_sparse_values, Tensor[] feature_list_sparse_shapes, Tensor[] feature_list_dense_values) parse_single_sequence_example(Tensor serialized, Tensor feature_list_dense_missing_assumed_empty, Tensor[] context_sparse_keys, Tensor[] context_dense_keys, Tensor[] feature_list_sparse_keys, Tensor[] feature_list_dense_keys, Tensor[] context_dense_defaults, Tensor debug_name, TF_DataType[] context_sparse_types = null, TF_DataType[] feature_list_dense_types = null, Shape[] context_dense_shapes = null, TF_DataType[] feature_list_sparse_types = null, Shape[] feature_list_dense_shapes = null, string name = "ParseSingleSequenceExample") + { + var dict = new Dictionary(); + dict["serialized"] = serialized; + dict["feature_list_dense_missing_assumed_empty"] = feature_list_dense_missing_assumed_empty; + dict["context_sparse_keys"] = context_sparse_keys; + dict["context_dense_keys"] = context_dense_keys; + dict["feature_list_sparse_keys"] = feature_list_sparse_keys; + dict["feature_list_dense_keys"] = feature_list_dense_keys; + dict["context_dense_defaults"] = context_dense_defaults; + dict["debug_name"] = debug_name; + if (context_sparse_types != null) + dict["context_sparse_types"] = context_sparse_types; + if (feature_list_dense_types != null) + dict["feature_list_dense_types"] = feature_list_dense_types; + if (context_dense_shapes != null) + dict["context_dense_shapes"] = context_dense_shapes; + if (feature_list_sparse_types != null) + dict["feature_list_sparse_types"] = feature_list_sparse_types; + if (feature_list_dense_shapes != null) + dict["feature_list_dense_shapes"] = feature_list_dense_shapes; + var op = tf.OpDefLib._apply_op_helper("ParseSingleSequenceExample", name: name, keywords: dict); + int _idx = 0; + var context_sparse_indices = Enumerable.Range(0, op.OutputListLength("context_sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_sparse_values = Enumerable.Range(0, op.OutputListLength("context_sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_sparse_shapes = Enumerable.Range(0, op.OutputListLength("context_sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var context_dense_values = Enumerable.Range(0, op.OutputListLength("context_dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_indices = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_values = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_sparse_shapes = Enumerable.Range(0, op.OutputListLength("feature_list_sparse_shapes")).Select(_ => op.outputs[_idx++]).ToArray(); + var feature_list_dense_values = Enumerable.Range(0, op.OutputListLength("feature_list_dense_values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values); + } + + /// + /// Transforms a serialized tensorflow.TensorProto proto into a Tensor. + /// + /// + /// A scalar string containing a serialized TensorProto proto. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ParseTensor'. + /// + /// + /// Optional argument + /// The type of the serialized tensor. The provided type must match the + /// type of the serialized tensor and no implicit conversion will take place. + /// + /// + /// A Tensor of type out_type. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor parse_tensor(Tensor serialized, TF_DataType out_type, string name = "ParseTensor") + { + var dict = new Dictionary(); + dict["serialized"] = serialized; + dict["out_type"] = out_type; + var op = tf.OpDefLib._apply_op_helper("ParseTensor", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder op for a value that will be fed into the computation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Placeholder'. + /// + /// + /// Optional argument + /// The type of elements in the tensor. + /// + /// + /// (Optional) The shape of the tensor. If the shape has 0 dimensions, the + /// shape is unconstrained. + /// + /// + /// A placeholder tensor that must be replaced using the feed mechanism. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// N.B. This operation will fail with an error if it is executed. It is + /// intended as a way to represent a value that will always be fed, and to + /// provide attrs that enable the fed value to be checked at runtime. + /// + public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string name = "Placeholder") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + if (shape != null) + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder op for a value that will be fed into the computation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PlaceholderV2'. + /// + /// + /// Optional argument + /// The type of elements in the tensor. + /// + /// + /// Optional argument + /// The shape of the tensor. The shape can be any partially-specified + /// shape. To be unconstrained, pass in a shape with unknown rank. + /// + /// + /// A placeholder tensor that must be replaced using the feed mechanism. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// N.B. This operation will fail with an error if it is executed. It is + /// intended as a way to represent a value that will always be fed, and to + /// provide attrs that enable the fed value to be checked at runtime. + /// + public static Tensor placeholder_v2(TF_DataType dtype, Shape shape, string name = "PlaceholderV2") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("PlaceholderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder op that passes through input when its output is not fed. + /// + /// + /// The default value to produce when output is not fed. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PlaceholderWithDefault'. + /// + /// + /// Optional argument + /// The (possibly partial) shape of the tensor. + /// + /// + /// A placeholder tensor that defaults to input if it is not fed. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor placeholder_with_default(Tensor input, Shape shape, string name = "PlaceholderWithDefault") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("PlaceholderWithDefault", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the polygamma function \\(\psi^{(n)}(x)\\). + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Polygamma'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The polygamma function is defined as: + /// + /// + /// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) + /// + /// where \\(\psi(x)\\) is the digamma function. + /// + public static Tensor polygamma(Tensor a, Tensor x, string name = "Polygamma") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Polygamma", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PopulationCount'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For each entry in x, calculates the number of 1 (on) bits in the binary + /// representation of that entry. + /// + /// **NOTE**: It is more efficient to first tf.bitcast your tensors into + /// int32 or int64 and perform the bitcount on the result, than to feed in + /// 8- or 16-bit inputs and then aggregate the resulting counts. + /// + public static Tensor population_count(Tensor x, string name = "PopulationCount") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("PopulationCount", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the power of one value to another. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Pow'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor x and a tensor y, this operation computes \\(x^y\\) for + /// corresponding elements in x and y. For example: + /// + /// + /// # tensor 'x' is [[2, 2]], [3, 3]] + /// # tensor 'y' is [[8, 16], [2, 3]] + /// tf.pow(x, y) ==&gt; [[256, 65536], [9, 27]] + /// + /// + public static Tensor pow(Tensor x, Tensor y, string name = "Pow") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Pow", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that asynchronously prefetches elements from input_dataset. + /// + /// + /// + /// + /// The maximum number of elements to buffer in an iterator over + /// this dataset. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PrefetchDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor prefetch_dataset(Tensor input_dataset, Tensor buffer_size, TF_DataType[] output_types, Shape[] output_shapes, string name = "PrefetchDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["buffer_size"] = buffer_size; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("PrefetchDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// An identity op that triggers an error if a gradient is requested. + /// + /// + /// any tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. + /// + /// + /// Will be printed in the error when anyone tries to differentiate + /// this operation. + /// + /// + /// the same input tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, the TensorFlow gradient system + /// will return an error when trying to lookup the gradient of this op, + /// because no gradient must ever be registered for this function. This + /// op exists to prevent subtle bugs from silently returning unimplemented + /// gradients in some corner cases. + /// + public static Tensor prevent_gradient(Tensor input, string message = null, string name = "PreventGradient") + { + var dict = new Dictionary(); + dict["input"] = input; + if (message != null) + dict["message"] = message; + var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, keywords: dict); + return op.output; + } + + /// + /// Prints a list of tensors. + /// + /// + /// The tensor passed to output + /// + /// + /// A list of tensors to print out when op is evaluated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Print'. + /// + /// + /// A string, prefix of the error message. + /// + /// + /// Only log first_n number of times. -1 disables logging. + /// + /// + /// Only print this many entries of each tensor. + /// + /// + /// = The unmodified input tensor + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Passes input through to output and prints data when evaluating. + /// + public static Tensor print(Tensor input, Tensor[] data, string message = null, int? first_n = null, int? summarize = null, string name = "Print") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["data"] = data; + if (message != null) + dict["message"] = message; + if (first_n.HasValue) + dict["first_n"] = first_n.Value; + if (summarize.HasValue) + dict["summarize"] = summarize.Value; + var op = tf.OpDefLib._apply_op_helper("Print", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that produces elements sorted by the first component value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PriorityQueue'. + /// + /// + /// Optional argument + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The type of each component in a value. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that the PriorityQueue requires the first component of any element + /// to be a scalar int64, in addition to the other elements declared by + /// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue + /// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra + /// entry in their input (resp. output) lists. + /// + public static Tensor priority_queue(Shape[] shapes, TF_DataType[] component_types = null, int? capacity = null, string container = null, string shared_name = null, string name = "PriorityQueue") + { + var dict = new Dictionary(); + dict["shapes"] = shapes; + if (component_types != null) + dict["component_types"] = component_types; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("PriorityQueue", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that produces elements sorted by the first component value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PriorityQueueV2'. + /// + /// + /// Optional argument + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The type of each component in a value. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that the PriorityQueue requires the first component of any element + /// to be a scalar int64, in addition to the other elements declared by + /// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue + /// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra + /// entry in their input (resp. output) lists. + /// + public static Tensor priority_queue_v2(Shape[] shapes, TF_DataType[] component_types = null, int? capacity = null, string container = null, string shared_name = null, string name = "PriorityQueueV2") + { + var dict = new Dictionary(); + dict["shapes"] = shapes; + if (component_types != null) + dict["component_types"] = component_types; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("PriorityQueueV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the product of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Prod'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor prod(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Prod") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Prod", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the QR decompositions of one or more matrices. + /// + /// + /// A tensor of shape [..., M, N] whose inner-most 2 dimensions + /// form matrices of size [M, N]. Let P be the minimum of M and N. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Qr'. + /// + /// + /// If true, compute full-sized q and r. If false + /// (the default), compute only the leading P columns of q. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// q : Orthonormal basis for range of a. If full_matrices is False then + /// shape is [..., M, P]; if full_matrices is True then shape is + /// [..., M, M]. + /// r : Triangular factor. If full_matrices is False then shape is + /// [..., P, N]. If full_matrices is True then shape is [..., M, N]. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Computes the QR decomposition of each inner matrix in tensor such that + /// tensor[..., :, :] = q[..., :, :] * r[..., :,:]) + /// + /// + /// # a is a tensor. + /// # q is a tensor of orthonormal matrices. + /// # r is a tensor of upper triangular matrices. + /// q, r = qr(a) + /// q_full, r_full = qr(a, full_matrices=True) + /// + /// + public static (Tensor q, Tensor r) qr(Tensor input, bool? full_matrices = null, string name = "Qr") + { + var dict = new Dictionary(); + dict["input"] = input; + if (full_matrices.HasValue) + dict["full_matrices"] = full_matrices.Value; + var op = tf.OpDefLib._apply_op_helper("Qr", name: name, keywords: dict); + int _idx = 0; + var q = op.outputs[_idx++]; + var r = op.outputs[_idx++]; + return (q, r); + } + + /// + /// Use QuantizeAndDequantizeV2 instead. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizeAndDequantize'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor quantize_and_dequantize(Tensor input, bool? signed_input = null, int? num_bits = null, bool? range_given = null, float? input_min = null, float? input_max = null, string name = "QuantizeAndDequantize") + { + var dict = new Dictionary(); + dict["input"] = input; + if (signed_input.HasValue) + dict["signed_input"] = signed_input.Value; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (range_given.HasValue) + dict["range_given"] = range_given.Value; + if (input_min.HasValue) + dict["input_min"] = input_min.Value; + if (input_max.HasValue) + dict["input_max"] = input_max.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantize", name: name, keywords: dict); + return op.output; + } + + /// + /// Quantizes then dequantizes a tensor. + /// + /// + /// Tensor to quantize and then dequantize. + /// + /// + /// If range_given == True, this specifies the minimum input value that needs to + /// be represented, otherwise it is determined from the min value of the input + /// tensor. + /// + /// + /// If range_given == True, this specifies the maximum input value that needs to + /// be represented, otherwise it is determined from the max value of the input + /// tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizeAndDequantizeV2'. + /// + /// + /// Whether the quantization is signed or unsigned. (actually this parameter should + /// have been called &lt;b&gt;signed_output&lt;/b&gt;) + /// + /// + /// The bitwidth of the quantization. + /// + /// + /// Whether the range is given or should be determined from the input tensor. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op simulates the precision loss from the quantized forward pass by: + /// + /// 1. Quantizing the tensor to fixed point numbers, which should match the target + /// quantization method when it is used in inference. + /// 2. Dequantizing it back to floating point numbers for the following ops, most + /// likely matmul. + /// + /// There are different ways to quantize. This version uses only scaling, so 0.0 + /// maps to 0. + /// + /// From the specified 'num_bits' in the quantized output type, it determines + /// minimum and maximum representable quantized values. + /// + /// e.g. + /// + /// * [-128, 127] for signed, num_bits = 8, or + /// * [0, 255] for unsigned, num_bits = 8. + /// + /// If range_given == False, the initial input_min, input_max will be determined + /// automatically as the minimum and maximum values in the input tensor, otherwise + /// the specified values of input_min, input_max are used. + /// + /// Note: If the input_min, input_max are specified, they do not need to equal the + /// actual minimum and maximum values in the tensor. e.g. in some cases it may be + /// beneficial to specify these values such that the low probability extremes of the + /// input distribution are clipped. + /// + /// This op determines the maximum scale_factor that would map the initial + /// [input_min, input_max] range to a range that lies within the representable + /// quantized range. + /// + /// It determines the scale from one of input_min and input_max, then updates the + /// other one to maximize the respresentable range. + /// + /// e.g. + /// + /// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + /// 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it + /// would update input_max to be 127 / 12.8 = 9.921875 + /// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, + /// 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it + /// would update input_min to be 128.0 / 12.7 = -10.07874 + /// * if the output is unsigned, input_min is forced to be 0, and only the + /// specified input_max is used. + /// + /// After determining the scale_factor and updating the input range, it applies the + /// following to each value in the 'input' tensor. + /// + /// output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. + /// + /// + public static Tensor quantize_and_dequantize_v2(Tensor input, Tensor input_min, Tensor input_max, bool? signed_input = null, int? num_bits = null, bool? range_given = null, string name = "QuantizeAndDequantizeV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + if (signed_input.HasValue) + dict["signed_input"] = signed_input.Value; + if (num_bits.HasValue) + dict["num_bits"] = num_bits.Value; + if (range_given.HasValue) + dict["range_given"] = range_given.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantizeV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Quantizes then dequantizes a tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizeAndDequantizeV3'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a + /// tensor, so its value can change during training. + /// + public static Tensor quantize_and_dequantize_v3(Tensor input, Tensor input_min, Tensor input_max, Tensor num_bits, bool? signed_input = null, bool? range_given = null, string name = "QuantizeAndDequantizeV3") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + dict["num_bits"] = num_bits; + if (signed_input.HasValue) + dict["signed_input"] = signed_input.Value; + if (range_given.HasValue) + dict["range_given"] = range_given.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizeAndDequantizeV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Convert the quantized 'input' tensor into a lower-precision 'output', using the + /// + /// + /// + /// + /// The float value that the minimum quantized input value represents. + /// + /// + /// The float value that the maximum quantized input value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizeDownAndShrinkRange'. + /// + /// + /// Optional argument + /// The type of the output. Should be a lower bit depth than Tinput. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// output_min : The float value that the minimum quantized output value represents. + /// output_max : The float value that the maximum quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// actual distribution of the values to maximize the usage of the lower bit depth + /// and adjusting the output min and max ranges accordingly. + /// + /// [input_min, input_max] are scalar floats that specify the range for the float + /// interpretation of the 'input' data. For example, if input_min is -1.0f and + /// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 + /// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. + /// + /// This operator tries to squeeze as much precision as possible into an output with + /// a lower bit depth by calculating the actual min and max values found in the + /// data. For example, maybe that quint16 input has no values lower than 16,384 and + /// none higher than 49,152. That means only half the range is actually needed, all + /// the float interpretations are between -0.5f and 0.5f, so if we want to compress + /// the data into a quint8 output, we can use that range rather than the theoretical + /// -1.0f to 1.0f that is suggested by the input min and max. + /// + /// In practice, this is most useful for taking output from operations like + /// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and + /// may have large potential output ranges, but in practice have a distribution of + /// input values that only uses a small fraction of the possible range. By feeding + /// that output into this operator, we can reduce it from 32 bits down to 8 with + /// minimal loss of accuracy. + /// + public static (Tensor output, Tensor output_min, Tensor output_max) quantize_down_and_shrink_range(Tensor input, Tensor input_min, Tensor input_max, TF_DataType out_type, string name = "QuantizeDownAndShrinkRange") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + dict["out_type"] = out_type; + var op = tf.OpDefLib._apply_op_helper("QuantizeDownAndShrinkRange", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output, output_min, output_max); + } + + /// + /// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. + /// + /// + /// + /// + /// The minimum scalar value possibly produced for the input. + /// + /// + /// The maximum scalar value possibly produced for the input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizeV2'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : The quantized data produced from the float input. + /// output_min : The actual minimum scalar value used for the output. + /// output_max : The actual maximum scalar value used for the output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// [min_range, max_range] are scalar floats that specify the range for + /// the 'input' data. The 'mode' attribute controls exactly which calculations are + /// used to convert the float values to their quantized equivalents. The + /// 'round_mode' attribute controls which rounding tie-breaking algorithm is used + /// when rounding float values to their quantized equivalents. + /// + /// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: + /// + /// + /// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) + /// if T == qint8, out[i] -= (range(T) + 1) / 2.0 + /// + /// + /// here range(T) = numeric_limits&lt;T&gt;::max() - numeric_limits&lt;T&gt;::min() + /// + /// *MIN_COMBINED Mode Example* + /// + /// Assume the input is type float and has a possible range of [0.0, 6.0] and the + /// output type is quint8 ([0, 255]). The min_range and max_range values should be + /// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each + /// value of the input by 255/6 and cast to quint8. + /// + /// If the output type was qint8 ([-128, 127]), the operation will additionally + /// subtract each value by 128 prior to casting, so that the range of values aligns + /// with the range of qint8. + /// + /// If the mode is 'MIN_FIRST', then this approach is used: + /// + /// + /// num_discrete_values = 1 &lt;&lt; (# of bits in T) + /// range_adjust = num_discrete_values / (num_discrete_values - 1) + /// range = (range_max - range_min) * range_adjust + /// range_scale = num_discrete_values / range + /// quantized = round(input * range_scale) - round(range_min * range_scale) + + /// numeric_limits&lt;T&gt;::min() + /// quantized = max(quantized, numeric_limits&lt;T&gt;::min()) + /// quantized = min(quantized, numeric_limits&lt;T&gt;::max()) + /// + /// + /// The biggest difference between this and MIN_COMBINED is that the minimum range + /// is rounded first, before it's subtracted from the rounded value. With + /// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing + /// and dequantizing will introduce a larger and larger error. + /// + /// *SCALED mode Example* + /// + /// SCALED mode matches the quantization approach used in + /// QuantizeAndDequantize{V2|V3}. + /// + /// If the mode is SCALED, we do not use the full range of the output type, + /// choosing to elide the lowest possible value for symmetry (e.g., output range is + /// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to + /// 0. + /// + /// We first find the range of values in our tensor. The + /// range we use is always centered on 0, so we find m such that + /// + /// + /// m = max(abs(input_min), abs(input_max)) + /// + /// + /// Our input tensor range is then [-m, m]. + /// + /// Next, we choose our fixed-point quantization buckets, [min_fixed, max_fixed]. + /// If T is signed, this is + /// + /// + /// num_bits = sizeof(T) * 8 + /// [min_fixed, max_fixed] = + /// [-(1 &lt;&lt; (num_bits - 1) - 1), (1 &lt;&lt; (num_bits - 1)) - 1] + /// + /// + /// Otherwise, if T is unsigned, the fixed-point range is + /// + /// + /// [min_fixed, max_fixed] = [0, (1 &lt;&lt; num_bits) - 1] + /// + /// + /// From this we compute our scaling factor, s: + /// + /// + /// s = (max_fixed - min_fixed) / (2 * m) + /// + /// + /// Now we can quantize the elements of our tensor: + /// + /// + /// result = round(input * s) + /// + /// + /// One thing to watch out for is that the operator may choose to adjust the + /// requested minimum and maximum values slightly during the quantization process, + /// so you should always use the output ports as the range for further calculations. + /// For example, if the requested minimum and maximum values are close to equal, + /// they will be separated by a small epsilon value to prevent ill-formed quantized + /// buffers from being created. Otherwise, you can end up with buffers where all the + /// quantized values map to the same float value, which causes problems for + /// operations that have to perform further calculations on them. + /// + public static (Tensor output, Tensor output_min, Tensor output_max) quantize_v2(Tensor input, Tensor min_range, Tensor max_range, TF_DataType T, string mode = null, string round_mode = null, string name = "QuantizeV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["min_range"] = min_range; + dict["max_range"] = max_range; + dict["T"] = T; + if (mode != null) + dict["mode"] = mode; + if (round_mode != null) + dict["round_mode"] = round_mode; + var op = tf.OpDefLib._apply_op_helper("QuantizeV2", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output, output_min, output_max); + } + + /// + /// Returns x + y element-wise, working on quantized buffers. + /// + /// + /// + /// + /// + /// + /// The float value that the lowest quantized x value represents. + /// + /// + /// The float value that the highest quantized x value represents. + /// + /// + /// The float value that the lowest quantized y value represents. + /// + /// + /// The float value that the highest quantized y value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedAdd'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// z : + /// min_z : The float value that the lowest quantized output value represents. + /// max_z : The float value that the highest quantized output value represents. + /// + /// *NOTE*: QuantizedAdd supports limited forms of broadcasting. More about + /// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor z, Tensor min_z, Tensor max_z) quantized_add(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType? Toutput = null, string name = "QuantizedAdd") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + dict["min_x"] = min_x; + dict["max_x"] = max_x; + dict["min_y"] = min_y; + dict["max_y"] = max_y; + if (Toutput.HasValue) + dict["Toutput"] = Toutput.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedAdd", name: name, keywords: dict); + int _idx = 0; + var z = op.outputs[_idx++]; + var min_z = op.outputs[_idx++]; + var max_z = op.outputs[_idx++]; + return (z, min_z, max_z); + } + + /// + /// Produces the average pool of the input tensor for quantized types. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// The float value that the lowest quantized input value represents. + /// + /// + /// The float value that the highest quantized input value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedAvgPool'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// The length must be 4 to match the number of dimensions of the input. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// tensor. The length must be 4 to match the number of dimensions of the input. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// min_output : The float value that the lowest quantized output value represents. + /// max_output : The float value that the highest quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor output, Tensor min_output, Tensor max_output) quantized_avg_pool(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string name = "QuantizedAvgPool") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["min_input"] = min_input; + dict["max_input"] = max_input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("QuantizedAvgPool", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var min_output = op.outputs[_idx++]; + var max_output = op.outputs[_idx++]; + return (output, min_output, max_output); + } + + /// + /// Quantized Batch normalization. + /// + /// + /// A 4D input Tensor. + /// + /// + /// The value represented by the lowest quantized input. + /// + /// + /// The value represented by the highest quantized input. + /// + /// + /// A 1D mean Tensor with size matching the last dimension of t. + /// This is the first output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// The value represented by the lowest quantized mean. + /// + /// + /// The value represented by the highest quantized mean. + /// + /// + /// A 1D variance Tensor with size matching the last dimension of t. + /// This is the second output from tf.nn.moments, + /// or a saved moving average thereof. + /// + /// + /// The value represented by the lowest quantized variance. + /// + /// + /// The value represented by the highest quantized variance. + /// + /// + /// A 1D beta Tensor with size matching the last dimension of t. + /// An offset to be added to the normalized tensor. + /// + /// + /// The value represented by the lowest quantized offset. + /// + /// + /// The value represented by the highest quantized offset. + /// + /// + /// A 1D gamma Tensor with size matching the last dimension of t. + /// If "scale_after_normalization" is true, this tensor will be multiplied + /// with the normalized tensor. + /// + /// + /// The value represented by the lowest quantized gamma. + /// + /// + /// The value represented by the highest quantized gamma. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedBatchNormWithGlobalNormalization'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// A small float number to avoid dividing by 0. + /// + /// + /// Optional argument + /// A bool indicating whether the resulted tensor + /// needs to be multiplied with gamma. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// result : + /// result_min : + /// result_max : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This op is deprecated and will be removed in the future. Prefer + /// tf.nn.batch_normalization. + /// + public static (Tensor result, Tensor result_min, Tensor result_max) quantized_batch_norm_with_global_normalization(Tensor t, Tensor t_min, Tensor t_max, Tensor m, Tensor m_min, Tensor m_max, Tensor v, Tensor v_min, Tensor v_max, Tensor beta, Tensor beta_min, Tensor beta_max, Tensor gamma, Tensor gamma_min, Tensor gamma_max, TF_DataType out_type, float variance_epsilon, bool scale_after_normalization, string name = "QuantizedBatchNormWithGlobalNormalization") + { + var dict = new Dictionary(); + dict["t"] = t; + dict["t_min"] = t_min; + dict["t_max"] = t_max; + dict["m"] = m; + dict["m_min"] = m_min; + dict["m_max"] = m_max; + dict["v"] = v; + dict["v_min"] = v_min; + dict["v_max"] = v_max; + dict["beta"] = beta; + dict["beta_min"] = beta_min; + dict["beta_max"] = beta_max; + dict["gamma"] = gamma; + dict["gamma_min"] = gamma_min; + dict["gamma_max"] = gamma_max; + dict["out_type"] = out_type; + dict["variance_epsilon"] = variance_epsilon; + dict["scale_after_normalization"] = scale_after_normalization; + var op = tf.OpDefLib._apply_op_helper("QuantizedBatchNormWithGlobalNormalization", name: name, keywords: dict); + int _idx = 0; + var result = op.outputs[_idx++]; + var result_min = op.outputs[_idx++]; + var result_max = op.outputs[_idx++]; + return (result, result_min, result_max); + } + + /// + /// Adds Tensor 'bias' to Tensor 'input' for Quantized types. + /// + /// + /// + /// + /// A 1D bias Tensor with size matching the last dimension of 'input'. + /// + /// + /// The float value that the lowest quantized input value represents. + /// + /// + /// The float value that the highest quantized input value represents. + /// + /// + /// The float value that the lowest quantized bias value represents. + /// + /// + /// The float value that the highest quantized bias value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedBiasAdd'. + /// + /// + /// Optional argument + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// min_out : The float value that the lowest quantized output value represents. + /// max_out : The float value that the highest quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Broadcasts the values of bias on dimensions 0..N-2 of 'input'. + /// + public static (Tensor output, Tensor min_out, Tensor max_out) quantized_bias_add(Tensor input, Tensor bias, Tensor min_input, Tensor max_input, Tensor min_bias, Tensor max_bias, TF_DataType out_type, string name = "QuantizedBiasAdd") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["bias"] = bias; + dict["min_input"] = min_input; + dict["max_input"] = max_input; + dict["min_bias"] = min_bias; + dict["max_bias"] = max_bias; + dict["out_type"] = out_type; + var op = tf.OpDefLib._apply_op_helper("QuantizedBiasAdd", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var min_out = op.outputs[_idx++]; + var max_out = op.outputs[_idx++]; + return (output, min_out, max_out); + } + + /// + /// Concatenates quantized tensors along one dimension. + /// + /// + /// 0-D. The dimension along which to concatenate. Must be in the + /// range [0, rank(values)). + /// + /// + /// The N Tensors to concatenate. Their ranks and types must match, + /// and their sizes must match in all dimensions except concat_dim. + /// + /// + /// The minimum scalar values for each of the input tensors. + /// + /// + /// The maximum scalar values for each of the input tensors. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedConcat'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : A Tensor with the concatenation of values stacked along the + /// concat_dim dimension. This tensor's shape matches that of values except + /// in concat_dim where it has the sum of the sizes. + /// output_min : The float value that the minimum quantized output value represents. + /// output_max : The float value that the maximum quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor output, Tensor output_min, Tensor output_max) quantized_concat(Tensor concat_dim, Tensor[] values, Tensor[] input_mins, Tensor[] input_maxes, string name = "QuantizedConcat") + { + var dict = new Dictionary(); + dict["concat_dim"] = concat_dim; + dict["values"] = values; + dict["input_mins"] = input_mins; + dict["input_maxes"] = input_maxes; + var op = tf.OpDefLib._apply_op_helper("QuantizedConcat", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output, output_min, output_max); + } + + /// + /// Computes a 2D convolution given quantized 4D input and filter tensors. + /// + /// + /// + /// + /// filter's input_depth dimension must match input's depth dimensions. + /// + /// + /// The float value that the lowest quantized input value represents. + /// + /// + /// The float value that the highest quantized input value represents. + /// + /// + /// The float value that the lowest quantized filter value represents. + /// + /// + /// The float value that the highest quantized filter value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedConv2D'. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// tensor. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// + /// + /// 1-D tensor of length 4. The dilation factor for each dimension of + /// input. If set to k &gt; 1, there will be k-1 skipped cells between each + /// filter element on that dimension. The dimension order is determined by the + /// value of data_format, see above for details. Dilations in the batch and + /// depth dimensions must be 1. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// min_output : The float value that the lowest quantized output value represents. + /// max_output : The float value that the highest quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The inputs are quantized tensors where the lowest value represents the real + /// number of the associated minimum, and the highest represents the maximum. + /// This means that you can only interpret the quantized output in the same way, by + /// taking the returned minimum and maximum values into account. + /// + public static (Tensor output, Tensor min_output, Tensor max_output) quantized_conv2d(Tensor input, Tensor filter, Tensor min_input, Tensor max_input, Tensor min_filter, Tensor max_filter, int[] strides, string padding, TF_DataType? out_type = null, int[] dilations = null, string name = "QuantizedConv2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["filter"] = filter; + dict["min_input"] = min_input; + dict["max_input"] = max_input; + dict["min_filter"] = min_filter; + dict["max_filter"] = max_filter; + dict["strides"] = strides; + dict["padding"] = padding; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + if (dilations != null) + dict["dilations"] = dilations; + var op = tf.OpDefLib._apply_op_helper("QuantizedConv2D", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var min_output = op.outputs[_idx++]; + var max_output = op.outputs[_idx++]; + return (output, min_output, max_output); + } + + /// + /// Quantized Instance normalization. + /// + /// + /// A 4D input Tensor. + /// + /// + /// The value represented by the lowest quantized input. + /// + /// + /// The value represented by the highest quantized input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedInstanceNorm'. + /// + /// + /// If True, given_y_min and given_y_min + /// and given_y_max are used as the output range. Otherwise, + /// the implementation computes the output range. + /// + /// + /// Output in y_min if output_range_given is True. + /// + /// + /// Output in y_max if output_range_given is True. + /// + /// + /// A small float number to avoid dividing by 0. + /// + /// + /// Minimum value of y_max - y_min + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : A 4D Tensor. + /// y_min : The value represented by the lowest quantized output. + /// y_max : The value represented by the highest quantized output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor y, Tensor y_min, Tensor y_max) quantized_instance_norm(Tensor x, Tensor x_min, Tensor x_max, bool? output_range_given = null, float? given_y_min = null, float? given_y_max = null, float? variance_epsilon = null, float? min_separation = null, string name = "QuantizedInstanceNorm") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["x_min"] = x_min; + dict["x_max"] = x_max; + if (output_range_given.HasValue) + dict["output_range_given"] = output_range_given.Value; + if (given_y_min.HasValue) + dict["given_y_min"] = given_y_min.Value; + if (given_y_max.HasValue) + dict["given_y_max"] = given_y_max.Value; + if (variance_epsilon.HasValue) + dict["variance_epsilon"] = variance_epsilon.Value; + if (min_separation.HasValue) + dict["min_separation"] = min_separation.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedInstanceNorm", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var y_min = op.outputs[_idx++]; + var y_max = op.outputs[_idx++]; + return (y, y_min, y_max); + } + + /// + /// Perform a quantized matrix multiplication of a by the matrix b. + /// + /// + /// Must be a two-dimensional tensor. + /// + /// + /// Must be a two-dimensional tensor. + /// + /// + /// The float value that the lowest quantized a value represents. + /// + /// + /// The float value that the highest quantized a value represents. + /// + /// + /// The float value that the lowest quantized b value represents. + /// + /// + /// The float value that the highest quantized b value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedMatMul'. + /// + /// + /// + /// + /// If true, a is transposed before multiplication. + /// + /// + /// If true, b is transposed before multiplication. + /// + /// + /// The type of output produced by activation function + /// following this operation. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// min_out : The float value that the lowest quantized output value represents. + /// max_out : The float value that the highest quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of + /// a (after being transposed if transpose_a is non-zero) must match the + /// outer dimension of b (after being transposed if transposed_b is + /// non-zero). + /// + public static (Tensor output, Tensor min_out, Tensor max_out) quantized_mat_mul(Tensor a, Tensor b, Tensor min_a, Tensor max_a, Tensor min_b, Tensor max_b, TF_DataType? Toutput = null, bool? transpose_a = null, bool? transpose_b = null, TF_DataType? Tactivation = null, string name = "QuantizedMatMul") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["b"] = b; + dict["min_a"] = min_a; + dict["max_a"] = max_a; + dict["min_b"] = min_b; + dict["max_b"] = max_b; + if (Toutput.HasValue) + dict["Toutput"] = Toutput.Value; + if (transpose_a.HasValue) + dict["transpose_a"] = transpose_a.Value; + if (transpose_b.HasValue) + dict["transpose_b"] = transpose_b.Value; + if (Tactivation.HasValue) + dict["Tactivation"] = Tactivation.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedMatMul", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var min_out = op.outputs[_idx++]; + var max_out = op.outputs[_idx++]; + return (output, min_out, max_out); + } + + /// + /// Produces the max pool of the input tensor for quantized types. + /// + /// + /// The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. + /// + /// + /// The float value that the lowest quantized input value represents. + /// + /// + /// The float value that the highest quantized input value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedMaxPool'. + /// + /// + /// Optional argument + /// The size of the window for each dimension of the input tensor. + /// The length must be 4 to match the number of dimensions of the input. + /// + /// + /// Optional argument + /// The stride of the sliding window for each dimension of the input + /// tensor. The length must be 4 to match the number of dimensions of the input. + /// + /// + /// Optional argument + /// The type of padding algorithm to use. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// min_output : The float value that the lowest quantized output value represents. + /// max_output : The float value that the highest quantized output value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor output, Tensor min_output, Tensor max_output) quantized_max_pool(Tensor input, Tensor min_input, Tensor max_input, int[] ksize, int[] strides, string padding, string name = "QuantizedMaxPool") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["min_input"] = min_input; + dict["max_input"] = max_input; + dict["ksize"] = ksize; + dict["strides"] = strides; + dict["padding"] = padding; + var op = tf.OpDefLib._apply_op_helper("QuantizedMaxPool", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var min_output = op.outputs[_idx++]; + var max_output = op.outputs[_idx++]; + return (output, min_output, max_output); + } + + /// + /// Returns x * y element-wise, working on quantized buffers. + /// + /// + /// + /// + /// + /// + /// The float value that the lowest quantized x value represents. + /// + /// + /// The float value that the highest quantized x value represents. + /// + /// + /// The float value that the lowest quantized y value represents. + /// + /// + /// The float value that the highest quantized y value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedMul'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// z : + /// min_z : The float value that the lowest quantized output value represents. + /// max_z : The float value that the highest quantized output value represents. + /// + /// *NOTE*: QuantizedMul supports limited forms of broadcasting. More about + /// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor z, Tensor min_z, Tensor max_z) quantized_mul(Tensor x, Tensor y, Tensor min_x, Tensor max_x, Tensor min_y, Tensor max_y, TF_DataType? Toutput = null, string name = "QuantizedMul") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + dict["min_x"] = min_x; + dict["max_x"] = max_x; + dict["min_y"] = min_y; + dict["max_y"] = max_y; + if (Toutput.HasValue) + dict["Toutput"] = Toutput.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedMul", name: name, keywords: dict); + int _idx = 0; + var z = op.outputs[_idx++]; + var min_z = op.outputs[_idx++]; + var max_z = op.outputs[_idx++]; + return (z, min_z, max_z); + } + + /// + /// Computes Quantized Rectified Linear: max(features, 0) + /// + /// + /// + /// + /// The float value that the lowest quantized value represents. + /// + /// + /// The float value that the highest quantized value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedRelu'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// activations : Has the same output shape as "features". + /// min_activations : The float value that the lowest quantized value represents. + /// max_activations : The float value that the highest quantized value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor activations, Tensor min_activations, Tensor max_activations) quantized_relu(Tensor features, Tensor min_features, Tensor max_features, TF_DataType? out_type = null, string name = "QuantizedRelu") + { + var dict = new Dictionary(); + dict["features"] = features; + dict["min_features"] = min_features; + dict["max_features"] = max_features; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedRelu", name: name, keywords: dict); + int _idx = 0; + var activations = op.outputs[_idx++]; + var min_activations = op.outputs[_idx++]; + var max_activations = op.outputs[_idx++]; + return (activations, min_activations, max_activations); + } + + /// + /// Computes Quantized Rectified Linear 6: min(max(features, 0), 6) + /// + /// + /// + /// + /// The float value that the lowest quantized value represents. + /// + /// + /// The float value that the highest quantized value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedRelu6'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// activations : Has the same output shape as "features". + /// min_activations : The float value that the lowest quantized value represents. + /// max_activations : The float value that the highest quantized value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor activations, Tensor min_activations, Tensor max_activations) quantized_relu6(Tensor features, Tensor min_features, Tensor max_features, TF_DataType? out_type = null, string name = "QuantizedRelu6") + { + var dict = new Dictionary(); + dict["features"] = features; + dict["min_features"] = min_features; + dict["max_features"] = max_features; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedRelu6", name: name, keywords: dict); + int _idx = 0; + var activations = op.outputs[_idx++]; + var min_activations = op.outputs[_idx++]; + var max_activations = op.outputs[_idx++]; + return (activations, min_activations, max_activations); + } + + /// + /// Computes Quantized Rectified Linear X: min(max(features, 0), max_value) + /// + /// + /// + /// + /// + /// + /// The float value that the lowest quantized value represents. + /// + /// + /// The float value that the highest quantized value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedReluX'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// activations : Has the same output shape as "features". + /// min_activations : The float value that the lowest quantized value represents. + /// max_activations : The float value that the highest quantized value represents. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor activations, Tensor min_activations, Tensor max_activations) quantized_relu_x(Tensor features, Tensor max_value, Tensor min_features, Tensor max_features, TF_DataType? out_type = null, string name = "QuantizedReluX") + { + var dict = new Dictionary(); + dict["features"] = features; + dict["max_value"] = max_value; + dict["min_features"] = min_features; + dict["max_features"] = max_features; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedReluX", name: name, keywords: dict); + int _idx = 0; + var activations = op.outputs[_idx++]; + var min_activations = op.outputs[_idx++]; + var max_activations = op.outputs[_idx++]; + return (activations, min_activations, max_activations); + } + + /// + /// Reshapes a quantized tensor as per the Reshape op. + /// + /// + /// + /// + /// Defines the shape of the output tensor. + /// + /// + /// The minimum value of the input. + /// + /// + /// The maximum value of the input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedReshape'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// output_min : This value is copied from input_min. + /// output_max : This value is copied from input_max. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// + public static (Tensor output, Tensor output_min, Tensor output_max) quantized_reshape(Tensor tensor, Tensor shape, Tensor input_min, Tensor input_max, string name = "QuantizedReshape") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["shape"] = shape; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + var op = tf.OpDefLib._apply_op_helper("QuantizedReshape", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output, output_min, output_max); + } + + /// + /// Resize quantized images to size using quantized bilinear interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QuantizedResizeBilinear'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// resized_images : 4-D with shape + /// [batch, new_height, new_width, channels]. + /// out_min : + /// out_max : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Input images and output images must be quantized types. + /// + public static (Tensor resized_images, Tensor out_min, Tensor out_max) quantized_resize_bilinear(Tensor images, Tensor size, Tensor min, Tensor max, bool? align_corners = null, string name = "QuantizedResizeBilinear") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["size"] = size; + dict["min"] = min; + dict["max"] = max; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("QuantizedResizeBilinear", name: name, keywords: dict); + int _idx = 0; + var resized_images = op.outputs[_idx++]; + var out_min = op.outputs[_idx++]; + var out_max = op.outputs[_idx++]; + return (resized_images, out_min, out_max); + } + + /// + /// Closes the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueClose'. + /// + /// + /// If true, all pending enqueue requests that are + /// blocked on the given queue will be canceled. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation signals that no more elements will be enqueued in the + /// given queue. Subsequent Enqueue(Many) operations will fail. + /// Subsequent Dequeue(Many) operations will continue to succeed if + /// sufficient elements remain in the queue. Subsequent Dequeue(Many) + /// operations that would block will fail immediately. + /// + public static Operation queue_close(Tensor handle, bool? cancel_pending_enqueues = null, string name = "QueueClose") + { + var dict = new Dictionary(); + dict["handle"] = handle; + if (cancel_pending_enqueues.HasValue) + dict["cancel_pending_enqueues"] = cancel_pending_enqueues.Value; + var op = tf.OpDefLib._apply_op_helper("QueueClose", name: name, keywords: dict); + return op; + } + + /// + /// Closes the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueCloseV2'. + /// + /// + /// If true, all pending enqueue requests that are + /// blocked on the given queue will be canceled. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation signals that no more elements will be enqueued in the + /// given queue. Subsequent Enqueue(Many) operations will fail. + /// Subsequent Dequeue(Many) operations will continue to succeed if + /// sufficient elements remain in the queue. Subsequent Dequeue(Many) + /// operations that would block will fail immediately. + /// + public static Operation queue_close_v2(Tensor handle, bool? cancel_pending_enqueues = null, string name = "QueueCloseV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + if (cancel_pending_enqueues.HasValue) + dict["cancel_pending_enqueues"] = cancel_pending_enqueues.Value; + var op = tf.OpDefLib._apply_op_helper("QueueCloseV2", name: name, keywords: dict); + return op; + } + + /// + /// Dequeues a tuple of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeue'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue is empty, this operation will block for up to + /// timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation has k outputs, where k is the number of components + /// in the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + /// N.B. If the queue is empty, this operation will block until an element + /// has been dequeued (or 'timeout_ms' elapses, if specified). + /// + public static Tensor[] queue_dequeue(Tensor handle, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeue") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeue", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Dequeues n tuples of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// The number of tuples to dequeue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeueMany'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue has fewer than n elements, this operation + /// will block for up to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If the queue is closed and there are fewer than n elements, then an + /// OutOfRange error is returned. + /// + /// This operation concatenates queue-element component tensors along the + /// 0th dimension to make a single component tensor. All of the components + /// in the dequeued tuple will have size n in the 0th dimension. + /// + /// This operation has k outputs, where k is the number of components in + /// the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + /// N.B. If the queue is empty, this operation will block until n elements + /// have been dequeued (or 'timeout_ms' elapses, if specified). + /// + public static Tensor[] queue_dequeue_many(Tensor handle, Tensor n, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeueMany") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["n"] = n; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeueMany", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Dequeues n tuples of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// The number of tuples to dequeue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeueManyV2'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue has fewer than n elements, this operation + /// will block for up to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If the queue is closed and there are fewer than n elements, then an + /// OutOfRange error is returned. + /// + /// This operation concatenates queue-element component tensors along the + /// 0th dimension to make a single component tensor. All of the components + /// in the dequeued tuple will have size n in the 0th dimension. + /// + /// This operation has k outputs, where k is the number of components in + /// the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + /// N.B. If the queue is empty, this operation will block until n elements + /// have been dequeued (or 'timeout_ms' elapses, if specified). + /// + public static Tensor[] queue_dequeue_many_v2(Tensor handle, Tensor n, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeueManyV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["n"] = n; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeueManyV2", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Dequeues n tuples of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// The number of tuples to dequeue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeueUpTo'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue has fewer than n elements, this operation + /// will block for up to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation is not supported by all queues. If a queue does not support + /// DequeueUpTo, then an Unimplemented error is returned. + /// + /// If the queue is closed and there are more than 0 but less than n + /// elements remaining, then instead of returning an OutOfRange error like + /// QueueDequeueMany, less than n elements are returned immediately. If + /// the queue is closed and there are 0 elements left in the queue, then + /// an OutOfRange error is returned just like in QueueDequeueMany. + /// Otherwise the behavior is identical to QueueDequeueMany: + /// + /// This operation concatenates queue-element component tensors along the + /// 0th dimension to make a single component tensor. All of the components + /// in the dequeued tuple will have size n in the 0th dimension. + /// + /// This operation has k outputs, where k is the number of components in + /// the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + public static Tensor[] queue_dequeue_up_to(Tensor handle, Tensor n, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeueUpTo") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["n"] = n; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeueUpTo", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Dequeues n tuples of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// The number of tuples to dequeue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeueUpToV2'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue has fewer than n elements, this operation + /// will block for up to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation is not supported by all queues. If a queue does not support + /// DequeueUpTo, then an Unimplemented error is returned. + /// + /// If the queue is closed and there are more than 0 but less than n + /// elements remaining, then instead of returning an OutOfRange error like + /// QueueDequeueMany, less than n elements are returned immediately. If + /// the queue is closed and there are 0 elements left in the queue, then + /// an OutOfRange error is returned just like in QueueDequeueMany. + /// Otherwise the behavior is identical to QueueDequeueMany: + /// + /// This operation concatenates queue-element component tensors along the + /// 0th dimension to make a single component tensor. All of the components + /// in the dequeued tuple will have size n in the 0th dimension. + /// + /// This operation has k outputs, where k is the number of components in + /// the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + public static Tensor[] queue_dequeue_up_to_v2(Tensor handle, Tensor n, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeueUpToV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["n"] = n; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeueUpToV2", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Dequeues a tuple of one or more tensors from the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueDequeueV2'. + /// + /// + /// Optional argument + /// The type of each component in a tuple. + /// + /// + /// If the queue is empty, this operation will block for up to + /// timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// One or more tensors that were dequeued as a tuple. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation has k outputs, where k is the number of components + /// in the tuples stored in the given queue, and output i is the ith + /// component of the dequeued tuple. + /// + /// N.B. If the queue is empty, this operation will block until an element + /// has been dequeued (or 'timeout_ms' elapses, if specified). + /// + public static Tensor[] queue_dequeue_v2(Tensor handle, TF_DataType[] component_types, int? timeout_ms = null, string name = "QueueDequeueV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["component_types"] = component_types; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueDequeueV2", name: name, keywords: dict); + int _idx = 0; + var components = Enumerable.Range(0, op.OutputListLength("components")).Select(_ => op.outputs[_idx++]).ToArray(); + return (components); + } + + /// + /// Enqueues a tuple of one or more tensors in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// One or more tensors from which the enqueued tensors should be taken. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueEnqueue'. + /// + /// + /// If the queue is full, this operation will block for up to + /// timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// Returns the description of the operation + /// + /// + /// The components input has k elements, which correspond to the components of + /// tuples stored in the given queue. + /// + /// N.B. If the queue is full, this operation will block until the given + /// element has been enqueued (or 'timeout_ms' elapses, if specified). + /// + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int? timeout_ms = null, string name = "QueueEnqueue") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["components"] = components; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueEnqueue", name: name, keywords: dict); + return op; + } + + /// + /// Enqueues zero or more tuples of one or more tensors in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// One or more tensors from which the enqueued tensors should + /// be taken. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueEnqueueMany'. + /// + /// + /// If the queue is too full, this operation will block for up + /// to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation slices each component tensor along the 0th dimension to + /// make multiple queue elements. All of the tuple components must have the + /// same size in the 0th dimension. + /// + /// The components input has k elements, which correspond to the components of + /// tuples stored in the given queue. + /// + /// N.B. If the queue is full, this operation will block until the given + /// elements have been enqueued (or 'timeout_ms' elapses, if specified). + /// + public static Operation queue_enqueue_many(Tensor handle, Tensor[] components, int? timeout_ms = null, string name = "QueueEnqueueMany") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["components"] = components; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueEnqueueMany", name: name, keywords: dict); + return op; + } + + /// + /// Enqueues zero or more tuples of one or more tensors in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// One or more tensors from which the enqueued tensors should + /// be taken. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueEnqueueManyV2'. + /// + /// + /// If the queue is too full, this operation will block for up + /// to timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation slices each component tensor along the 0th dimension to + /// make multiple queue elements. All of the tuple components must have the + /// same size in the 0th dimension. + /// + /// The components input has k elements, which correspond to the components of + /// tuples stored in the given queue. + /// + /// N.B. If the queue is full, this operation will block until the given + /// elements have been enqueued (or 'timeout_ms' elapses, if specified). + /// + public static Operation queue_enqueue_many_v2(Tensor handle, Tensor[] components, int? timeout_ms = null, string name = "QueueEnqueueManyV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["components"] = components; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueEnqueueManyV2", name: name, keywords: dict); + return op; + } + + /// + /// Enqueues a tuple of one or more tensors in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// One or more tensors from which the enqueued tensors should be taken. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueEnqueueV2'. + /// + /// + /// If the queue is full, this operation will block for up to + /// timeout_ms milliseconds. + /// Note: This option is not supported yet. + /// + /// + /// Returns the description of the operation + /// + /// + /// The components input has k elements, which correspond to the components of + /// tuples stored in the given queue. + /// + /// N.B. If the queue is full, this operation will block until the given + /// element has been enqueued (or 'timeout_ms' elapses, if specified). + /// + public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int? timeout_ms = null, string name = "QueueEnqueueV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["components"] = components; + if (timeout_ms.HasValue) + dict["timeout_ms"] = timeout_ms.Value; + var op = tf.OpDefLib._apply_op_helper("QueueEnqueueV2", name: name, keywords: dict); + return op; + } + + /// + /// Returns true if queue is closed. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueIsClosed'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns true if the queue is closed and false if the queue + /// is open. + /// + public static Tensor queue_is_closed(Tensor handle, string name = "QueueIsClosed") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("QueueIsClosed", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns true if queue is closed. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueIsClosedV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns true if the queue is closed and false if the queue + /// is open. + /// + public static Tensor queue_is_closed_v2(Tensor handle, string name = "QueueIsClosedV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("QueueIsClosedV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the number of elements in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueSize'. + /// + /// + /// The number of elements in the given queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor queue_size(Tensor handle, string name = "QueueSize") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("QueueSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the number of elements in the given queue. + /// + /// + /// The handle to a queue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'QueueSizeV2'. + /// + /// + /// The number of elements in the given queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor queue_size_v2(Tensor handle, string name = "QueueSizeV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("QueueSizeV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Real-valued fast Fourier transform. + /// + /// + /// A float32 tensor. + /// + /// + /// An int32 tensor of shape [1]. The FFT length. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RFFT'. + /// + /// + /// A complex64 tensor of the same rank as input. The inner-most + /// dimension of input is replaced with the fft_length / 2 + 1 unique + /// frequency components of its 1D Fourier transform. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.rfft + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 1-dimensional discrete Fourier transform of a real-valued signal + /// over the inner-most dimension of input. + /// + /// Since the DFT of a real signal is Hermitian-symmetric, RFFT only returns the + /// fft_length / 2 + 1 unique components of the FFT: the zero-frequency term, + /// followed by the fft_length / 2 positive-frequency terms. + /// + /// Along the axis RFFT is computed on, if fft_length is smaller than the + /// corresponding dimension of input, the dimension is cropped. If it is larger, + /// the dimension is padded with zeros. + /// + public static Tensor r_f_f_t(Tensor input, Tensor fft_length, string name = "RFFT") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("RFFT", name: name, keywords: dict); + return op.output; + } + + /// + /// 2D real-valued fast Fourier transform. + /// + /// + /// A float32 tensor. + /// + /// + /// An int32 tensor of shape [2]. The FFT length for each dimension. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RFFT2D'. + /// + /// + /// A complex64 tensor of the same rank as input. The inner-most 2 + /// dimensions of input are replaced with their 2D Fourier transform. The + /// inner-most dimension contains fft_length / 2 + 1 unique frequency + /// components. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.rfft2 + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 2-dimensional discrete Fourier transform of a real-valued signal + /// over the inner-most 2 dimensions of input. + /// + /// Since the DFT of a real signal is Hermitian-symmetric, RFFT2D only returns the + /// fft_length / 2 + 1 unique components of the FFT for the inner-most dimension + /// of output: the zero-frequency term, followed by the fft_length / 2 + /// positive-frequency terms. + /// + /// Along each axis RFFT2D is computed on, if fft_length is smaller than the + /// corresponding dimension of input, the dimension is cropped. If it is larger, + /// the dimension is padded with zeros. + /// + public static Tensor r_f_f_t2d(Tensor input, Tensor fft_length, string name = "RFFT2D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("RFFT2D", name: name, keywords: dict); + return op.output; + } + + /// + /// 3D real-valued fast Fourier transform. + /// + /// + /// A float32 tensor. + /// + /// + /// An int32 tensor of shape [3]. The FFT length for each dimension. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RFFT3D'. + /// + /// + /// A complex64 tensor of the same rank as input. The inner-most 3 + /// dimensions of input are replaced with the their 3D Fourier transform. The + /// inner-most dimension contains fft_length / 2 + 1 unique frequency + /// components. + /// + /// @compatibility(numpy) + /// Equivalent to np.fft.rfftn with 3 dimensions. + /// @end_compatibility + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the 3-dimensional discrete Fourier transform of a real-valued signal + /// over the inner-most 3 dimensions of input. + /// + /// Since the DFT of a real signal is Hermitian-symmetric, RFFT3D only returns the + /// fft_length / 2 + 1 unique components of the FFT for the inner-most dimension + /// of output: the zero-frequency term, followed by the fft_length / 2 + /// positive-frequency terms. + /// + /// Along each axis RFFT3D is computed on, if fft_length is smaller than the + /// corresponding dimension of input, the dimension is cropped. If it is larger, + /// the dimension is padded with zeros. + /// + public static Tensor r_f_f_t3d(Tensor input, Tensor fft_length, string name = "RFFT3D") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["fft_length"] = fft_length; + var op = tf.OpDefLib._apply_op_helper("RFFT3D", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts one or more images from RGB to HSV. + /// + /// + /// 1-D or higher rank. RGB data to convert. Last dimension must be size 3. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RGBToHSV'. + /// + /// + /// images converted to HSV. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs a tensor of the same shape as the images tensor, containing the HSV + /// value of the pixels. The output is only well defined if the value in images + /// are in [0,1]. + /// + /// output[..., 0] contains hue, output[..., 1] contains saturation, and + /// output[..., 2] contains value. All HSV values are in [0,1]. A hue of 0 + /// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. + /// + public static Tensor r_g_b_to_h_s_v(Tensor images, string name = "RGBToHSV") + { + var dict = new Dictionary(); + dict["images"] = images; + var op = tf.OpDefLib._apply_op_helper("RGBToHSV", name: name, keywords: dict); + return op.output; + } + + /// + /// Randomly crop image. + /// + /// + /// 3-D of shape [height, width, channels]. + /// + /// + /// 1-D of length 2 containing: crop_height, crop_width.. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomCrop'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// 3-D of shape [crop_height, crop_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// size is a 1-D int64 tensor with 2 elements representing the crop height and + /// width. The values must be non negative. + /// + /// This Op picks a random location in image and crops a height by width + /// rectangle from that location. The random location is picked so the cropped + /// area will fit inside the original image. + /// + public static Tensor random_crop(Tensor image, Tensor size, int? seed = null, int? seed2 = null, string name = "RandomCrop") + { + var dict = new Dictionary(); + dict["image"] = image; + dict["size"] = size; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomCrop", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a Dataset that returns pseudorandom numbers. + /// + /// + /// A scalar seed for the random number generator. If either seed or + /// seed2 is set to be non-zero, the random number generator is seeded + /// by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second scalar seed to avoid seed collision. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor random_dataset(Tensor seed, Tensor seed2, TF_DataType[] output_types, Shape[] output_shapes, string name = "RandomDataset") + { + var dict = new Dictionary(); + dict["seed"] = seed; + dict["seed2"] = seed2; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("RandomDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from the Gamma distribution(s) described by alpha. + /// + /// + /// 1-D integer tensor. Shape of independent samples to draw from each + /// distribution described by the shape parameters given in alpha. + /// + /// + /// A tensor in which each scalar is a "shape" parameter describing the + /// associated gamma distribution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomGamma'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor with shape shape + shape(alpha). Each slice + /// [:, ..., :, i0, i1, ...iN] contains the samples drawn for + /// alpha[i0, i1, ...iN]. The dtype of the output matches the dtype of alpha. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op uses the algorithm by Marsaglia et al. to acquire samples via + /// transformation-rejection from pairs of uniform and normal random variables. + /// See http://dl.acm.org/citation.cfm?id=358414 + /// + public static Tensor random_gamma(Tensor shape, Tensor alpha, int? seed = null, int? seed2 = null, string name = "RandomGamma") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["alpha"] = alpha; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomGamma", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the derivative of a Gamma random sample w.r.t. alpha. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomGammaGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor random_gamma_grad(Tensor alpha, Tensor sample, string name = "RandomGammaGrad") + { + var dict = new Dictionary(); + dict["alpha"] = alpha; + dict["sample"] = sample; + var op = tf.OpDefLib._apply_op_helper("RandomGammaGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Use RandomPoissonV2 instead. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomPoisson'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor random_poisson(Tensor shape, Tensor rate, int? seed = null, int? seed2 = null, string name = "RandomPoisson") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["rate"] = rate; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomPoisson", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from the Poisson distribution(s) described by rate. + /// + /// + /// 1-D integer tensor. Shape of independent samples to draw from each + /// distribution described by the shape parameters given in rate. + /// + /// + /// A tensor in which each scalar is a "rate" parameter describing the + /// associated poisson distribution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomPoissonV2'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// + /// + /// A tensor with shape shape + shape(rate). Each slice + /// [:, ..., :, i0, i1, ...iN] contains the samples drawn for + /// rate[i0, i1, ...iN]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op uses two algorithms, depending on rate. If rate &gt;= 10, then + /// the algorithm by Hormann is used to acquire samples via + /// transformation-rejection. + /// See http://www.sciencedirect.com/science/article/pii/0167668793909974. + /// + /// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform + /// random variables. + /// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer + /// Programming, Volume 2. Addison Wesley + /// + public static Tensor random_poisson_v2(Tensor shape, Tensor rate, int? seed = null, int? seed2 = null, TF_DataType? dtype = null, string name = "RandomPoissonV2") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["rate"] = rate; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("RandomPoissonV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Randomly shuffles a tensor along its first dimension. + /// + /// + /// The tensor to be shuffled. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomShuffle'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor of same shape and type as value, shuffled along its first + /// dimension. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The tensor is shuffled along dimension 0, such that each value[j] is mapped + /// to one and only one output[i]. For example, a mapping that might occur for a + /// 3x2 tensor is: + /// + /// + /// [[1, 2], [[5, 6], + /// [3, 4], ==&gt; [1, 2], + /// [5, 6]] [3, 4]] + /// + /// + public static Tensor random_shuffle(Tensor value, int? seed = null, int? seed2 = null, string name = "RandomShuffle") + { + var dict = new Dictionary(); + dict["value"] = value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomShuffle", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that randomizes the order of elements. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomShuffleQueue'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// Dequeue will block unless there would be this + /// many elements after the dequeue or the queue is closed. This + /// ensures a minimum level of mixing of elements. + /// + /// + /// If either seed or seed2 is set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor random_shuffle_queue(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, int? min_after_dequeue = null, int? seed = null, int? seed2 = null, string container = null, string shared_name = null, string name = "RandomShuffleQueue") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (min_after_dequeue.HasValue) + dict["min_after_dequeue"] = min_after_dequeue.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("RandomShuffleQueue", name: name, keywords: dict); + return op.output; + } + + /// + /// A queue that randomizes the order of elements. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomShuffleQueueV2'. + /// + /// + /// Optional argument + /// The type of each component in a value. + /// + /// + /// The shape of each component in a value. The length of this attr must + /// be either 0 or the same as the length of component_types. If the length of + /// this attr is 0, the shapes of queue elements are not constrained, and + /// only one element may be dequeued at a time. + /// + /// + /// The upper bound on the number of elements in this queue. + /// Negative numbers mean no limit. + /// + /// + /// Dequeue will block unless there would be this + /// many elements after the dequeue or the queue is closed. This + /// ensures a minimum level of mixing of elements. + /// + /// + /// If either seed or seed2 is set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// If non-empty, this queue is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this queue will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the queue. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor random_shuffle_queue_v2(TF_DataType[] component_types, Shape[] shapes = null, int? capacity = null, int? min_after_dequeue = null, int? seed = null, int? seed2 = null, string container = null, string shared_name = null, string name = "RandomShuffleQueueV2") + { + var dict = new Dictionary(); + dict["component_types"] = component_types; + if (shapes != null) + dict["shapes"] = shapes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (min_after_dequeue.HasValue) + dict["min_after_dequeue"] = min_after_dequeue.Value; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("RandomShuffleQueueV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from a normal distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomStandardNormal'. + /// + /// + /// Optional argument + /// The type of the output. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor of the specified shape filled with random normal values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values will have mean 0 and standard deviation 1. + /// + public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype, int? seed = null, int? seed2 = null, string name = "RandomStandardNormal") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomUniform'. + /// + /// + /// Optional argument + /// The type of the output. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor of the specified shape filled with uniform random values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values follow a uniform distribution in the range [0, 1). The + /// lower bound 0 is included in the range, while the upper bound 1 is excluded. + /// + public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = null, int? seed2 = null, string name = "RandomUniform") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomUniform", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random integers from a uniform distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// 0-D. Inclusive lower bound on the generated integers. + /// + /// + /// 0-D. Exclusive upper bound on the generated integers. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RandomUniformInt'. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor of the specified shape filled with uniform random integers. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values are uniform integers in the range [minval, maxval). + /// The lower bound minval is included in the range, while the upper bound + /// maxval is excluded. + /// + /// The random integers are slightly biased unless maxval - minval is an exact + /// power of two. The bias is small for values of maxval - minval significantly + /// smaller than the range of the output (either 2^32 or 2^64). + /// + public static Tensor random_uniform_int(Tensor shape, Tensor minval, Tensor maxval, int? seed = null, int? seed2 = null, string name = "RandomUniformInt") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["minval"] = minval; + dict["maxval"] = maxval; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("RandomUniformInt", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a sequence of numbers. + /// + /// + /// 0-D (scalar). First entry in the sequence. + /// + /// + /// 0-D (scalar). Upper limit of sequence, exclusive. + /// + /// + /// 0-D (scalar). Optional. Default is 1. Number that increments start. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Range'. + /// + /// + /// 1-D. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation creates a sequence of numbers that begins at start and + /// extends by increments of delta up to but not including limit. + /// + /// For example: + /// + /// + /// # 'start' is 3 + /// # 'limit' is 18 + /// # 'delta' is 3 + /// tf.range(start, limit, delta) ==&gt; [3, 6, 9, 12, 15] + /// + /// + public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = "Range") + { + var dict = new Dictionary(); + dict["start"] = start; + dict["limit"] = limit; + dict["delta"] = delta; + var op = tf.OpDefLib._apply_op_helper("Range", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset with a range of values. Corresponds to python's xrange. + /// + /// + /// corresponds to start in python's xrange(). + /// + /// + /// corresponds to stop in python's xrange(). + /// + /// + /// corresponds to step in python's xrange(). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RangeDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, Shape[] output_shapes, string name = "RangeDataset") + { + var dict = new Dictionary(); + dict["start"] = start; + dict["stop"] = stop; + dict["step"] = step; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("RangeDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Rank'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns an integer representing the rank of input. + /// + /// For example: + /// + /// + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// # shape of tensor 't' is [2, 2, 3] + /// rank(t) ==&gt; 3 + /// + /// + /// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank + /// of a tensor is the number of indices required to uniquely select each element + /// of the tensor. Rank is also known as "order", "degree", or "ndims." + /// + public static Tensor rank(Tensor input, string name = "Rank") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("Rank", name: name, keywords: dict); + return op.output; + } + + /// + /// Reads and outputs the entire contents of the input filename. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReadFile'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor read_file(Tensor filename, string name = "ReadFile") + { + var dict = new Dictionary(); + dict["filename"] = filename; + var op = tf.OpDefLib._apply_op_helper("ReadFile", name: name, keywords: dict); + return op.output; + } + + /// + /// Reads the value of a variable. + /// + /// + /// handle to the resource in which to store the variable. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReadVariableOp'. + /// + /// + /// Optional argument + /// the dtype of the value. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The tensor returned by this operation is immutable. + /// + /// The value returned by this operation is guaranteed to be influenced by all the + /// writes on which this operation depends directly or indirectly, and to not be + /// influenced by any of the writes which depend directly or indirectly on this + /// operation. + /// + public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = "ReadVariableOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("ReadVariableOp", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the number of records this Reader has produced. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderNumRecordsProduced'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is the same as the number of ReaderRead executions that have + /// succeeded. + /// + public static Tensor reader_num_records_produced(Tensor reader_handle, string name = "ReaderNumRecordsProduced") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderNumRecordsProduced", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the number of records this Reader has produced. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderNumRecordsProducedV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is the same as the number of ReaderRead executions that have + /// succeeded. + /// + public static Tensor reader_num_records_produced_v2(Tensor reader_handle, string name = "ReaderNumRecordsProducedV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderNumRecordsProducedV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the number of work units this Reader has finished processing. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderNumWorkUnitsCompleted'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor reader_num_work_units_completed(Tensor reader_handle, string name = "ReaderNumWorkUnitsCompleted") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderNumWorkUnitsCompleted", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the number of work units this Reader has finished processing. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderNumWorkUnitsCompletedV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor reader_num_work_units_completed_v2(Tensor reader_handle, string name = "ReaderNumWorkUnitsCompletedV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderNumWorkUnitsCompletedV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the next record (key, value pair) produced by a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// Handle to a Queue, with string work items. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderRead'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// key : A scalar. + /// value : A scalar. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// + public static (Tensor key, Tensor value) reader_read(Tensor reader_handle, Tensor queue_handle, string name = "ReaderRead") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["queue_handle"] = queue_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderRead", name: name, keywords: dict); + int _idx = 0; + var key = op.outputs[_idx++]; + var value = op.outputs[_idx++]; + return (key, value); + } + + /// + /// Returns up to num_records (key, value) pairs produced by a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// Handle to a Queue, with string work items. + /// + /// + /// number of records to read from Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderReadUpTo'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// keys : A 1-D tensor. + /// values : A 1-D tensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// It may return less than num_records even before the last batch. + /// + public static (Tensor keys, Tensor values) reader_read_up_to(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string name = "ReaderReadUpTo") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["queue_handle"] = queue_handle; + dict["num_records"] = num_records; + var op = tf.OpDefLib._apply_op_helper("ReaderReadUpTo", name: name, keywords: dict); + int _idx = 0; + var keys = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + return (keys, values); + } + + /// + /// Returns up to num_records (key, value) pairs produced by a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// Handle to a Queue, with string work items. + /// + /// + /// number of records to read from Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderReadUpToV2'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// keys : A 1-D tensor. + /// values : A 1-D tensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// It may return less than num_records even before the last batch. + /// + public static (Tensor keys, Tensor values) reader_read_up_to_v2(Tensor reader_handle, Tensor queue_handle, Tensor num_records, string name = "ReaderReadUpToV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["queue_handle"] = queue_handle; + dict["num_records"] = num_records; + var op = tf.OpDefLib._apply_op_helper("ReaderReadUpToV2", name: name, keywords: dict); + int _idx = 0; + var keys = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + return (keys, values); + } + + /// + /// Returns the next record (key, value pair) produced by a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// Handle to a Queue, with string work items. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderReadV2'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// key : A scalar. + /// value : A scalar. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Will dequeue from the input queue if necessary (e.g. when the + /// Reader needs to start reading from a new file since it has finished + /// with the previous file). + /// + public static (Tensor key, Tensor value) reader_read_v2(Tensor reader_handle, Tensor queue_handle, string name = "ReaderReadV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["queue_handle"] = queue_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderReadV2", name: name, keywords: dict); + int _idx = 0; + var key = op.outputs[_idx++]; + var value = op.outputs[_idx++]; + return (key, value); + } + + /// + /// Restore a Reader to its initial clean state. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderReset'. + /// + /// + /// Returns the description of the operation + /// + public static Operation reader_reset(Tensor reader_handle, string name = "ReaderReset") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderReset", name: name, keywords: dict); + return op; + } + + /// + /// Restore a Reader to its initial clean state. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderResetV2'. + /// + /// + /// Returns the description of the operation + /// + public static Operation reader_reset_v2(Tensor reader_handle, string name = "ReaderResetV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderResetV2", name: name, keywords: dict); + return op; + } + + /// + /// Restore a reader to a previously saved state. + /// + /// + /// Handle to a Reader. + /// + /// + /// Result of a ReaderSerializeState of a Reader with type + /// matching reader_handle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderRestoreState'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Not all Readers support being restored, so this can produce an + /// Unimplemented error. + /// + public static Operation reader_restore_state(Tensor reader_handle, Tensor state, string name = "ReaderRestoreState") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["state"] = state; + var op = tf.OpDefLib._apply_op_helper("ReaderRestoreState", name: name, keywords: dict); + return op; + } + + /// + /// Restore a reader to a previously saved state. + /// + /// + /// Handle to a Reader. + /// + /// + /// Result of a ReaderSerializeState of a Reader with type + /// matching reader_handle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderRestoreStateV2'. + /// + /// + /// Returns the description of the operation + /// + /// + /// Not all Readers support being restored, so this can produce an + /// Unimplemented error. + /// + public static Operation reader_restore_state_v2(Tensor reader_handle, Tensor state, string name = "ReaderRestoreStateV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + dict["state"] = state; + var op = tf.OpDefLib._apply_op_helper("ReaderRestoreStateV2", name: name, keywords: dict); + return op; + } + + /// + /// Produce a string tensor that encodes the state of a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderSerializeState'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Not all Readers support being serialized, so this can produce an + /// Unimplemented error. + /// + public static Tensor reader_serialize_state(Tensor reader_handle, string name = "ReaderSerializeState") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderSerializeState", name: name, keywords: dict); + return op.output; + } + + /// + /// Produce a string tensor that encodes the state of a Reader. + /// + /// + /// Handle to a Reader. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReaderSerializeStateV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Not all Readers support being serialized, so this can produce an + /// Unimplemented error. + /// + public static Tensor reader_serialize_state_v2(Tensor reader_handle, string name = "ReaderSerializeStateV2") + { + var dict = new Dictionary(); + dict["reader_handle"] = reader_handle; + var op = tf.OpDefLib._apply_op_helper("ReaderSerializeStateV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the real part of a complex number. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Real'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input of complex numbers, this operation returns a tensor of + /// type float that is the real part of each element in input. All elements in + /// input must be complex numbers of the form \\(a + bj\\), where *a* is the real + /// part returned by this operation and *b* is the imaginary part. + /// + /// For example: + /// + /// + /// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] + /// tf.real(input) ==&gt; [-2.25, 3.25] + /// + /// + public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real") + { + TF_DataType Tin = input.GetDataType(); + return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); + +// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input})); + } + + /// + /// Returns x / y element-wise for real types. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RealDiv'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If x and y are reals, this will return the floating-point division. + /// + /// *NOTE*: Div supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor real_div(Tensor x, Tensor y, string name = "RealDiv") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("RealDiv", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the reciprocal of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Reciprocal'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = 1 / x\\). + /// + public static Tensor reciprocal(Tensor x, string name = "Reciprocal") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Reciprocal", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient for the inverse of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReciprocalGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = -dy * y*y, where y = 1/x, and dy + /// is the corresponding input gradient. + /// + public static Tensor reciprocal_grad(Tensor y, Tensor dy, string name = "ReciprocalGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("ReciprocalGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Emits randomized records. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RecordInput'. + /// + /// + /// Optional argument + /// Glob pattern for the data files. + /// + /// + /// Random seeds used to produce randomized records. + /// + /// + /// Shifts the list of files after the list is randomly + /// shuffled. + /// + /// + /// The randomization shuffling buffer. + /// + /// + /// How many sstables are opened and concurrently iterated over. + /// + /// + /// The batch size. + /// + /// + /// The type of compression for the file. Currently ZLIB and + /// GZIP are supported. Defaults to none. + /// + /// + /// A tensor of shape [batch_size]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor record_input(string file_pattern, int? file_random_seed = null, float? file_shuffle_shift_ratio = null, int? file_buffer_size = null, int? file_parallelism = null, int? batch_size = null, string compression_type = null, string name = "RecordInput") + { + var dict = new Dictionary(); + dict["file_pattern"] = file_pattern; + if (file_random_seed.HasValue) + dict["file_random_seed"] = file_random_seed.Value; + if (file_shuffle_shift_ratio.HasValue) + dict["file_shuffle_shift_ratio"] = file_shuffle_shift_ratio.Value; + if (file_buffer_size.HasValue) + dict["file_buffer_size"] = file_buffer_size.Value; + if (file_parallelism.HasValue) + dict["file_parallelism"] = file_parallelism.Value; + if (batch_size.HasValue) + dict["batch_size"] = batch_size.Value; + if (compression_type != null) + dict["compression_type"] = compression_type; + var op = tf.OpDefLib._apply_op_helper("RecordInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Joins a string Tensor across the given dimensions. + /// + /// + /// The input to be joined. All reduced indices must have non-zero size. + /// + /// + /// The dimensions to reduce over. Dimensions are reduced in the + /// order specified. Omitting reduction_indices is equivalent to passing + /// [n-1, n-2, ..., 0]. Negative indices from -n to -1 are supported. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReduceJoin'. + /// + /// + /// If True, retain reduced dimensions with length 1. + /// + /// + /// The separator to use when joining. + /// + /// + /// Has shape equal to that of the input with reduced dimensions removed or + /// set to 1 depending on keep_dims. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Computes the string join across dimensions in the given string Tensor of shape + /// [\\(d_0, d_1, ..., d_{n-1}\\)]. Returns a new Tensor created by joining the input + /// strings with the given separator (default: empty string). Negative indices are + /// counted backwards from the end, with -1 being equivalent to n - 1. If + /// indices are not specified, joins across all dimensions beginning from n - 1 + /// through 0. + /// + /// For example: + /// + /// + /// # tensor a is [["a", "b"], ["c", "d"]] + /// tf.reduce_join(a, 0) ==&gt; ["ac", "bd"] + /// tf.reduce_join(a, 1) ==&gt; ["ab", "cd"] + /// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==&gt; ["ac", "bd"] + /// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==&gt; ["ab", "cd"] + /// tf.reduce_join(a, 0, keep_dims=True) ==&gt; [["ac", "bd"]] + /// tf.reduce_join(a, 1, keep_dims=True) ==&gt; [["ab"], ["cd"]] + /// tf.reduce_join(a, 0, separator=".") ==&gt; ["a.c", "b.d"] + /// tf.reduce_join(a, [0, 1]) ==&gt; "acbd" + /// tf.reduce_join(a, [1, 0]) ==&gt; "abcd" + /// tf.reduce_join(a, []) ==&gt; [["a", "b"], ["c", "d"]] + /// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==&gt; "abcd" + /// + /// + public static Tensor reduce_join(Tensor inputs, Tensor reduction_indices, bool? keep_dims = null, string separator = null, string name = "ReduceJoin") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + if (separator != null) + dict["separator"] = separator; + var op = tf.OpDefLib._apply_op_helper("ReduceJoin", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates or finds a child frame, and makes data available to the child frame. + /// + /// + /// The tensor to be made available to the child frame. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefEnter'. + /// + /// + /// Optional argument + /// The name of the child frame. + /// + /// + /// If true, the output is constant within the child frame. + /// + /// + /// The number of iterations allowed to run in parallel. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The unique frame_name is used by the Executor to identify frames. If + /// is_constant is true, output is a constant in the child frame; otherwise + /// it may be changed in the child frame. At most parallel_iterations iterations + /// are run in parallel in the child frame. + /// + public static Tensor ref_enter(Tensor data, string frame_name, bool? is_constant = null, int? parallel_iterations = null, string name = "RefEnter") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["frame_name"] = frame_name; + if (is_constant.HasValue) + dict["is_constant"] = is_constant.Value; + if (parallel_iterations.HasValue) + dict["parallel_iterations"] = parallel_iterations.Value; + var op = tf.OpDefLib._apply_op_helper("RefEnter", name: name, keywords: dict); + return op.output; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// The tensor to be made available to the parent frame. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefExit'. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Exit makes its input data available to the parent frame. + /// + public static Tensor ref_exit(Tensor data, string name = "RefExit") + { + var dict = new Dictionary(); + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("RefExit", name: name, keywords: dict); + return op.output; + } + + /// + /// Return the same ref tensor as the input ref tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefIdentity'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ref_identity(Tensor input, string name = "RefIdentity") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("RefIdentity", name: name, keywords: dict); + return op.output; + } + + /// + /// Forwards the value of an available tensor from inputs to output. + /// + /// + /// The input tensors, exactly one of which will become available. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefMerge'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : Will be set to the available input tensor. + /// value_index : The index of the chosen input tensor in inputs. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Merge waits for at least one of the tensors in inputs to become available. + /// It is usually combined with Switch to implement branching. + /// + /// Merge forwards the first tensor for become available to output, and sets + /// value_index to its index in inputs. + /// + public static (Tensor output, Tensor value_index) ref_merge(Tensor[] inputs, string name = "RefMerge") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("RefMerge", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var value_index = op.outputs[_idx++]; + return (output, value_index); + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// The tensor to be made available to the next iteration. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefNextIteration'. + /// + /// + /// The same tensor as data. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ref_next_iteration(Tensor data, string name = "RefNextIteration") + { + var dict = new Dictionary(); + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("RefNextIteration", name: name, keywords: dict); + return op.output; + } + + /// + /// Forwards the indexth element of inputs to output. + /// + /// + /// A scalar that determines the input that gets selected. + /// + /// + /// A list of ref tensors, one of which will be forwarded to output. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefSelect'. + /// + /// + /// The forwarded tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor ref_select(Tensor index, Tensor[] inputs, string name = "RefSelect") + { + var dict = new Dictionary(); + dict["index"] = index; + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("RefSelect", name: name, keywords: dict); + return op.output; + } + + /// + /// Forwards the ref tensor data to the output port determined by pred. + /// + /// + /// The ref tensor to be forwarded to the appropriate output. + /// + /// + /// A scalar that specifies which output port will receive data. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RefSwitch'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_false : If pred is false, data will be forwarded to this output. + /// output_true : If pred is true, data will be forwarded to this output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If pred is true, the data input is forwarded to output_true. Otherwise, + /// the data goes to output_false. + /// + /// See also Switch and Merge. + /// + public static (Tensor output_false, Tensor output_true) ref_switch(Tensor data, Tensor pred, string name = "RefSwitch") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["pred"] = pred; + var op = tf.OpDefLib._apply_op_helper("RefSwitch", name: name, keywords: dict); + int _idx = 0; + var output_false = op.outputs[_idx++]; + var output_true = op.outputs[_idx++]; + return (output_false, output_true); + } + + /// + /// Check if the input matches the regex pattern. + /// + /// + /// A string tensor of the text to be processed. + /// + /// + /// A scalar string tensor containing the regular expression to match the input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RegexFullMatch'. + /// + /// + /// A bool tensor with the same shape as input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input is a string tensor of any shape. The pattern is a scalar + /// string tensor which is applied to every element of the input tensor. + /// The boolean values (True or False) of the output tensor indicate + /// if the input matches the regex pattern provided. + /// + /// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) + /// + public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "RegexFullMatch", name, input, pattern)); + return result[0]; + } + var dict = new Dictionary(); + dict["input"] = input; + dict["pattern"] = pattern; + var op = tf.OpDefLib._apply_op_helper("RegexFullMatch", name: name, keywords: dict); + return op.output; + } + + /// + /// Replaces the match of pattern in input with rewrite. + /// + /// + /// The text to be processed. + /// + /// + /// The regular expression to match the input. + /// + /// + /// The rewrite to be applied to the matched expresion. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RegexReplace'. + /// + /// + /// If True, the replacement is global, otherwise the replacement + /// is done only on the first match. + /// + /// + /// The text after applying pattern and rewrite. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) + /// + public static Tensor regex_replace(Tensor input, Tensor pattern, Tensor rewrite, bool? replace_global = null, string name = "RegexReplace") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["pattern"] = pattern; + dict["rewrite"] = rewrite; + if (replace_global.HasValue) + dict["replace_global"] = replace_global.Value; + var op = tf.OpDefLib._apply_op_helper("RegexReplace", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes rectified linear: max(features, 0). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Relu'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor relu(Tensor features, string name = "Relu") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Relu", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes rectified linear 6: min(max(features, 0), 6). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Relu6'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor relu6(Tensor features, string name = "Relu6") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Relu6", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes rectified linear 6 gradients for a Relu6 operation. + /// + /// + /// The backpropagated gradients to the corresponding Relu6 operation. + /// + /// + /// The features passed as input to the corresponding Relu6 operation, or + /// its output; using either one produces the same result. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Relu6Grad'. + /// + /// + /// The gradients: + /// gradients * (features &gt; 0) * (features &lt; 6). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor relu6grad(Tensor gradients, Tensor features, string name = "Relu6Grad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Relu6Grad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes rectified linear gradients for a Relu operation. + /// + /// + /// The backpropagated gradients to the corresponding Relu operation. + /// + /// + /// The features passed as input to the corresponding Relu operation, OR + /// the outputs of that operation (both work equivalently). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReluGrad'. + /// + /// + /// gradients * (features &gt; 0). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor relu_grad(Tensor gradients, Tensor features, string name = "ReluGrad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("ReluGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Execute a sub graph on a remote processor. + /// + /// + /// Arbitrary number of tensors with arbitrary data types + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RemoteFusedGraphExecute'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// Serialized protocol buffer + /// of RemoteFusedGraphExecuteInfo which contains graph specifications. + /// + /// + /// Arbitrary number of tensors with arbitrary data types + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The graph specifications(such as graph itself, input tensors and output names) + /// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo + /// as serialized_remote_fused_graph_execute_info. + /// The specifications will be passed to a dedicated registered + /// remote fused graph executor. The executor will send the graph specifications + /// to a remote processor and execute that graph. The execution results + /// will be passed to consumer nodes as outputs of this node. + /// + public static Tensor[] remote_fused_graph_execute(Tensor[] inputs, TF_DataType[] Toutputs, string serialized_remote_fused_graph_execute_info, string name = "RemoteFusedGraphExecute") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + dict["Toutputs"] = Toutputs; + dict["serialized_remote_fused_graph_execute_info"] = serialized_remote_fused_graph_execute_info; + var op = tf.OpDefLib._apply_op_helper("RemoteFusedGraphExecute", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// Creates a dataset that emits the outputs of input_dataset count times. + /// + /// + /// + /// + /// A scalar representing the number of times that input_dataset should + /// be repeated. A value of -1 indicates that it should be repeated infinitely. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RepeatDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, Shape[] output_shapes, string name = "RepeatDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["count"] = count; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("RepeatDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Given a quantized tensor described by (input, input_min, input_max), outputs a + /// + /// + /// + /// + /// The float value that the minimum quantized input value represents. + /// + /// + /// The float value that the maximum quantized input value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RequantizationRange'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_min : The computed min output. + /// output_max : the computed max output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// range that covers the actual values present in that tensor. This op is + /// typically used to produce the requested_output_min and requested_output_max for + /// Requantize. + /// + public static (Tensor output_min, Tensor output_max) requantization_range(Tensor input, Tensor input_min, Tensor input_max, string name = "RequantizationRange") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + var op = tf.OpDefLib._apply_op_helper("RequantizationRange", name: name, keywords: dict); + int _idx = 0; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output_min, output_max); + } + + /// + /// Convert the quantized 'input' tensor into a lower-precision 'output', using the + /// + /// + /// + /// + /// The float value that the minimum quantized input value represents. + /// + /// + /// The float value that the maximum quantized input value represents. + /// + /// + /// The float value that the minimum quantized output value represents. + /// + /// + /// The float value that the maximum quantized output value represents. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Requantize'. + /// + /// + /// Optional argument + /// The type of the output. Should be a lower bit depth than Tinput. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output : + /// output_min : The requested_output_min value is copied into this output. + /// output_max : The requested_output_max value is copied into this output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// output range specified with 'requested_output_min' and 'requested_output_max'. + /// + /// [input_min, input_max] are scalar floats that specify the range for the float + /// interpretation of the 'input' data. For example, if input_min is -1.0f and + /// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 + /// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. + /// + public static (Tensor output, Tensor output_min, Tensor output_max) requantize(Tensor input, Tensor input_min, Tensor input_max, Tensor requested_output_min, Tensor requested_output_max, TF_DataType out_type, string name = "Requantize") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["input_min"] = input_min; + dict["input_max"] = input_max; + dict["requested_output_min"] = requested_output_min; + dict["requested_output_max"] = requested_output_max; + dict["out_type"] = out_type; + var op = tf.OpDefLib._apply_op_helper("Requantize", name: name, keywords: dict); + int _idx = 0; + var output = op.outputs[_idx++]; + var output_min = op.outputs[_idx++]; + var output_max = op.outputs[_idx++]; + return (output, output_min, output_max); + } + + /// + /// Reshapes a tensor. + /// + /// + /// + /// + /// Defines the shape of the output tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Reshape'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given tensor, this operation returns a tensor that has the same values + /// as tensor with shape shape. + /// + /// If one component of shape is the special value -1, the size of that dimension + /// is computed so that the total size remains constant. In particular, a shape + /// of [-1] flattens into 1-D. At most one component of shape can be -1. + /// + /// If shape is 1-D or higher, then the operation returns a tensor with shape + /// shape filled with the values of tensor. In this case, the number of elements + /// implied by shape must be the same as the number of elements in tensor. + /// + /// For example: + /// + /// + /// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] + /// # tensor 't' has shape [9] + /// reshape(t, [3, 3]) ==&gt; [[1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9]] + /// + /// # tensor 't' is [[[1, 1], [2, 2]], + /// # [[3, 3], [4, 4]]] + /// # tensor 't' has shape [2, 2, 2] + /// reshape(t, [2, 4]) ==&gt; [[1, 1, 2, 2], + /// [3, 3, 4, 4]] + /// + /// # tensor 't' is [[[1, 1, 1], + /// # [2, 2, 2]], + /// # [[3, 3, 3], + /// # [4, 4, 4]], + /// # [[5, 5, 5], + /// # [6, 6, 6]]] + /// # tensor 't' has shape [3, 2, 3] + /// # pass '[-1]' to flatten 't' + /// reshape(t, [-1]) ==&gt; [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] + /// + /// # -1 can also be used to infer the shape + /// + /// # -1 is inferred to be 9: + /// reshape(t, [2, -1]) ==&gt; [[1, 1, 1, 2, 2, 2, 3, 3, 3], + /// [4, 4, 4, 5, 5, 5, 6, 6, 6]] + /// # -1 is inferred to be 2: + /// reshape(t, [-1, 9]) ==&gt; [[1, 1, 1, 2, 2, 2, 3, 3, 3], + /// [4, 4, 4, 5, 5, 5, 6, 6, 6]] + /// # -1 is inferred to be 3: + /// reshape(t, [ 2, -1, 3]) ==&gt; [[[1, 1, 1], + /// [2, 2, 2], + /// [3, 3, 3]], + /// [[4, 4, 4], + /// [5, 5, 5], + /// [6, 6, 6]]] + /// + /// # tensor 't' is [7] + /// # shape [] reshapes to a scalar + /// reshape(t, []) ==&gt; 7 + /// + /// + public static Tensor reshape(Tensor tensor, Tensor shape, string name = "Reshape") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("Reshape", name: name, keywords: dict); + return op.output; + } + + /// + /// Resize images to size using area interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeArea'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// 4-D with shape + /// [batch, new_height, new_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Input images can be of different types but output images are always float. + /// + /// The range of pixel values for the output image might be slightly different + /// from the range for the input image because of limited numerical precision. + /// To guarantee an output range, for example [0.0, 1.0], apply + /// tf.clip_by_value to the output. + /// + /// Each output pixel is computed by first transforming the pixel's footprint into + /// the input tensor and then averaging the pixels that intersect the footprint. An + /// input pixel's contribution to the average is weighted by the fraction of its + /// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. + /// + public static Tensor resize_area(Tensor images, Tensor size, bool? align_corners = null, string name = "ResizeArea") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["size"] = size; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeArea", name: name, keywords: dict); + return op.output; + } + + /// + /// Resize images to size using bicubic interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeBicubic'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// 4-D with shape + /// [batch, new_height, new_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Input images can be of different types but output images are always float. + /// + public static Tensor resize_bicubic(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = "ResizeBicubic") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["size"] = size; + dict["align_corners"] = align_corners; + var op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of bicubic interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// 4-D with shape [batch, orig_height, orig_width, channels], + /// The image tensor that was resized. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeBicubicGrad'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and grad tensors are + /// aligned. Defaults to false. + /// + /// + /// 4-D with shape [batch, orig_height, orig_width, channels]. + /// Gradients with respect to the input image. Input image must have been + /// float or double. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor resize_bicubic_grad(Tensor grads, Tensor original_image, bool? align_corners = null, string name = "ResizeBicubicGrad") + { + var dict = new Dictionary(); + dict["grads"] = grads; + dict["original_image"] = original_image; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeBicubicGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Resize images to size using bilinear interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeBilinear'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// 4-D with shape + /// [batch, new_height, new_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Input images can be of different types but output images are always float. + /// + public static Tensor resize_bilinear(Tensor images, Tensor size, bool? align_corners = null, string name = "ResizeBilinear") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["size"] = size; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of bilinear interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// 4-D with shape [batch, orig_height, orig_width, channels], + /// The image tensor that was resized. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeBilinearGrad'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and grad tensors are + /// aligned. Defaults to false. + /// + /// + /// 4-D with shape [batch, orig_height, orig_width, channels]. + /// Gradients with respect to the input image. Input image must have been + /// float or double. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor resize_bilinear_grad(Tensor grads, Tensor original_image, bool? align_corners = null, string name = "ResizeBilinearGrad") + { + var dict = new Dictionary(); + dict["grads"] = grads; + dict["original_image"] = original_image; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeBilinearGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Resize images to size using nearest neighbor interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: new_height, new_width. The + /// new size for the images. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeNearestNeighbor'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and output tensors are + /// aligned, preserving the values at the corner pixels. Defaults to false. + /// + /// + /// 4-D with shape + /// [batch, new_height, new_width, channels]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor resize_nearest_neighbor(Tensor images, Tensor size, bool? align_corners = null, string name = "ResizeNearestNeighbor") + { + var dict = new Dictionary(); + dict["images"] = images; + dict["size"] = size; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of nearest neighbor interpolation. + /// + /// + /// 4-D with shape [batch, height, width, channels]. + /// + /// + /// = A 1-D int32 Tensor of 2 elements: orig_height, orig_width. The + /// original input size. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResizeNearestNeighborGrad'. + /// + /// + /// If true, the centers of the 4 corner pixels of the input and grad tensors are + /// aligned. Defaults to false. + /// + /// + /// 4-D with shape [batch, orig_height, orig_width, channels]. Gradients + /// with respect to the input image. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool? align_corners = null, string name = "ResizeNearestNeighborGrad") + { + var dict = new Dictionary(); + dict["grads"] = grads; + dict["size"] = size; + if (align_corners.HasValue) + dict["align_corners"] = align_corners.Value; + var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the AdaMax algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAdaMax'. + /// + /// + /// If True, updating of the var, m, and v tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// v_t &lt;- max(beta2 * v_{t-1}, abs(g)) + /// variable &lt;- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) + /// + public static Operation resource_apply_ada_max(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ResourceApplyAdaMax") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["v"] = v; + dict["beta1_power"] = beta1_power; + dict["lr"] = lr; + dict["beta1"] = beta1; + dict["beta2"] = beta2; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAdaMax", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the adadelta scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay factor. Must be a scalar. + /// + /// + /// Constant factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAdadelta'. + /// + /// + /// If True, updating of the var, accum and update_accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// accum = rho() * accum + (1 - rho()) * grad.square(); + /// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; + /// update_accum = rho() * update_accum + (1 - rho()) * update.square(); + /// var -= update; + /// + public static Operation resource_apply_adadelta(Tensor var, Tensor accum, Tensor accum_update, Tensor lr, Tensor rho, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ResourceApplyAdadelta") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["accum_update"] = accum_update; + dict["lr"] = lr; + dict["rho"] = rho; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAdadelta", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// + /// + /// Returns the description of the operation + /// + /// + /// accum += grad * grad + /// var -= lr * grad * (1 / sqrt(accum)) + /// + public static Operation resource_apply_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor grad, bool? use_locking = null, bool? update_slots = null, string name = "ResourceApplyAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (update_slots.HasValue) + dict["update_slots"] = update_slots.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAdagrad", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the proximal adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Training step number. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAdagradDA'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + public static Operation resource_apply_adagrad_d_a(Tensor var, Tensor gradient_accumulator, Tensor gradient_squared_accumulator, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor global_step, bool? use_locking = null, string name = "ResourceApplyAdagradDA") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["gradient_accumulator"] = gradient_accumulator; + dict["gradient_squared_accumulator"] = gradient_squared_accumulator; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["global_step"] = global_step; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAdagradDA", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the Adam algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Momentum factor. Must be a scalar. + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAdam'. + /// + /// + /// If True, updating of the var, m, and v tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, uses the nesterov update. + /// + /// + /// Returns the description of the operation + /// + /// + /// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + /// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + /// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + /// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + /// + public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool? use_locking = null, bool? use_nesterov = null, string name = "ResourceApplyAdam") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["v"] = v; + dict["beta1_power"] = beta1_power; + dict["beta2_power"] = beta2_power; + dict["lr"] = lr; + dict["beta1"] = beta1; + dict["beta2"] = beta2; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAdam", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the AddSign update. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyAddSign'. + /// + /// + /// If True, updating of the var and m tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// update &lt;- (alpha + sign_decay * sign(g) *sign(m)) * g + /// variable &lt;- variable - lr_t * update + /// + public static Operation resource_apply_add_sign(Tensor var, Tensor m, Tensor lr, Tensor alpha, Tensor sign_decay, Tensor beta, Tensor grad, bool? use_locking = null, string name = "ResourceApplyAddSign") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["lr"] = lr; + dict["alpha"] = alpha; + dict["sign_decay"] = sign_decay; + dict["beta"] = beta; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyAddSign", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the centered RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyCenteredRMSProp'. + /// + /// + /// If True, updating of the var, mg, ms, and mom tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// The centered RMSProp algorithm uses an estimate of the centered second moment + /// (i.e., the variance) for normalization, as opposed to regular RMSProp, which + /// uses the (uncentered) second moment. This often helps with training, but is + /// slightly more expensive in terms of computation and memory. + /// + /// Note that in dense implementation of this algorithm, mg, ms, and mom will + /// update even if the grad is zero, but in this sparse implementation, mg, ms, + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// mean_grad = decay * mean_grad + (1-decay) * gradient + /// + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) + /// + /// mg &lt;- rho * mg_{t-1} + (1-rho) * grad + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) + /// var &lt;- var - mom + /// + public static Operation resource_apply_centered_r_m_s_prop(Tensor var, Tensor mg, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ResourceApplyCenteredRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["mg"] = mg; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyCenteredRMSProp", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regulariation. Must be a scalar. + /// + /// + /// L2 regulariation. Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyFtrl'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// accum_new = accum + grad * grad + /// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Operation resource_apply_ftrl(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor lr_power, bool? use_locking = null, string name = "ResourceApplyFtrl") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyFtrl", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regulariation. Must be a scalar. + /// + /// + /// L2 shrinkage regulariation. Must be a scalar. + /// + /// + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyFtrlV2'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// grad_with_shrinkage = grad + 2 * l2_shrinkage * var + /// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage + /// linear += grad_with_shrinkage + + /// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Operation resource_apply_ftrl_v2(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor lr, Tensor l1, Tensor l2, Tensor l2_shrinkage, Tensor lr_power, bool? use_locking = null, string name = "ResourceApplyFtrlV2") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["l2_shrinkage"] = l2_shrinkage; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyFtrlV2", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' by subtracting 'alpha' * 'delta' from it. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The change. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool? use_locking = null, string name = "ResourceApplyGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["delta"] = delta; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyGradientDescent", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the momentum scheme. Set use_nesterov = True if you + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// Momentum. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyMomentum'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, the tensor passed to compute grad will be + /// var - lr * momentum * accum, so in the end, the var you get is actually + /// var - lr * momentum * accum. + /// + /// + /// Returns the description of the operation + /// + /// + /// want to use Nesterov momentum. + /// + /// accum = accum * momentum + grad + /// var -= lr * accum + /// + public static Operation resource_apply_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool? use_locking = null, bool? use_nesterov = null, string name = "ResourceApplyMomentum") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["momentum"] = momentum; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyMomentum", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the AddSign update. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyPowerSign'. + /// + /// + /// If True, updating of the var and m tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// m_t &lt;- beta1 * m_{t-1} + (1 - beta1) * g + /// update &lt;- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g + /// variable &lt;- variable - lr_t * update + /// + public static Operation resource_apply_power_sign(Tensor var, Tensor m, Tensor lr, Tensor logbase, Tensor sign_decay, Tensor beta, Tensor grad, bool? use_locking = null, string name = "ResourceApplyPowerSign") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["m"] = m; + dict["lr"] = lr; + dict["logbase"] = logbase; + dict["sign_decay"] = sign_decay; + dict["beta"] = beta; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyPowerSign", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyProximalAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// accum += grad * grad + /// prox_v = var - lr * grad * (1 / sqrt(accum)) + /// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} + /// + public static Operation resource_apply_proximal_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor l1, Tensor l2, Tensor grad, bool? use_locking = null, string name = "ResourceApplyProximalAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyProximalAdagrad", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' as FOBOS algorithm with fixed learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The change. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyProximalGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// prox_v = var - alpha * delta + /// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} + /// + public static Operation resource_apply_proximal_gradient_descent(Tensor var, Tensor alpha, Tensor l1, Tensor l2, Tensor delta, bool? use_locking = null, string name = "ResourceApplyProximalGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["l1"] = l1; + dict["l2"] = l2; + dict["delta"] = delta; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyProximalGradientDescent", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceApplyRMSProp'. + /// + /// + /// If True, updating of the var, ms, and mom tensors is protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// Note that in dense implementation of this algorithm, ms and mom will + /// update even if the grad is zero, but in this sparse implementation, ms + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + /// + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + /// var &lt;- var - mom + /// + public static Operation resource_apply_r_m_s_prop(Tensor var, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, bool? use_locking = null, string name = "ResourceApplyRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceApplyRMSProp", name: name, keywords: dict); + return op; + } + + /// + /// Increments variable pointed to by 'resource' until it reaches 'limit'. + /// + /// + /// Should be from a scalar Variable node. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceCountUpTo'. + /// + /// + /// Optional argument + /// If incrementing ref would bring it above limit, instead generates an + /// 'OutOfRange' error. + /// + /// + /// Optional argument + /// + /// + /// A copy of the input before increment. If nothing else modifies the + /// input, the values produced will all be distinct. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor resource_count_up_to(Tensor resource, int limit, TF_DataType T, string name = "ResourceCountUpTo") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["limit"] = limit; + dict["T"] = T; + var op = tf.OpDefLib._apply_op_helper("ResourceCountUpTo", name: name, keywords: dict); + return op.output; + } + + /// + /// Gather slices from the variable pointed to by resource according to indices. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceGather'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// indices must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape indices.shape + params.shape[1:] where: + /// + /// + /// # Scalar indices + /// output[:, ..., :] = params[indices, :, ... :] + /// + /// # Vector indices + /// output[i, :, ..., :] = params[indices[i], :, ... :] + /// + /// # Higher rank indices + /// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] + /// + /// + public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, bool? validate_indices = null, string name = "ResourceGather") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["dtype"] = dtype; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceGather", name: name, keywords: dict); + return op.output; + } + + /// + /// Adds sparse updates to the variable referenced by resource. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterAdd'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] += updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] += updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions add. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_add(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterAdd") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterAdd", name: name, keywords: dict); + return op; + } + + /// + /// Divides sparse updates into the variable referenced by resource. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterDiv'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] /= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] /= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions multiply. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_div(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterDiv") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterDiv", name: name, keywords: dict); + return op; + } + + /// + /// Reduces sparse updates into the variable referenced by resource using the max operation. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterMax'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = max(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions are combined. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_max(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterMax") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterMax", name: name, keywords: dict); + return op; + } + + /// + /// Reduces sparse updates into the variable referenced by resource using the min operation. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterMin'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = min(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions are combined. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_min(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterMin") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterMin", name: name, keywords: dict); + return op; + } + + /// + /// Multiplies sparse updates into the variable referenced by resource. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterMul'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] *= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] *= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions multiply. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_mul(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterMul") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterMul", name: name, keywords: dict); + return op; + } + + /// + /// Adds sparse updates to individual values or slices within a given + /// + /// + /// A resource handle. Must be from a VarHandleOp. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into ref. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of + /// values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterNdAdd'. + /// + /// + /// An optional bool. Defaults to True. If True, the assignment will + /// be protected by a lock; otherwise the behavior is undefined, + /// but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// variable according to indices. + /// + /// ref is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into ref. + /// It must be shape [d_0, ..., d_{Q-2}, K] where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or slices (if K &lt; P) along the Kth + /// dimension of ref. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// + /// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. + /// + /// + /// For example, say we want to update 4 scattered elements to a rank-1 tensor to + /// 8 elements. In Python, that update would look like this: + /// + /// + /// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) + /// indices = tf.constant([[4], [3], [1] ,[7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// update = tf.scatter_nd_add(ref, indices, updates) + /// with tf.Session() as sess: + /// print sess.run(update) + /// + /// + /// The resulting update to ref would look like this: + /// + /// [1, 12, 3, 14, 14, 6, 7, 20] + /// + /// See tf.scatter_nd for more details about how to make updates to + /// slices. + /// + public static Operation resource_scatter_nd_add(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ResourceScatterNdAdd") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterNdAdd", name: name, keywords: dict); + return op; + } + + /// + /// Applies sparse updates to individual values or slices within a given + /// + /// + /// A resource handle. Must be from a VarHandleOp. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into ref. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of updated + /// values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterNdUpdate'. + /// + /// + /// An optional bool. Defaults to True. If True, the assignment will + /// be protected by a lock; otherwise the behavior is undefined, + /// but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// variable according to indices. + /// + /// ref is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into ref. + /// It must be shape [d_0, ..., d_{Q-2}, K] where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or slices (if K &lt; P) along the Kth + /// dimension of ref. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// + /// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. + /// + /// + /// For example, say we want to update 4 scattered elements to a rank-1 tensor to + /// 8 elements. In Python, that update would look like this: + /// + /// + /// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1] ,[7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// update = tf.scatter_nd_update(ref, indices, updates) + /// with tf.Session() as sess: + /// print sess.run(update) + /// + /// + /// The resulting update to ref would look like this: + /// + /// [1, 11, 3, 10, 9, 6, 7, 12] + /// + /// See tf.scatter_nd for more details about how to make updates to + /// slices. + /// + public static Operation resource_scatter_nd_update(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ResourceScatterNdUpdate") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterNdUpdate", name: name, keywords: dict); + return op; + } + + /// + /// Subtracts sparse updates from the variable referenced by resource. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterSub'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] -= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] -= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions add. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt&gt; + /// &lt;/div&gt; + /// + public static Operation resource_scatter_sub(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterSub") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterSub", name: name, keywords: dict); + return op; + } + + /// + /// Assigns sparse updates to the variable referenced by resource. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceScatterUpdate'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + /// + public static Operation resource_scatter_update(Tensor resource, Tensor indices, Tensor updates, string name = "ResourceScatterUpdate") + { + var dict = new Dictionary(); + dict["resource"] = resource; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ResourceScatterUpdate", name: name, keywords: dict); + return op; + } + + /// + /// var: Should be from a Variable(). + /// + /// + /// + /// + /// Should be from a Variable(). + /// + /// + /// : Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// Decay factor. Must be a scalar. + /// + /// + /// Constant factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyAdadelta'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + public static Operation resource_sparse_apply_adadelta(Tensor var, Tensor accum, Tensor accum_update, Tensor lr, Tensor rho, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "ResourceSparseApplyAdadelta") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["accum_update"] = accum_update; + dict["lr"] = lr; + dict["rho"] = rho; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyAdadelta", name: name, keywords: dict); + return op; + } + + /// + /// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// + /// + /// Returns the description of the operation + /// + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// accum += grad * grad + /// var -= lr * grad * (1 / sqrt(accum)) + /// + public static Operation resource_sparse_apply_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor indices, bool? use_locking = null, bool? update_slots = null, string name = "ResourceSparseApplyAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (update_slots.HasValue) + dict["update_slots"] = update_slots.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyAdagrad", name: name, keywords: dict); + return op; + } + + /// + /// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Training step number. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyAdagradDA'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + public static Operation resource_sparse_apply_adagrad_d_a(Tensor var, Tensor gradient_accumulator, Tensor gradient_squared_accumulator, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor global_step, bool? use_locking = null, string name = "ResourceSparseApplyAdagradDA") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["gradient_accumulator"] = gradient_accumulator; + dict["gradient_squared_accumulator"] = gradient_squared_accumulator; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["global_step"] = global_step; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyAdagradDA", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the centered RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var, ms and mom. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyCenteredRMSProp'. + /// + /// + /// If True, updating of the var, mg, ms, and mom tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// The centered RMSProp algorithm uses an estimate of the centered second moment + /// (i.e., the variance) for normalization, as opposed to regular RMSProp, which + /// uses the (uncentered) second moment. This often helps with training, but is + /// slightly more expensive in terms of computation and memory. + /// + /// Note that in dense implementation of this algorithm, mg, ms, and mom will + /// update even if the grad is zero, but in this sparse implementation, mg, ms, + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// mean_grad = decay * mean_grad + (1-decay) * gradient + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) + /// + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + /// var &lt;- var - mom + /// + public static Operation resource_sparse_apply_centered_r_m_s_prop(Tensor var, Tensor mg, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "ResourceSparseApplyCenteredRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["mg"] = mg; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyCenteredRMSProp", name: name, keywords: dict); + return op; + } + + /// + /// Update relevant entries in '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyFtrl'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// That is for rows we have grad for, we update var, accum and linear as follows: + /// accum_new = accum + grad * grad + /// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Operation resource_sparse_apply_ftrl(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor lr_power, bool? use_locking = null, string name = "ResourceSparseApplyFtrl") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyFtrl", name: name, keywords: dict); + return op; + } + + /// + /// Update relevant entries in '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 shrinkage regulariation. Must be a scalar. + /// + /// + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyFtrlV2'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// That is for rows we have grad for, we update var, accum and linear as follows: + /// grad_with_shrinkage = grad + 2 * l2_shrinkage * var + /// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage + /// linear += grad_with_shrinkage + + /// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Operation resource_sparse_apply_ftrl_v2(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor l2_shrinkage, Tensor lr_power, bool? use_locking = null, string name = "ResourceSparseApplyFtrlV2") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["l2_shrinkage"] = l2_shrinkage; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyFtrlV2", name: name, keywords: dict); + return op; + } + + /// + /// Update relevant entries in '*var' and '*accum' according to the momentum scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Momentum. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyMomentum'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, the tensor passed to compute grad will be + /// var - lr * momentum * accum, so in the end, the var you get is actually + /// var - lr * momentum * accum. + /// + /// + /// Returns the description of the operation + /// + /// + /// Set use_nesterov = True if you want to use Nesterov momentum. + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// + /// accum = accum * momentum + grad + /// var -= lr * accum + /// + public static Operation resource_sparse_apply_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor indices, Tensor momentum, bool? use_locking = null, bool? use_nesterov = null, string name = "ResourceSparseApplyMomentum") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["indices"] = indices; + dict["momentum"] = momentum; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyMomentum", name: name, keywords: dict); + return op; + } + + /// + /// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyProximalAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// accum += grad * grad + /// prox_v = var + /// prox_v -= lr * grad * (1 / sqrt(accum)) + /// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} + /// + public static Operation resource_sparse_apply_proximal_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor l1, Tensor l2, Tensor grad, Tensor indices, bool? use_locking = null, string name = "ResourceSparseApplyProximalAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyProximalAdagrad", name: name, keywords: dict); + return op; + } + + /// + /// Sparse update '*var' as FOBOS algorithm with fixed learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyProximalGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// That is for rows we have grad for, we update var as follows: + /// prox_v = var - alpha * grad + /// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} + /// + public static Operation resource_sparse_apply_proximal_gradient_descent(Tensor var, Tensor alpha, Tensor l1, Tensor l2, Tensor grad, Tensor indices, bool? use_locking = null, string name = "ResourceSparseApplyProximalGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyProximalGradientDescent", name: name, keywords: dict); + return op; + } + + /// + /// Update '*var' according to the RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var, ms and mom. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceSparseApplyRMSProp'. + /// + /// + /// If True, updating of the var, ms, and mom tensors is protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Returns the description of the operation + /// + /// + /// Note that in dense implementation of this algorithm, ms and mom will + /// update even if the grad is zero, but in this sparse implementation, ms + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + /// + /// ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad + /// mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + /// var &lt;- var - mom + /// + public static Operation resource_sparse_apply_r_m_s_prop(Tensor var, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "ResourceSparseApplyRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceSparseApplyRMSProp", name: name, keywords: dict); + return op; + } + + /// + /// Assign value to the sliced l-value reference of ref. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ResourceStridedSliceAssign'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns the description of the operation + /// + /// + /// The values of value are assigned to the positions in the variable + /// ref that are selected by the slice parameters. The slice parameters + /// begin, end, strides, etc. work exactly as in StridedSlice. + /// + /// NOTE this op currently does not support broadcasting and so value's + /// shape must be exactly the shape produced by the slice of ref. + /// + public static Operation resource_strided_slice_assign(Tensor referecne, Tensor begin, Tensor end, Tensor strides, Tensor value, int? begin_mask = null, int? end_mask = null, int? ellipsis_mask = null, int? new_axis_mask = null, int? shrink_axis_mask = null, string name = "ResourceStridedSliceAssign") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["begin"] = begin; + dict["end"] = end; + dict["strides"] = strides; + dict["value"] = value; + if (begin_mask.HasValue) + dict["begin_mask"] = begin_mask.Value; + if (end_mask.HasValue) + dict["end_mask"] = end_mask.Value; + if (ellipsis_mask.HasValue) + dict["ellipsis_mask"] = ellipsis_mask.Value; + if (new_axis_mask.HasValue) + dict["new_axis_mask"] = new_axis_mask.Value; + if (shrink_axis_mask.HasValue) + dict["shrink_axis_mask"] = shrink_axis_mask.Value; + var op = tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name: name, keywords: dict); + return op; + } + + /// + /// Restores a tensor from checkpoint files. + /// + /// + /// Must have a single element. The pattern of the files from + /// which we read the tensor. + /// + /// + /// Must have a single element. The name of the tensor to be + /// restored. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Restore'. + /// + /// + /// Optional argument + /// The type of the tensor to be restored. + /// + /// + /// Index of file to open first if multiple files match + /// file_pattern. + /// + /// + /// The restored tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reads a tensor stored in one or several files. If there are several files (for + /// instance because a tensor was saved as slices), file_pattern may contain + /// wildcard symbols (* and ?) in the filename portion only, not in the + /// directory portion. + /// + /// If a file_pattern matches several files, preferred_shard can be used to hint + /// in which file the requested tensor is likely to be found. This op will first + /// open the file at index preferred_shard in the list of matching files and try + /// to restore tensors from that file. Only if some tensors or tensor slices are + /// not found in that first file, then the Op opens all the files. Setting + /// preferred_shard to match the value passed as the shard input + /// of a matching Save Op may speed up Restore. This attribute only affects + /// performance, not correctness. The default value -1 means files are processed in + /// order. + /// + /// See also RestoreSlice. + /// + public static Tensor restore(Tensor file_pattern, Tensor tensor_name, TF_DataType dt, int? preferred_shard = null, string name = "Restore") + { + var dict = new Dictionary(); + dict["file_pattern"] = file_pattern; + dict["tensor_name"] = tensor_name; + dict["dt"] = dt; + if (preferred_shard.HasValue) + dict["preferred_shard"] = preferred_shard.Value; + var op = tf.OpDefLib._apply_op_helper("Restore", name: name, keywords: dict); + return op.output; + } + + /// + /// Restores a tensor from checkpoint files. + /// + /// + /// Must have a single element. The pattern of the files from + /// which we read the tensor. + /// + /// + /// Must have a single element. The name of the tensor to be + /// restored. + /// + /// + /// Scalar. The shapes and slice specifications to use when + /// restoring a tensors. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RestoreSlice'. + /// + /// + /// Optional argument + /// The type of the tensor to be restored. + /// + /// + /// Index of file to open first if multiple files match + /// file_pattern. See the documentation for Restore. + /// + /// + /// The restored tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is like Restore except that restored tensor can be listed as filling + /// only a slice of a larger tensor. shape_and_slice specifies the shape of the + /// larger tensor and the slice that the restored tensor covers. + /// + /// The shape_and_slice input has the same format as the + /// elements of the shapes_and_slices input of the SaveSlices op. + /// + public static Tensor restore_slice(Tensor file_pattern, Tensor tensor_name, Tensor shape_and_slice, TF_DataType dt, int? preferred_shard = null, string name = "RestoreSlice") + { + var dict = new Dictionary(); + dict["file_pattern"] = file_pattern; + dict["tensor_name"] = tensor_name; + dict["shape_and_slice"] = shape_and_slice; + dict["dt"] = dt; + if (preferred_shard.HasValue) + dict["preferred_shard"] = preferred_shard.Value; + var op = tf.OpDefLib._apply_op_helper("RestoreSlice", name: name, keywords: dict); + return op.output; + } + + /// + /// Restores tensors from a V2 checkpoint. + /// + /// + /// Must have a single element. The prefix of a V2 checkpoint. + /// + /// + /// shape {N}. The names of the tensors to be restored. + /// + /// + /// shape {N}. The slice specs of the tensors to be restored. + /// Empty strings indicate that they are non-partitioned tensors. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RestoreV2'. + /// + /// + /// Optional argument + /// shape {N}. The list of expected dtype for the tensors. Must match + /// those stored in the checkpoint. + /// + /// + /// shape {N}. The restored tensors, whose shapes are read from the + /// checkpoint directly. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For backward compatibility with the V1 format, this Op currently allows + /// restoring from a V1 checkpoint as well: + /// - This Op first attempts to find the V2 index file pointed to by "prefix", and + /// if found proceed to read it as a V2 checkpoint; + /// - Otherwise the V1 read path is invoked. + /// Relying on this behavior is not recommended, as the ability to fall back to read + /// V1 might be deprecated and eventually removed. + /// + /// By default, restores the named tensors in full. If the caller wishes to restore + /// specific slices of stored tensors, "shape_and_slices" should be non-empty + /// strings and correspondingly well-formed. + /// + /// Callers must ensure all the named tensors are indeed stored in the checkpoint. + /// + public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + Dictionary attrs = new(); + attrs["dtypes"] = dtypes; + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + tf.Context, "RestoreV2", name, prefix, tensor_names, shape_and_slices + ) + { attrs = attrs }); + return result; + } + catch (Exception) + { + try + { + return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); + } + catch (Exception) + { + + } + } + } + var dict = new Dictionary(); + dict["prefix"] = prefix; + dict["tensor_names"] = tensor_names; + dict["shape_and_slices"] = shape_and_slices; + dict["dtypes"] = dtypes; + var op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, keywords: dict); + int _idx = 0; + var tensors = Enumerable.Range(0, op.OutputListLength("tensors")).Select(_ => op.outputs[_idx++]).ToArray(); + return (tensors); + } + + public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) + { + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + object[] attrs = new object[] { "dtypes", dtypes }; + Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; + var result = _execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); + + if (_execute.must_record_gradient()) + { + // TODO(Rinne); record the gradient + } + return result; + } + + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// Up to 8-D. + /// + /// + /// 1-D. The dimensions to reverse. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Reverse'. + /// + /// + /// The same shape as tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor, and a bool tensor dims representing the dimensions + /// of tensor, this operation reverses each dimension i of tensor where + /// dims[i] is True. + /// + /// tensor can have up to 8 dimensions. The number of dimensions + /// of tensor must equal the number of elements in dims. In other words: + /// + /// rank(tensor) = size(dims) + /// + /// For example: + /// + /// + /// # tensor 't' is [[[[ 0, 1, 2, 3], + /// # [ 4, 5, 6, 7], + /// # [ 8, 9, 10, 11]], + /// # [[12, 13, 14, 15], + /// # [16, 17, 18, 19], + /// # [20, 21, 22, 23]]]] + /// # tensor 't' shape is [1, 2, 3, 4] + /// + /// # 'dims' is [False, False, False, True] + /// reverse(t, dims) ==&gt; [[[[ 3, 2, 1, 0], + /// [ 7, 6, 5, 4], + /// [ 11, 10, 9, 8]], + /// [[15, 14, 13, 12], + /// [19, 18, 17, 16], + /// [23, 22, 21, 20]]]] + /// + /// # 'dims' is [False, True, False, False] + /// reverse(t, dims) ==&gt; [[[[12, 13, 14, 15], + /// [16, 17, 18, 19], + /// [20, 21, 22, 23] + /// [[ 0, 1, 2, 3], + /// [ 4, 5, 6, 7], + /// [ 8, 9, 10, 11]]]] + /// + /// # 'dims' is [False, False, True, False] + /// reverse(t, dims) ==&gt; [[[[8, 9, 10, 11], + /// [4, 5, 6, 7], + /// [0, 1, 2, 3]] + /// [[20, 21, 22, 23], + /// [16, 17, 18, 19], + /// [12, 13, 14, 15]]]] + /// + /// + public static Tensor reverse(Tensor tensor, Tensor dims, string name = "Reverse") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["dims"] = dims; + var op = tf.OpDefLib._apply_op_helper("Reverse", name: name, keywords: dict); + return op.output; + } + + /// + /// Reverses variable length slices. + /// + /// + /// The input to reverse. + /// + /// + /// 1-D with length input.dims(batch_dim) and + /// max(seq_lengths) &lt;= input.dims(seq_dim) + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReverseSequence'. + /// + /// + /// Optional argument + /// The dimension which is partially reversed. + /// + /// + /// The dimension along which reversal is performed. + /// + /// + /// The partially reversed input. It has the same shape as input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op first slices input along the dimension batch_dim, and for each + /// slice i, reverses the first seq_lengths[i] elements along + /// the dimension seq_dim. + /// + /// The elements of seq_lengths must obey seq_lengths[i] &lt;= input.dims[seq_dim], + /// and seq_lengths must be a vector of length input.dims[batch_dim]. + /// + /// The output slice i along dimension batch_dim is then given by input + /// slice i, with the first seq_lengths[i] slices along dimension + /// seq_dim reversed. + /// + /// For example: + /// + /// + /// # Given this: + /// batch_dim = 0 + /// seq_dim = 1 + /// input.dims = (4, 8, ...) + /// seq_lengths = [7, 2, 3, 5] + /// + /// # then slices of input are reversed on seq_dim, but only up to seq_lengths: + /// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] + /// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] + /// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] + /// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] + /// + /// # while entries past seq_lens are copied through: + /// output[0, 7:, :, ...] = input[0, 7:, :, ...] + /// output[1, 2:, :, ...] = input[1, 2:, :, ...] + /// output[2, 3:, :, ...] = input[2, 3:, :, ...] + /// output[3, 2:, :, ...] = input[3, 2:, :, ...] + /// + /// + /// In contrast, if: + /// + /// + /// # Given this: + /// batch_dim = 2 + /// seq_dim = 0 + /// input.dims = (8, ?, 4, ...) + /// seq_lengths = [7, 2, 3, 5] + /// + /// # then slices of input are reversed on seq_dim, but only up to seq_lengths: + /// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] + /// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] + /// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] + /// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] + /// + /// # while entries past seq_lens are copied through: + /// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] + /// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] + /// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] + /// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] + /// + /// + public static Tensor reverse_sequence(Tensor input, Tensor seq_lengths, int seq_dim, int? batch_dim = null, string name = "ReverseSequence") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["seq_lengths"] = seq_lengths; + dict["seq_dim"] = seq_dim; + if (batch_dim.HasValue) + dict["batch_dim"] = batch_dim.Value; + var op = tf.OpDefLib._apply_op_helper("ReverseSequence", name: name, keywords: dict); + return op.output; + } + + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// Up to 8-D. + /// + /// + /// 1-D. The indices of the dimensions to reverse. Must be in the range + /// [-rank(tensor), rank(tensor)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ReverseV2'. + /// + /// + /// The same shape as tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// NOTE tf.reverse has now changed behavior in preparation for 1.0. + /// tf.reverse_v2 is currently an alias that will be deprecated before TF 1.0. + /// + /// Given a tensor, and a int32 tensor axis representing the set of + /// dimensions of tensor to reverse. This operation reverses each dimension + /// i for which there exists j s.t. axis[j] == i. + /// + /// tensor can have up to 8 dimensions. The number of dimensions specified + /// in axis may be 0 or more entries. If an index is specified more than + /// once, a InvalidArgument error is raised. + /// + /// For example: + /// + /// + /// # tensor 't' is [[[[ 0, 1, 2, 3], + /// # [ 4, 5, 6, 7], + /// # [ 8, 9, 10, 11]], + /// # [[12, 13, 14, 15], + /// # [16, 17, 18, 19], + /// # [20, 21, 22, 23]]]] + /// # tensor 't' shape is [1, 2, 3, 4] + /// + /// # 'dims' is [3] or 'dims' is [-1] + /// reverse(t, dims) ==&gt; [[[[ 3, 2, 1, 0], + /// [ 7, 6, 5, 4], + /// [ 11, 10, 9, 8]], + /// [[15, 14, 13, 12], + /// [19, 18, 17, 16], + /// [23, 22, 21, 20]]]] + /// + /// # 'dims' is '[1]' (or 'dims' is '[-3]') + /// reverse(t, dims) ==&gt; [[[[12, 13, 14, 15], + /// [16, 17, 18, 19], + /// [20, 21, 22, 23] + /// [[ 0, 1, 2, 3], + /// [ 4, 5, 6, 7], + /// [ 8, 9, 10, 11]]]] + /// + /// # 'dims' is '[2]' (or 'dims' is '[-2]') + /// reverse(t, dims) ==&gt; [[[[8, 9, 10, 11], + /// [4, 5, 6, 7], + /// [0, 1, 2, 3]] + /// [[20, 21, 22, 23], + /// [16, 17, 18, 19], + /// [12, 13, 14, 15]]]] + /// + /// + public static Tensor reverse_v2(Tensor tensor, Tensor axis, string name = "ReverseV2") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["axis"] = axis; + var op = tf.OpDefLib._apply_op_helper("ReverseV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Elementwise computes the bitwise right-shift of x and y. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RightShift'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Performs a logical shift for unsigned integer types, and an arithmetic shift + /// for signed integer types. + /// + /// If y is negative, or greater than or equal to than the width of x in bits + /// the result is implementation defined. + /// + public static Tensor right_shift(Tensor x, Tensor y, string name = "RightShift") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("RightShift", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns element-wise integer closest to x. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Rint'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// If the result is midway between two representable values, + /// the even representable is chosen. + /// For example: + /// + /// + /// rint(-1.5) ==&gt; -2.0 + /// rint(0.5000001) ==&gt; 1.0 + /// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==&gt; [-2., -2., -0., 0., 2., 2., 2.] + /// + /// + public static Tensor rint(Tensor x, string name = "Rint") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Rint", name: name, keywords: dict); + return op.output; + } + + /// + /// Rolls the elements of a tensor along an axis. + /// + /// + /// + /// + /// Dimension must be 0-D or 1-D. shift[i] specifies the number of places by which + /// elements are shifted positively (towards larger indices) along the dimension + /// specified by axis[i]. Negative shifts will roll the elements in the opposite + /// direction. + /// + /// + /// Dimension must be 0-D or 1-D. axis[i] specifies the dimension that the shift + /// shift[i] should occur. If the same axis is referenced more than once, the + /// total shift for that axis will be the sum of all the shifts that belong to that + /// axis. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Roll'. + /// + /// + /// Has the same shape and size as the input. The elements are shifted + /// positively (towards larger indices) by the offsets of shift along the + /// dimensions of axis. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The elements are shifted positively (towards larger indices) by the offset of + /// shift along the dimension of axis. Negative shift values will shift + /// elements in the opposite direction. Elements that roll passed the last position + /// will wrap around to the first and vice versa. Multiple shifts along multiple + /// axes may be specified. + /// + /// For example: + /// + /// + /// # 't' is [0, 1, 2, 3, 4] + /// roll(t, shift=2, axis=0) ==&gt; [3, 4, 0, 1, 2] + /// + /// # shifting along multiple dimensions + /// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + /// roll(t, shift=[1, -2], axis=[0, 1]) ==&gt; [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] + /// + /// # shifting along the same axis multiple times + /// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + /// roll(t, shift=[2, -3], axis=[1, 1]) ==&gt; [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] + /// + /// + public static Tensor roll(Tensor input, Tensor shift, Tensor axis, string name = "Roll") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["shift"] = shift; + dict["axis"] = axis; + var op = tf.OpDefLib._apply_op_helper("Roll", name: name, keywords: dict); + return op.output; + } + + /// + /// Rounds the values of a tensor to the nearest integer, element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Round'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Rounds half to even. Also known as bankers rounding. If you want to round + /// according to the current system rounding mode use std::cint. + /// + public static Tensor round(Tensor x, string name = "Round") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Round", name: name, keywords: dict); + return op.output; + } + + /// + /// Perform batches of RPC requests. + /// + /// + /// 0-D or 1-D. The address (i.e. host_name:port) of the RPC server. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with method and request. + /// + /// + /// 0-D or 1-D. The method address on the RPC server. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with address and request. + /// + /// + /// 0-D or 1-D. Serialized proto strings: the rpc request argument. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with address and method. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Rpc'. + /// + /// + /// RPC protocol to use. Empty string means use the default protocol. + /// Options include 'grpc'. + /// + /// + /// boolean. If true (default), then failures to connect + /// (i.e., the server does not immediately respond) cause an RPC failure. + /// + /// + /// int. If 0 (default), then the kernel will run the RPC + /// request and only time out if the RPC deadline passes or the session times out. + /// If this value is greater than 0, then the op will raise an exception if + /// the RPC takes longer than timeout_in_ms. + /// + /// + /// Same shape as request. Serialized proto strings: the rpc responses. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op asynchronously performs either a single RPC request, or a batch + /// of requests. RPC requests are defined by three main parameters: + /// + /// - address (the host+port or BNS address of the request) + /// - method (the RPC method name for the request) + /// - request (the serialized proto string, or vector of strings, + /// of the RPC request argument). + /// + /// For example, if you have an RPC service running on port localhost:2345, + /// and its interface is configured with the following proto declaration: + /// + /// + /// service MyService { + /// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { + /// } + /// }; + /// + /// + /// then call this op with arguments: + /// + /// + /// address = "localhost:2345" + /// method = "MyService/MyMethod" + /// + /// + /// The request tensor is a string tensor representing serialized MyRequestProto + /// strings; and the output string tensor response will have the same shape + /// and contain (upon successful completion) corresponding serialized + /// MyResponseProto strings. + /// + /// For example, to send a single, empty, MyRequestProto, call + /// this op with request = "". To send 5 **parallel** empty requests, + /// call this op with request = ["", "", "", "", ""]. + /// + /// More generally, one can create a batch of MyRequestProto serialized protos + /// from regular batched tensors using the encode_proto op, and convert + /// the response MyResponseProto serialized protos to batched tensors + /// using the decode_proto op. + /// + /// **NOTE** Working with serialized proto strings is faster than instantiating + /// actual proto objects in memory, so no performance degradation is expected + /// compared to writing custom kernels for this workflow. + /// + /// If the connection fails or the remote worker returns an error + /// status, the op reraises this exception locally. + /// + /// See the TryRpc op if you prefer to handle RPC failures manually in the graph. + /// + public static Tensor rpc(Tensor address, Tensor method, Tensor request, string protocol = null, bool? fail_fast = null, int? timeout_in_ms = null, string name = "Rpc") + { + var dict = new Dictionary(); + dict["address"] = address; + dict["method"] = method; + dict["request"] = request; + if (protocol != null) + dict["protocol"] = protocol; + if (fail_fast.HasValue) + dict["fail_fast"] = fail_fast.Value; + if (timeout_in_ms.HasValue) + dict["timeout_in_ms"] = timeout_in_ms.Value; + var op = tf.OpDefLib._apply_op_helper("Rpc", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes reciprocal of square root of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Rsqrt'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = 1 / \sqrt{x}\\). + /// + public static Tensor rsqrt(Tensor x, string name = "Rsqrt") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Rsqrt", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient for the rsqrt of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'RsqrtGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = dy * -0.5 * y^3, where y = rsqrt(x), and dy + /// is the corresponding input gradient. + /// + public static Tensor rsqrt_grad(Tensor y, Tensor dy, string name = "RsqrtGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("RsqrtGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Generate a single randomly distorted bounding box for an image. + /// + /// + /// 1-D, containing [height, width, channels]. + /// + /// + /// 3-D with shape [batch, N, 4] describing the N bounding boxes + /// associated with the image. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SampleDistortedBoundingBox'. + /// + /// + /// If either seed or seed2 are set to non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a random + /// seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// The cropped area of the image must contain at least this + /// fraction of any bounding box supplied. The value of this parameter should be + /// non-negative. In the case of 0, the cropped area does not need to overlap + /// any of the bounding boxes supplied. + /// + /// + /// The cropped area of the image must have an aspect ratio = + /// width / height within this range. + /// + /// + /// The cropped area of the image must contain a fraction of the + /// supplied image within this range. + /// + /// + /// Number of attempts at generating a cropped region of the image + /// of the specified constraints. After max_attempts failures, return the entire + /// image. + /// + /// + /// Controls behavior if no bounding boxes supplied. + /// If true, assume an implicit bounding box covering the whole input. If false, + /// raise an error. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// begin : 1-D, containing [offset_height, offset_width, 0]. Provide as input to + /// tf.slice. + /// size : 1-D, containing [target_height, target_width, -1]. Provide as input to + /// tf.slice. + /// bboxes : 3-D with shape [1, 1, 4] containing the distorted bounding box. + /// Provide as input to tf.image.draw_bounding_boxes. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Bounding box annotations are often supplied in addition to ground-truth labels + /// in image recognition or object localization tasks. A common technique for + /// training such a system is to randomly distort an image while preserving + /// its content, i.e. *data augmentation*. This Op outputs a randomly distorted + /// localization of an object, i.e. bounding box, given an image_size, + /// bounding_boxes and a series of constraints. + /// + /// The output of this Op is a single bounding box that may be used to crop the + /// original image. The output is returned as 3 tensors: begin, size and + /// bboxes. The first 2 tensors can be fed directly into tf.slice to crop the + /// image. The latter may be supplied to tf.image.draw_bounding_boxes to visualize + /// what the bounding box looks like. + /// + /// Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max]. The + /// bounding box coordinates are floats in [0.0, 1.0] relative to the width and + /// height of the underlying image. + /// + /// For example, + /// + /// + /// # Generate a single distorted bounding box. + /// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( + /// tf.shape(image), + /// bounding_boxes=bounding_boxes) + /// + /// # Draw the bounding box in an image summary. + /// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), + /// bbox_for_draw) + /// tf.summary.image('images_with_box', image_with_box) + /// + /// # Employ the bounding box to distort the image. + /// distorted_image = tf.slice(image, begin, size) + /// + /// + /// Note that if no bounding box information is available, setting + /// use_image_if_no_bounding_boxes = true will assume there is a single implicit + /// bounding box covering the whole image. If use_image_if_no_bounding_boxes is + /// false and no bounding boxes are supplied, an error is raised. + /// + public static (Tensor begin, Tensor size, Tensor bboxes) sample_distorted_bounding_box(Tensor image_size, Tensor bounding_boxes, int? seed = null, int? seed2 = null, float? min_object_covered = null, float[] aspect_ratio_range = null, float[] area_range = null, int? max_attempts = null, bool? use_image_if_no_bounding_boxes = null, string name = "SampleDistortedBoundingBox") + { + var dict = new Dictionary(); + dict["image_size"] = image_size; + dict["bounding_boxes"] = bounding_boxes; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (min_object_covered.HasValue) + dict["min_object_covered"] = min_object_covered.Value; + if (aspect_ratio_range != null) + dict["aspect_ratio_range"] = aspect_ratio_range; + if (area_range != null) + dict["area_range"] = area_range; + if (max_attempts.HasValue) + dict["max_attempts"] = max_attempts.Value; + if (use_image_if_no_bounding_boxes.HasValue) + dict["use_image_if_no_bounding_boxes"] = use_image_if_no_bounding_boxes.Value; + var op = tf.OpDefLib._apply_op_helper("SampleDistortedBoundingBox", name: name, keywords: dict); + int _idx = 0; + var begin = op.outputs[_idx++]; + var size = op.outputs[_idx++]; + var bboxes = op.outputs[_idx++]; + return (begin, size, bboxes); + } + + /// + /// Generate a single randomly distorted bounding box for an image. + /// + /// + /// 1-D, containing [height, width, channels]. + /// + /// + /// 3-D with shape [batch, N, 4] describing the N bounding boxes + /// associated with the image. + /// + /// + /// The cropped area of the image must contain at least this + /// fraction of any bounding box supplied. The value of this parameter should be + /// non-negative. In the case of 0, the cropped area does not need to overlap + /// any of the bounding boxes supplied. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SampleDistortedBoundingBoxV2'. + /// + /// + /// If either seed or seed2 are set to non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a random + /// seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// The cropped area of the image must have an aspect ratio = + /// width / height within this range. + /// + /// + /// The cropped area of the image must contain a fraction of the + /// supplied image within this range. + /// + /// + /// Number of attempts at generating a cropped region of the image + /// of the specified constraints. After max_attempts failures, return the entire + /// image. + /// + /// + /// Controls behavior if no bounding boxes supplied. + /// If true, assume an implicit bounding box covering the whole input. If false, + /// raise an error. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// begin : 1-D, containing [offset_height, offset_width, 0]. Provide as input to + /// tf.slice. + /// size : 1-D, containing [target_height, target_width, -1]. Provide as input to + /// tf.slice. + /// bboxes : 3-D with shape [1, 1, 4] containing the distorted bounding box. + /// Provide as input to tf.image.draw_bounding_boxes. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Bounding box annotations are often supplied in addition to ground-truth labels + /// in image recognition or object localization tasks. A common technique for + /// training such a system is to randomly distort an image while preserving + /// its content, i.e. *data augmentation*. This Op outputs a randomly distorted + /// localization of an object, i.e. bounding box, given an image_size, + /// bounding_boxes and a series of constraints. + /// + /// The output of this Op is a single bounding box that may be used to crop the + /// original image. The output is returned as 3 tensors: begin, size and + /// bboxes. The first 2 tensors can be fed directly into tf.slice to crop the + /// image. The latter may be supplied to tf.image.draw_bounding_boxes to visualize + /// what the bounding box looks like. + /// + /// Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max]. The + /// bounding box coordinates are floats in [0.0, 1.0] relative to the width and + /// height of the underlying image. + /// + /// For example, + /// + /// + /// # Generate a single distorted bounding box. + /// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( + /// tf.shape(image), + /// bounding_boxes=bounding_boxes) + /// + /// # Draw the bounding box in an image summary. + /// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), + /// bbox_for_draw) + /// tf.summary.image('images_with_box', image_with_box) + /// + /// # Employ the bounding box to distort the image. + /// distorted_image = tf.slice(image, begin, size) + /// + /// + /// Note that if no bounding box information is available, setting + /// use_image_if_no_bounding_boxes = true will assume there is a single implicit + /// bounding box covering the whole image. If use_image_if_no_bounding_boxes is + /// false and no bounding boxes are supplied, an error is raised. + /// + public static (Tensor begin, Tensor size, Tensor bboxes) sample_distorted_bounding_box_v2(Tensor image_size, Tensor bounding_boxes, Tensor min_object_covered, int? seed = null, int? seed2 = null, float[] aspect_ratio_range = null, float[] area_range = null, int? max_attempts = null, bool? use_image_if_no_bounding_boxes = null, string name = "SampleDistortedBoundingBoxV2") + { + var dict = new Dictionary(); + dict["image_size"] = image_size; + dict["bounding_boxes"] = bounding_boxes; + dict["min_object_covered"] = min_object_covered; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + if (aspect_ratio_range != null) + dict["aspect_ratio_range"] = aspect_ratio_range; + if (area_range != null) + dict["area_range"] = area_range; + if (max_attempts.HasValue) + dict["max_attempts"] = max_attempts.Value; + if (use_image_if_no_bounding_boxes.HasValue) + dict["use_image_if_no_bounding_boxes"] = use_image_if_no_bounding_boxes.Value; + var op = tf.OpDefLib._apply_op_helper("SampleDistortedBoundingBoxV2", name: name, keywords: dict); + int _idx = 0; + var begin = op.outputs[_idx++]; + var size = op.outputs[_idx++]; + var bboxes = op.outputs[_idx++]; + return (begin, size, bboxes); + } + + /// + /// Saves the input tensors to disk. + /// + /// + /// Must have a single element. The name of the file to which we write + /// the tensor. + /// + /// + /// Shape [N]. The names of the tensors to be saved. + /// + /// + /// N tensors to save. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Save'. + /// + /// + /// Returns the description of the operation + /// + /// + /// The size of tensor_names must match the number of tensors in data. data[i] + /// is written to filename with name tensor_names[i]. + /// + /// See also SaveSlices. + /// + public static Operation save(Tensor filename, Tensor tensor_names, Tensor[] data, string name = "Save") + { + var dict = new Dictionary(); + dict["filename"] = filename; + dict["tensor_names"] = tensor_names; + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("Save", name: name, keywords: dict); + return op; + } + + /// + /// Saves input tensors slices to disk. + /// + /// + /// Must have a single element. The name of the file to which we write the + /// tensor. + /// + /// + /// Shape [N]. The names of the tensors to be saved. + /// + /// + /// Shape [N]. The shapes and slice specifications to use when + /// saving the tensors. + /// + /// + /// N tensors to save. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SaveSlices'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This is like Save except that tensors can be listed in the saved file as being + /// a slice of a larger tensor. shapes_and_slices specifies the shape of the + /// larger tensor and the slice that this tensor covers. shapes_and_slices must + /// have as many elements as tensor_names. + /// + /// Elements of the shapes_and_slices input must either be: + /// + /// * The empty string, in which case the corresponding tensor is + /// saved normally. + /// * A string of the form dim0 dim1 ... dimN-1 slice-spec where the + /// dimI are the dimensions of the larger tensor and slice-spec + /// specifies what part is covered by the tensor to save. + /// + /// slice-spec itself is a :-separated list: slice0:slice1:...:sliceN-1 + /// where each sliceI is either: + /// + /// * The string - meaning that the slice covers all indices of this dimension + /// * start,length where start and length are integers. In that + /// case the slice covers length indices starting at start. + /// + /// See also Save. + /// + public static Operation save_slices(Tensor filename, Tensor tensor_names, Tensor shapes_and_slices, Tensor[] data, string name = "SaveSlices") + { + var dict = new Dictionary(); + dict["filename"] = filename; + dict["tensor_names"] = tensor_names; + dict["shapes_and_slices"] = shapes_and_slices; + dict["data"] = data; + var op = tf.OpDefLib._apply_op_helper("SaveSlices", name: name, keywords: dict); + return op; + } + + /// + /// Saves tensors in V2 checkpoint format. + /// + /// + /// Must have a single element. The prefix of the V2 checkpoint to which we + /// write the tensors. + /// + /// + /// shape {N}. The names of the tensors to be saved. + /// + /// + /// shape {N}. The slice specs of the tensors to be saved. + /// Empty strings indicate that they are non-partitioned tensors. + /// + /// + /// N tensors to save. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SaveV2'. + /// + /// + /// Returns the description of the operation + /// + /// + /// By default, saves the named tensors in full. If the caller wishes to save + /// specific slices of full tensors, "shape_and_slices" should be non-empty strings + /// and correspondingly well-formed. + /// + public static Operation save_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, Tensor[] tensors, string name = "SaveV2") + { + var dict = new Dictionary(); + dict["prefix"] = prefix; + dict["tensor_names"] = tensor_names; + dict["shape_and_slices"] = shape_and_slices; + dict["tensors"] = tensors; + var op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, keywords: dict); + return op; + } + + public static Tensor scale_and_translate(Tensor images_t, Tensor new_size, Tensor[] scale, Tensor zeroes, string kernel_type, bool antialias) + { + throw new NotImplementedException("scale_and_translate"); + } + + /// + /// Outputs a Summary protocol buffer with scalar values. + /// + /// + /// Tags for the summary. + /// + /// + /// Same shape as tags. Values for the summary. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScalarSummary'. + /// + /// + /// Scalar. Serialized Summary protocol buffer. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input tags and values must have the same shape. The generated summary + /// has a summary value for each tag-value pair in tags and values. + /// + public static Tensor scalar_summary(Tensor tags, Tensor values, string name = "ScalarSummary") + { + var dict = new Dictionary(); + dict["tags"] = tags; + dict["values"] = values; + var op = tf.OpDefLib._apply_op_helper("ScalarSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Adds sparse updates to a variable reference. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterAdd'. + /// + /// + /// If True, the addition will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] += updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] += updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions add. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor scatter_add(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterAdd") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Divides a variable reference by sparse updates. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of values that ref is divided by. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterDiv'. + /// + /// + /// If True, the operation will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// + /// # Scalar indices + /// ref[indices, ...] /= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] /= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + /// + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions divide. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + public static Tensor scatter_div(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterDiv") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterDiv", name: name, keywords: dict); + return op.output; + } + + /// + /// Reduces sparse updates into a variable reference using the max operation. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to reduce into ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterMax'. + /// + /// + /// If True, the update will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = max(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions combine. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor scatter_max(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterMax") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Reduces sparse updates into a variable reference using the min operation. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to reduce into ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterMin'. + /// + /// + /// If True, the update will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = min(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions combine. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor scatter_min(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterMin") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterMin", name: name, keywords: dict); + return op.output; + } + + /// + /// Multiplies sparse updates into a variable reference. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to multiply to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterMul'. + /// + /// + /// If True, the operation will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// + /// # Scalar indices + /// ref[indices, ...] *= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] *= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + /// + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their contributions multiply. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + public static Tensor scatter_mul(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterMul") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Scatter updates into a new tensor according to indices. + /// + /// + /// Index tensor. + /// + /// + /// Updates to scatter into output. + /// + /// + /// 1-D. The shape of the resulting tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterNd'. + /// + /// + /// A new tensor with the given shape and updates applied according + /// to the indices. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Creates a new tensor by applying sparse updates to individual values or + /// slices within a tensor (initially zero for numeric, empty for string) of + /// the given shape according to indices. This operator is the inverse of the + /// tf.gather_nd operator which extracts values or slices from a given tensor. + /// + /// If indices contains duplicates, then their updates are accumulated (summed). + /// + /// **WARNING**: The order in which updates are applied is nondeterministic, so the + /// output will be nondeterministic if indices contains duplicates -- because + /// of some numerical approximation issues, numbers summed in different order + /// may yield different results. + /// + /// indices is an integer tensor containing indices into a new tensor of shape + /// shape. The last dimension of indices can be at most the rank of shape: + /// + /// indices.shape[-1] &lt;= shape.rank + /// + /// The last dimension of indices corresponds to indices into elements + /// (if indices.shape[-1] = shape.rank) or slices + /// (if indices.shape[-1] &lt; shape.rank) along dimension indices.shape[-1] of + /// shape. updates is a tensor with shape + /// + /// indices.shape[:-1] + shape[indices.shape[-1]:] + /// + /// The simplest form of scatter is to insert individual elements in a tensor by + /// index. For example, say we want to insert 4 scattered elements in a rank-1 + /// tensor with 8 elements. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt&gt; + /// &lt;/div&gt; + /// + /// In Python, this scatter operation would look like this: + /// + /// + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// shape = tf.constant([8]) + /// scatter = tf.scatter_nd(indices, updates, shape) + /// with tf.Session() as sess: + /// print(sess.run(scatter)) + /// + /// + /// The resulting tensor would look like this: + /// + /// [0, 11, 0, 10, 9, 0, 0, 12] + /// + /// We can also, insert entire slices of a higher rank tensor all at once. For + /// example, if we wanted to insert two slices in the first dimension of a + /// rank-3 tensor with two matrices of new values. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd2.png" alt&gt; + /// &lt;/div&gt; + /// + /// In Python, this scatter operation would look like this: + /// + /// + /// indices = tf.constant([[0], [2]]) + /// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], + /// [7, 7, 7, 7], [8, 8, 8, 8]]]) + /// shape = tf.constant([4, 4, 4]) + /// scatter = tf.scatter_nd(indices, updates, shape) + /// with tf.Session() as sess: + /// print(sess.run(scatter)) + /// + /// + /// The resulting tensor would look like this: + /// + /// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + /// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + /// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + /// + /// Note that on CPU, if an out of bound index is found, an error is returned. + /// On GPU, if an out of bound index is found, the index is ignored. + /// + public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor shape, string name = "ScatterNd") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["updates"] = updates; + dict["shape"] = shape; + var op = tf.OpDefLib._apply_op_helper("ScatterNd", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies sparse addition between updates and individual values or slices + /// + /// + /// A mutable Tensor. Should be from a Variable node. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into ref. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of updated values + /// to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterNdAdd'. + /// + /// + /// An optional bool. Defaults to True. If True, the assignment will + /// be protected by a lock; otherwise the behavior is undefined, + /// but may exhibit less contention. + /// + /// + /// Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// within a given variable according to indices. + /// + /// ref is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into ref. + /// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or slices (if K &lt; P) along the Kth + /// dimension of ref. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// $$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$ + /// + /// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 + /// elements. In Python, that addition would look like this: + /// + /// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// add = tf.scatter_nd_add(ref, indices, updates) + /// with tf.Session() as sess: + /// print sess.run(add) + /// + /// The resulting update to ref would look like this: + /// + /// [1, 13, 3, 14, 14, 6, 7, 20] + /// + /// See tf.scatter_nd for more details about how to make updates to + /// slices. + /// + public static Tensor scatter_nd_add(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterNdAdd") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterNdAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies sparse addition to input using individual values or slices + /// + /// + /// A Tensor. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into input. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of updated values + /// to add to input. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterNdNonAliasingAdd'. + /// + /// + /// A Tensor with the same shape as input, containing values of input + /// updated with updates. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// from updates according to indices indices. The updates are non-aliasing: + /// input is only modified in-place if no other operations will use it. + /// Otherwise, a copy of input is made. This operation has a gradient with + /// respect to both input and updates. + /// + /// input is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into input. + /// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or (P-K)-dimensional slices + /// (if K &lt; P) along the Kth dimension of input. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ + /// + /// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 + /// elements. In Python, that addition would look like this: + /// + /// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) + /// with tf.Session() as sess: + /// print(sess.run(output)) + /// + /// The resulting value output would look like this: + /// + /// [1, 13, 3, 14, 14, 6, 7, 20] + /// + /// See tf.scatter_nd for more details about how to make updates to slices. + /// + public static Tensor scatter_nd_non_aliasing_add(Tensor input, Tensor indices, Tensor updates, string name = "ScatterNdNonAliasingAdd") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["indices"] = indices; + dict["updates"] = updates; + var op = tf.OpDefLib._apply_op_helper("ScatterNdNonAliasingAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies sparse subtraction between updates and individual values or slices + /// + /// + /// A mutable Tensor. Should be from a Variable node. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into ref. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of updated values + /// to subtract from ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterNdSub'. + /// + /// + /// An optional bool. Defaults to True. If True, the assignment will + /// be protected by a lock; otherwise the behavior is undefined, + /// but may exhibit less contention. + /// + /// + /// Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// within a given variable according to indices. + /// + /// ref is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into ref. + /// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or slices (if K &lt; P) along the Kth + /// dimension of ref. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// $$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$ + /// + /// For example, say we want to subtract 4 scattered elements from a rank-1 tensor + /// with 8 elements. In Python, that subtraction would look like this: + /// + /// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1], [7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// sub = tf.scatter_nd_sub(ref, indices, updates) + /// with tf.Session() as sess: + /// print sess.run(sub) + /// + /// The resulting update to ref would look like this: + /// + /// [1, -9, 3, -6, -4, 6, 7, -4] + /// + /// See tf.scatter_nd for more details about how to make updates to + /// slices. + /// + public static Tensor scatter_nd_sub(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterNdSub") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterNdSub", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies sparse updates to individual values or slices within a given + /// + /// + /// A mutable Tensor. Should be from a Variable node. + /// + /// + /// A Tensor. Must be one of the following types: int32, int64. + /// A tensor of indices into ref. + /// + /// + /// A Tensor. Must have the same type as ref. A tensor of updated + /// values to add to ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterNdUpdate'. + /// + /// + /// An optional bool. Defaults to True. If True, the assignment will + /// be protected by a lock; otherwise the behavior is undefined, + /// but may exhibit less contention. + /// + /// + /// Same as ref. Returned as a convenience for operations that want to + /// use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// variable according to indices. + /// + /// ref is a Tensor with rank P and indices is a Tensor of rank Q. + /// + /// indices must be integer tensor, containing indices into ref. + /// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where 0 &lt; K &lt;= P. + /// + /// The innermost dimension of indices (with length K) corresponds to + /// indices into elements (if K = P) or slices (if K &lt; P) along the Kth + /// dimension of ref. + /// + /// updates is Tensor of rank Q-1+P-K with shape: + /// + /// $$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$ + /// + /// For example, say we want to update 4 scattered elements to a rank-1 tensor to + /// 8 elements. In Python, that update would look like this: + /// + /// + /// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + /// indices = tf.constant([[4], [3], [1] ,[7]]) + /// updates = tf.constant([9, 10, 11, 12]) + /// update = tf.scatter_nd_update(ref, indices, updates) + /// with tf.Session() as sess: + /// print sess.run(update) + /// + /// + /// The resulting update to ref would look like this: + /// + /// [1, 11, 3, 10, 9, 6, 7, 12] + /// + /// See tf.scatter_nd for more details about how to make updates to + /// slices. + /// + /// See also tf.scatter_update and tf.batch_scatter_update. + /// + public static Tensor scatter_nd_update(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterNdUpdate") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterNdUpdate", name: name, keywords: dict); + return op.output; + } + + /// + /// Subtracts sparse updates to a variable reference. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to subtract from ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterSub'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// + /// # Scalar indices + /// ref[indices, ...] -= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] -= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + /// + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// Duplicate entries are handled correctly: if multiple indices reference + /// the same location, their (negated) contributions add. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterSub.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor scatter_sub(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterSub") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterSub", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies sparse updates to a variable reference. + /// + /// + /// Should be from a Variable node. + /// + /// + /// A tensor of indices into the first dimension of ref. + /// + /// + /// A tensor of updated values to store in ref. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ScatterUpdate'. + /// + /// + /// If True, the assignment will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// = Same as ref. Returned as a convenience for operations that want + /// to use the updated values after the update is done. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation computes + /// + /// + /// # Scalar indices + /// ref[indices, ...] = updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + /// + /// + /// This operation outputs ref after the update is done. + /// This makes it easier to chain operations that need to use the reset value. + /// + /// If values in ref is to be updated more than once, because there are + /// duplicate entries in indices, the order at which the updates happen + /// for each value is undefined. + /// + /// Requires updates.shape = indices.shape + ref.shape[1:] or updates.shape = []. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt&gt; + /// &lt;/div&gt; + /// + /// See also tf.batch_scatter_update and tf.scatter_nd_update. + /// + public static Tensor scatter_update(Tensor referecne, Tensor indices, Tensor updates, bool? use_locking = null, string name = "ScatterUpdate") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["indices"] = indices; + dict["updates"] = updates; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("ScatterUpdate", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes fingerprints of the input strings. + /// + /// + /// vector of strings to compute fingerprints on. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SdcaFprint'. + /// + /// + /// a (N,2) shaped matrix where N is the number of elements in the input + /// vector. Each row contains the low and high parts of the fingerprint. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sdca_fprint(Tensor input, string name = "SdcaFprint") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("SdcaFprint", name: name, keywords: dict); + return op.output; + } + + /// + /// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for + /// + /// + /// a list of vectors which contain example indices. + /// + /// + /// a list of vectors which contain feature indices. + /// + /// + /// a list of vectors which contains feature value + /// associated with each feature group. + /// + /// + /// a list of matrices which contains the dense feature values. + /// + /// + /// a vector which contains the weight associated with each + /// example. + /// + /// + /// a vector which contains the label/target associated with each + /// example. + /// + /// + /// a list of vectors where each value is the indices which has + /// corresponding weights in sparse_weights. This field maybe omitted for the + /// dense approach. + /// + /// + /// a list of vectors where each value is the weight associated with + /// a sparse feature group. + /// + /// + /// a list of vectors where the values are the weights associated + /// with a dense feature group. + /// + /// + /// a list of vectors containing the example state data. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SdcaOptimizer'. + /// + /// + /// Optional argument + /// Type of the primal loss. Currently SdcaSolver supports logistic, + /// squared and hinge losses. + /// + /// + /// Optional argument + /// Symmetric l1 regularization strength. + /// + /// + /// Optional argument + /// Symmetric l2 regularization strength. + /// + /// + /// Optional argument + /// Number of partitions of the global loss function. + /// + /// + /// Optional argument + /// Number of iterations per mini-batch. + /// + /// + /// Whether to use Adaptive SDCA for the inner loop. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// out_example_state_data : a list of vectors containing the updated example state + /// data. + /// out_delta_sparse_weights : a list of vectors where each value is the delta + /// weights associated with a sparse feature group. + /// out_delta_dense_weights : a list of vectors where the values are the delta + /// weights associated with a dense feature group. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// linear models with L1 + L2 regularization. As global optimization objective is + /// strongly-convex, the optimizer optimizes the dual objective at each step. The + /// optimizer applies each update one example at a time. Examples are sampled + /// uniformly, and the optimizer is learning rate free and enjoys linear convergence + /// rate. + /// + /// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).&lt;br&gt; + /// Shai Shalev-Shwartz, Tong Zhang. 2012 + /// + /// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ + /// + /// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).&lt;br&gt; + /// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, + /// Peter Richtarik, Martin Takac. 2015 + /// + /// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).&lt;br&gt; + /// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 + /// + public static (Tensor out_example_state_data, Tensor[] out_delta_sparse_weights, Tensor[] out_delta_dense_weights) sdca_optimizer(Tensor[] sparse_example_indices, Tensor[] sparse_feature_indices, Tensor[] sparse_feature_values, Tensor[] dense_features, Tensor example_weights, Tensor example_labels, Tensor[] sparse_indices, Tensor[] sparse_weights, Tensor[] dense_weights, Tensor example_state_data, string loss_type, float l1, float l2, int num_loss_partitions, int num_inner_iterations, bool? adaptative = null, string name = "SdcaOptimizer") + { + var dict = new Dictionary(); + dict["sparse_example_indices"] = sparse_example_indices; + dict["sparse_feature_indices"] = sparse_feature_indices; + dict["sparse_feature_values"] = sparse_feature_values; + dict["dense_features"] = dense_features; + dict["example_weights"] = example_weights; + dict["example_labels"] = example_labels; + dict["sparse_indices"] = sparse_indices; + dict["sparse_weights"] = sparse_weights; + dict["dense_weights"] = dense_weights; + dict["example_state_data"] = example_state_data; + dict["loss_type"] = loss_type; + dict["l1"] = l1; + dict["l2"] = l2; + dict["num_loss_partitions"] = num_loss_partitions; + dict["num_inner_iterations"] = num_inner_iterations; + if (adaptative.HasValue) + dict["adaptative"] = adaptative.Value; + var op = tf.OpDefLib._apply_op_helper("SdcaOptimizer", name: name, keywords: dict); + int _idx = 0; + var out_example_state_data = op.outputs[_idx++]; + var out_delta_sparse_weights = Enumerable.Range(0, op.OutputListLength("out_delta_sparse_weights")).Select(_ => op.outputs[_idx++]).ToArray(); + var out_delta_dense_weights = Enumerable.Range(0, op.OutputListLength("out_delta_dense_weights")).Select(_ => op.outputs[_idx++]).ToArray(); + return (out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights); + } + + /// + /// Applies L1 regularization shrink step on the parameters. + /// + /// + /// a list of vectors where each value is the weight associated with a + /// feature group. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SdcaShrinkL1'. + /// + /// + /// Optional argument + /// Symmetric l1 regularization strength. + /// + /// + /// Optional argument + /// Symmetric l2 regularization strength. Should be a positive float. + /// + /// + /// Returns the description of the operation + /// + public static Operation sdca_shrink_l1(Tensor[] weights, float l1, float l2, string name = "SdcaShrinkL1") + { + var dict = new Dictionary(); + dict["weights"] = weights; + dict["l1"] = l1; + dict["l2"] = l2; + var op = tf.OpDefLib._apply_op_helper("SdcaShrinkL1", name: name, keywords: dict); + return op; + } + + /// + /// Computes the maximum along segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor whose size is equal to the size of data's + /// first dimension. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SegmentMax'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output_i = \max_j(data_j)\\) where max is over j such + /// that segment_ids[j] == i. + /// + /// If the max is empty for a given segment ID i, output[i] = 0. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor segment_max(Tensor data, Tensor segment_ids, string name = "SegmentMax") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SegmentMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the mean along segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor whose size is equal to the size of data's + /// first dimension. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SegmentMean'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output_i = \frac{\sum_j data_j}{N}\\) where mean is + /// over j such that segment_ids[j] == i and N is the total number of + /// values summed. + /// + /// If the mean is empty for a given segment ID i, output[i] = 0. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/SegmentMean.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor segment_mean(Tensor data, Tensor segment_ids, string name = "SegmentMean") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SegmentMean", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the minimum along segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor whose size is equal to the size of data's + /// first dimension. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SegmentMin'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output_i = \min_j(data_j)\\) where min is over j such + /// that segment_ids[j] == i. + /// + /// If the min is empty for a given segment ID i, output[i] = 0. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor segment_min(Tensor data, Tensor segment_ids, string name = "SegmentMin") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SegmentMin", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the product along segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor whose size is equal to the size of data's + /// first dimension. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SegmentProd'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output_i = \prod_j data_j\\) where the product is over j such + /// that segment_ids[j] == i. + /// + /// If the product is empty for a given segment ID i, output[i] = 1. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/SegmentProd.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor segment_prod(Tensor data, Tensor segment_ids, string name = "SegmentProd") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SegmentProd", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor whose size is equal to the size of data's + /// first dimension. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SegmentSum'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output_i = \sum_j data_j\\) where sum is over j such + /// that segment_ids[j] == i. + /// + /// If the sum is empty for a given segment ID i, output[i] = 0. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/SegmentSum.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor segment_sum(Tensor data, Tensor segment_ids, string name = "SegmentSum") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SegmentSum", name: name, keywords: dict); + return op.output; + } + + /// + /// Selects elements from x or y, depending on condition. + /// + /// + /// + /// + /// = A Tensor which may have the same shape as condition. + /// If condition is rank 1, x may have higher rank, + /// but its first dimension must match the size of condition. + /// + /// + /// = A Tensor with the same type and shape as x. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Select'. + /// + /// + /// = A Tensor with the same type and shape as x and y. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The x, and y tensors must all have the same shape, and the + /// output will also have that shape. + /// + /// The condition tensor must be a scalar if x and y are scalars. + /// If x and y are vectors or higher rank, then condition must be either a + /// scalar, a vector with size matching the first dimension of x, or must have + /// the same shape as x. + /// + /// The condition tensor acts as a mask that chooses, based on the value at each + /// element, whether the corresponding element / row in the output should be + /// taken from x (if true) or y (if false). + /// + /// If condition is a vector and x and y are higher rank matrices, then + /// it chooses which row (outer dimension) to copy from x and y. + /// If condition has the same shape as x and y, then it chooses which + /// element to copy from x and y. + /// + /// For example: + /// + /// + /// # 'condition' tensor is [[True, False] + /// # [False, True]] + /// # 't' is [[1, 2], + /// # [3, 4]] + /// # 'e' is [[5, 6], + /// # [7, 8]] + /// select(condition, t, e) # =&gt; [[1, 6], [7, 4]] + /// + /// + /// # 'condition' tensor is [True, False] + /// # 't' is [[1, 2], + /// # [3, 4]] + /// # 'e' is [[5, 6], + /// # [7, 8]] + /// select(condition, t, e) ==&gt; [[1, 2], + /// [7, 8]] + /// + /// + /// + public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "Select") + { + var dict = new Dictionary(); + dict["condition"] = condition; + dict["t"] = t; + dict["e"] = e; + var op = tf.OpDefLib._apply_op_helper("Select", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. + /// + /// + /// Shape is [..., M, M]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SelfAdjointEig'. + /// + /// + /// Shape is [..., M+1, M]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The input is a tensor of shape [..., M, M] whose inner-most 2 dimensions + /// form square matrices, with the same constraints as the single matrix + /// SelfAdjointEig. + /// + /// The result is a [..., M+1, M] matrix with [..., 0,:] containing the + /// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues + /// are sorted in non-decreasing order. + /// + public static Tensor self_adjoint_eig(Tensor input, string name = "SelfAdjointEig") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("SelfAdjointEig", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the eigen decomposition of one or more square self-adjoint matrices. + /// + /// + /// Tensor input of shape [N, N]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SelfAdjointEigV2'. + /// + /// + /// If True then eigenvectors will be computed and returned in v. + /// Otherwise, only the eigenvalues will be computed. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// e : Eigenvalues. Shape is [N]. + /// v : Eigenvectors. Shape is [N, N]. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in + /// input such that input[..., :, :] = v[..., :, :] * diag(e[..., :]). The eigenvalues + /// are sorted in non-decreasing order. + /// + /// + /// # a is a tensor. + /// # e is a tensor of eigenvalues. + /// # v is a tensor of eigenvectors. + /// e, v = self_adjoint_eig(a) + /// e = self_adjoint_eig(a, compute_v=False) + /// + /// + public static (Tensor e, Tensor v) self_adjoint_eig_v2(Tensor input, bool? compute_v = null, string name = "SelfAdjointEigV2") + { + var dict = new Dictionary(); + dict["input"] = input; + if (compute_v.HasValue) + dict["compute_v"] = compute_v.Value; + var op = tf.OpDefLib._apply_op_helper("SelfAdjointEigV2", name: name, keywords: dict); + int _idx = 0; + var e = op.outputs[_idx++]; + var v = op.outputs[_idx++]; + return (e, v); + } + + /// + /// Computes scaled exponential linear: scale * alpha * (exp(features) - 1) + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Selu'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// if &lt; 0, scale * features otherwise. + /// + /// To be used together with + /// initializer = tf.variance_scaling_initializer(scale=1.0, mode='fan_in'). + /// For correct dropout, use tf.contrib.nn.alpha_dropout. + /// + /// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + /// + public static Tensor selu(Tensor features, string name = "Selu") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Selu", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients for the scaled exponential linear (Selu) operation. + /// + /// + /// The backpropagated gradients to the corresponding Selu operation. + /// + /// + /// The outputs of the corresponding Selu operation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SeluGrad'. + /// + /// + /// The gradients: gradients * (outputs + scale * alpha) + /// if outputs &lt; 0, scale * gradients otherwise. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor selu_grad(Tensor gradients, Tensor outputs, string name = "SeluGrad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["outputs"] = outputs; + var op = tf.OpDefLib._apply_op_helper("SeluGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts the given resource_handle representing an iterator to a variant tensor. + /// + /// + /// A handle to an iterator resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SerializeIterator'. + /// + /// + /// A variant tensor storing the state of the iterator contained in the + /// resource. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor serialize_iterator(Tensor resource_handle, string name = "SerializeIterator") + { + var dict = new Dictionary(); + dict["resource_handle"] = resource_handle; + var op = tf.OpDefLib._apply_op_helper("SerializeIterator", name: name, keywords: dict); + return op.output; + } + + /// + /// Serialize an N-minibatch SparseTensor into an [N, 3] Tensor object. + /// + /// + /// 2-D. The indices of the minibatch SparseTensor. + /// + /// + /// 1-D. The values of the minibatch SparseTensor. + /// + /// + /// 1-D. The shape of the minibatch SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SerializeManySparse'. + /// + /// + /// The dtype to use for serialization; the supported types are string + /// (default) and variant. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The SparseTensor must have rank R greater than 1, and the first dimension + /// is treated as the minibatch dimension. Elements of the SparseTensor + /// must be sorted in increasing order of this first dimension. The serialized + /// SparseTensor objects going into each row of serialized_sparse will have + /// rank R-1. + /// + /// The minibatch size N is extracted from sparse_shape[0]. + /// + public static Tensor serialize_many_sparse(Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape, TF_DataType? out_type = null, string name = "SerializeManySparse") + { + var dict = new Dictionary(); + dict["sparse_indices"] = sparse_indices; + dict["sparse_values"] = sparse_values; + dict["sparse_shape"] = sparse_shape; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("SerializeManySparse", name: name, keywords: dict); + return op.output; + } + + /// + /// Serialize a SparseTensor into a [3] Tensor object. + /// + /// + /// 2-D. The indices of the SparseTensor. + /// + /// + /// 1-D. The values of the SparseTensor. + /// + /// + /// 1-D. The shape of the SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SerializeSparse'. + /// + /// + /// The dtype to use for serialization; the supported types are string + /// (default) and variant. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor serialize_sparse(Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape, TF_DataType? out_type = null, string name = "SerializeSparse") + { + var dict = new Dictionary(); + dict["sparse_indices"] = sparse_indices; + dict["sparse_values"] = sparse_values; + dict["sparse_shape"] = sparse_shape; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("SerializeSparse", name: name, keywords: dict); + return op.output; + } + + /// + /// Transforms a Tensor into a serialized TensorProto proto. + /// + /// + /// A Tensor of type T. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SerializeTensor'. + /// + /// + /// A serialized TensorProto proto of the input tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor serialize_tensor(Tensor tensor, string name = "SerializeTensor") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + var op = tf.OpDefLib._apply_op_helper("SerializeTensor", name: name, keywords: dict); + return op.output; + } + + /// + /// Number of unique elements along last dimension of input set. + /// + /// + /// 2D Tensor, indices of a SparseTensor. + /// + /// + /// 1D Tensor, values of a SparseTensor. + /// + /// + /// 1D Tensor, shape of a SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SetSize'. + /// + /// + /// + /// + /// For set ranked n, this is a Tensor with rank n-1, and the same 1st + /// n-1 dimensions as set. Each value is the number of unique elements in + /// the corresponding [0...n-1] dimension of set. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Input set is a SparseTensor represented by set_indices, set_values, + /// and set_shape. The last dimension contains values in a set, duplicates are + /// allowed but ignored. + /// + /// If validate_indices is True, this op validates the order and range of set + /// indices. + /// + public static Tensor set_size(Tensor set_indices, Tensor set_values, Tensor set_shape, bool? validate_indices = null, string name = "SetSize") + { + var dict = new Dictionary(); + dict["set_indices"] = set_indices; + dict["set_values"] = set_values; + dict["set_shape"] = set_shape; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("SetSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the shape of a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Shape'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns a 1-D integer tensor representing the shape of input. + /// + /// For example: + /// + /// + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// shape(t) ==&gt; [2, 2, 3] + /// + /// + public static Tensor shape(Tensor input, TF_DataType? out_type = null, string name = "Shape") + { + var dict = new Dictionary(); + dict["input"] = input; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("Shape", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns shape of tensors. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShapeN'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns N 1-D integer tensors representing shape of input[i]s. + /// + public static Tensor[] shape_n(Tensor[] input, TF_DataType? out_type = null, string name = "ShapeN") + { + var dict = new Dictionary(); + dict["input"] = input; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("ShapeN", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// Generate a sharded filename. The filename is printf formatted as + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShardedFilename'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// %s-%05d-of-%05d, basename, shard, num_shards. + /// + public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "ShardedFilename", name, basename, shard, num_shards)); + return result[0]; + } + var dict = new Dictionary(); + dict["basename"] = basename; + dict["shard"] = shard; + dict["num_shards"] = num_shards; + var op = tf.OpDefLib._apply_op_helper("ShardedFilename", name: name, keywords: dict); + return op.output; + } + + /// + /// Generate a glob pattern matching all sharded file names. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShardedFilespec'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sharded_filespec(Tensor basename, Tensor num_shards, string name = "ShardedFilespec") + { + var dict = new Dictionary(); + dict["basename"] = basename; + dict["num_shards"] = num_shards; + var op = tf.OpDefLib._apply_op_helper("ShardedFilespec", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that shuffles and repeats elements from input_dataset + /// + /// + /// + /// + /// The number of output elements to buffer in an iterator over + /// this dataset. Compare with the min_after_dequeue attr when creating a + /// RandomShuffleQueue. + /// + /// + /// A scalar seed for the random number generator. If either seed or + /// seed2 is set to be non-zero, the random number generator is seeded + /// by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second scalar seed to avoid seed collision. + /// + /// + /// A scalar representing the number of times the underlying dataset + /// should be repeated. The default is -1, which results in infinite repetition. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShuffleAndRepeatDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// pseudorandomly. + /// + public static Tensor shuffle_and_repeat_dataset(Tensor input_dataset, Tensor buffer_size, Tensor seed, Tensor seed2, Tensor count, TF_DataType[] output_types, Shape[] output_shapes, string name = "ShuffleAndRepeatDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["buffer_size"] = buffer_size; + dict["seed"] = seed; + dict["seed2"] = seed2; + dict["count"] = count; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("ShuffleAndRepeatDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that shuffles elements from input_dataset pseudorandomly. + /// + /// + /// + /// + /// The number of output elements to buffer in an iterator over + /// this dataset. Compare with the min_after_dequeue attr when creating a + /// RandomShuffleQueue. + /// + /// + /// A scalar seed for the random number generator. If either seed or + /// seed2 is set to be non-zero, the random number generator is seeded + /// by the given seed. Otherwise, a random seed is used. + /// + /// + /// A second scalar seed to avoid seed collision. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShuffleDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// If true, each iterator over this dataset will be given + /// a different pseudorandomly generated seed, based on a sequence seeded by the + /// seed and seed2 inputs. If false, each iterator will be given the same + /// seed, and repeated iteration over this dataset will yield the exact same + /// sequence of results. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor shuffle_dataset(Tensor input_dataset, Tensor buffer_size, Tensor seed, Tensor seed2, TF_DataType[] output_types, Shape[] output_shapes, bool? reshuffle_each_iteration = null, string name = "ShuffleDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["buffer_size"] = buffer_size; + dict["seed"] = seed; + dict["seed2"] = seed2; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + if (reshuffle_each_iteration.HasValue) + dict["reshuffle_each_iteration"] = reshuffle_each_iteration.Value; + var op = tf.OpDefLib._apply_op_helper("ShuffleDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// An op that shuts down a running distributed TPU system. The Op returns + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ShutdownDistributedTPU'. + /// + /// + /// Returns the description of the operation + /// + /// + /// an error if no system is running. + /// + public static Operation shutdown_distributed_t_p_u(string name = "ShutdownDistributedTPU") + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("ShutdownDistributedTPU", name: name, keywords: dict); + return op; + } + + /// + /// Computes sigmoid of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sigmoid'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, y = 1 / (1 + exp(-x)). + /// + public static Tensor sigmoid(Tensor x, string name = "Sigmoid") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient of the sigmoid of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SigmoidGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = dy * y * (1 - y), where y = sigmoid(x), and + /// dy is the corresponding input gradient. + /// + public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns an element-wise indication of the sign of a number. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sign'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// y = sign(x) = -1 if x &lt; 0; 0 if x == 0; 1 if x &gt; 0. + /// + /// For complex numbers, y = sign(x) = x / |x| if x != 0, otherwise y = 0. + /// + public static Tensor sign(Tensor x, string name = "Sign") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Sign", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes sin of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sin'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sin(Tensor x, string name = "Sin") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Sin", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes hyperbolic sine of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sinh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sinh(Tensor x, string name = "Sinh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Sinh", name: name, keywords: dict); + return op.output; + } + + /// + /// A placeholder for input pipeline graph optimizations. + /// + /// + /// A variant tensor representing the input dataset. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SinkDataset'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// A placeholder for input pipeline graph optimizations. + /// + public static Tensor sink_dataset(Tensor input_dataset, string name = "SinkDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + var op = tf.OpDefLib._apply_op_helper("SinkDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the size of a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Size'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns an integer representing the number of elements in + /// input. + /// + /// For example: + /// + /// + /// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]] + /// size(t) ==&gt; 12 + /// + /// + public static Tensor size(Tensor input, TF_DataType? out_type = null, string name = "Size") + { + var dict = new Dictionary(); + dict["input"] = input; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("Size", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that skips count elements from the input_dataset. + /// + /// + /// + /// + /// A scalar representing the number of elements from the input_dataset + /// that should be skipped. If count is -1, skips everything. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SkipDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor skip_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, Shape[] output_shapes, string name = "SkipDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["count"] = count; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("SkipDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Parses a text file and creates a batch of examples. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Skipgram'. + /// + /// + /// Optional argument + /// The corpus's text file name. + /// + /// + /// Optional argument + /// The size of produced batch. + /// + /// + /// The number of words to predict to the left and right of the target. + /// + /// + /// The minimum number of word occurrences for it to be included in the + /// vocabulary. + /// + /// + /// Threshold for word occurrence. Words that appear with higher + /// frequency will be randomly down-sampled. Set to 0 to disable. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// vocab_word : A vector of words in the corpus. + /// vocab_freq : Frequencies of words. Sorted in the non-ascending order. + /// words_per_epoch : Number of words per epoch in the data file. + /// current_epoch : The current epoch number. + /// total_words_processed : The total number of words processed so far. + /// examples : A vector of word ids. + /// labels : A vector of word ids. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor vocab_word, Tensor vocab_freq, Tensor words_per_epoch, Tensor current_epoch, Tensor total_words_processed, Tensor examples, Tensor labels) skipgram(string filename, int batch_size, int? window_size = null, int? min_count = null, float? subsample = null, string name = "Skipgram") + { + var dict = new Dictionary(); + dict["filename"] = filename; + dict["batch_size"] = batch_size; + if (window_size.HasValue) + dict["window_size"] = window_size.Value; + if (min_count.HasValue) + dict["min_count"] = min_count.Value; + if (subsample.HasValue) + dict["subsample"] = subsample.Value; + var op = tf.OpDefLib._apply_op_helper("Skipgram", name: name, keywords: dict); + int _idx = 0; + var vocab_word = op.outputs[_idx++]; + var vocab_freq = op.outputs[_idx++]; + var words_per_epoch = op.outputs[_idx++]; + var current_epoch = op.outputs[_idx++]; + var total_words_processed = op.outputs[_idx++]; + var examples = op.outputs[_idx++]; + var labels = op.outputs[_idx++]; + return (vocab_word, vocab_freq, words_per_epoch, current_epoch, total_words_processed, examples, labels); + } + + /// + /// Return a slice from 'input'. + /// + /// + /// + /// + /// begin[i] specifies the offset into the 'i'th dimension of + /// 'input' to slice from. + /// + /// + /// size[i] specifies the number of elements of the 'i'th dimension + /// of 'input' to slice. If size[i] is -1, all remaining elements in dimension + /// i are included in the slice (i.e. this is equivalent to setting + /// size[i] = input.dim_size(i) - begin[i]). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Slice'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The output tensor is a tensor with dimensions described by 'size' + /// whose values are extracted from 'input' starting at the offsets in + /// 'begin'. + /// + /// *Requirements*: + /// 0 &lt;= begin[i] &lt;= begin[i] + size[i] &lt;= Di for i in [0, n) + /// + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = "Slice") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["begin"] = begin; + dict["size"] = size; + var op = tf.OpDefLib._apply_op_helper("Slice", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that passes a sliding window over input_dataset. + /// + /// + /// + /// + /// A scalar representing the number of elements in the + /// sliding window. + /// + /// + /// A scalar representing the steps moving the sliding window + /// forward in one iteration. It must be positive. + /// + /// + /// A scalar representing the stride of the input elements of the sliding window. + /// It must be positive. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SlideDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor slide_dataset(Tensor input_dataset, Tensor window_size, Tensor window_shift, Tensor window_stride, TF_DataType[] output_types, Shape[] output_shapes, string name = "SlideDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["window_size"] = window_size; + dict["window_shift"] = window_shift; + dict["window_stride"] = window_stride; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("SlideDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a copy of the input tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Snapshot'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor snapshot(Tensor input, string name = "Snapshot") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("Snapshot", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softmax activations. + /// + /// + /// 2-D with shape [batch_size, num_classes]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Softmax'. + /// + /// + /// Same shape as logits. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For each batch i and class j we have + /// + /// $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ + /// + public static Tensor softmax(Tensor logits, string name = "Softmax") + { + var dict = new Dictionary(); + dict["logits"] = logits; + var op = tf.OpDefLib._apply_op_helper("Softmax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softmax cross entropy cost and gradients to backpropagate. + /// + /// + /// batch_size x num_classes matrix + /// + /// + /// batch_size x num_classes matrix + /// The caller must ensure that each batch of labels represents a valid + /// probability distribution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SoftmaxCrossEntropyWithLogits'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// loss : Per example loss (batch_size vector). + /// backprop : backpropagated gradients (batch_size x num_classes matrix). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Inputs are the logits, not probabilities. + /// + public static (Tensor loss, Tensor backprop) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SoftmaxCrossEntropyWithLogits") + { + var dict = new Dictionary(); + dict["features"] = features; + dict["labels"] = labels; + var op = tf.OpDefLib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, keywords: dict); + int _idx = 0; + var loss = op.outputs[_idx++]; + var backprop = op.outputs[_idx++]; + return (loss, backprop); + } + + /// + /// Computes softplus: log(exp(features) + 1). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Softplus'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor softplus(Tensor features, string name = "Softplus") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Softplus", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softplus gradients for a softplus operation. + /// + /// + /// The backpropagated gradients to the corresponding softplus operation. + /// + /// + /// The features passed as input to the corresponding softplus operation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SoftplusGrad'. + /// + /// + /// The gradients: gradients / (1 + exp(-features)). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor softplus_grad(Tensor gradients, Tensor features, string name = "SoftplusGrad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("SoftplusGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softsign: features / (abs(features) + 1). + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Softsign'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor softsign(Tensor features, string name = "Softsign") + { + var dict = new Dictionary(); + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("Softsign", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softsign gradients for a softsign operation. + /// + /// + /// The backpropagated gradients to the corresponding softsign operation. + /// + /// + /// The features passed as input to the corresponding softsign operation. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SoftsignGrad'. + /// + /// + /// The gradients: gradients / (1 + abs(features)) ** 2. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor softsign_grad(Tensor gradients, Tensor features, string name = "SoftsignGrad") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["features"] = features; + var op = tf.OpDefLib._apply_op_helper("SoftsignGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// SpaceToBatch for 4-D tensors of type T. + /// + /// + /// 4-D with shape [batch, height, width, depth]. + /// + /// + /// 2-D tensor of non-negative integers with shape [2, 2]. It specifies + /// the padding of the input with zeros across the spatial dimensions as follows: + /// + /// paddings = [[pad_top, pad_bottom], [pad_left, pad_right]] + /// + /// The effective spatial dimensions of the zero-padded input tensor will be: + /// + /// height_pad = pad_top + height + pad_bottom + /// width_pad = pad_left + width + pad_right + /// + /// The attr block_size must be greater than one. It indicates the block size. + /// + /// * Non-overlapping blocks of size block_size x block size in the height and + /// width dimensions are rearranged into the batch dimension at each location. + /// * The batch of the output tensor is batch * block_size * block_size. + /// * Both height_pad and width_pad must be divisible by block_size. + /// + /// The shape of the output will be: + /// + /// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, + /// depth] + /// + /// Some examples: + /// + /// (1) For the following input of shape [1, 2, 2, 1] and block_size of 2: + /// + /// + /// x = [[[[1], [2]], [[3], [4]]]] + /// + /// + /// The output tensor has shape [4, 1, 1, 1] and value: + /// + /// + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// + /// + /// (2) For the following input of shape [1, 2, 2, 3] and block_size of 2: + /// + /// + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// The output tensor has shape [4, 1, 1, 3] and value: + /// + /// + /// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] + /// + /// + /// (3) For the following input of shape [1, 4, 4, 1] and block_size of 2: + /// + /// + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// + /// + /// The output tensor has shape [4, 2, 2, 1] and value: + /// + /// + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// + /// + /// (4) For the following input of shape [2, 2, 4, 1] and block_size of 2: + /// + /// + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]]], + /// [[[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// + /// + /// The output tensor has shape [8, 1, 2, 1] and value: + /// + /// + /// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + /// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + /// + /// + /// Among others, this operation is useful for reducing atrous convolution into + /// regular convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SpaceToBatch'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is a legacy version of the more general SpaceToBatchND. + /// + /// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + /// More specifically, this op outputs a copy of the input tensor where values from + /// the height and width dimensions are moved to the batch dimension. After + /// the zero-padding, both height and width of the input must be divisible by the + /// block size. + /// + public static Tensor space_to_batch(Tensor input, Tensor paddings, int block_size, string name = "SpaceToBatch") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["paddings"] = paddings; + dict["block_size"] = block_size; + var op = tf.OpDefLib._apply_op_helper("SpaceToBatch", name: name, keywords: dict); + return op.output; + } + + /// + /// SpaceToBatch for N-D tensors of type T. + /// + /// + /// N-D with shape input_shape = [batch] + spatial_shape + remaining_shape, + /// where spatial_shape has M dimensions. + /// + /// + /// 1-D with shape [M], all values must be &gt;= 1. + /// + /// + /// 2-D with shape [M, 2], all values must be &gt;= 0. + /// paddings[i] = [pad_start, pad_end] specifies the padding for input dimension + /// i + 1, which corresponds to spatial dimension i. It is required that + /// block_shape[i] divides input_shape[i + 1] + pad_start + pad_end. + /// + /// This operation is equivalent to the following steps: + /// + /// 1. Zero-pad the start and end of dimensions [1, ..., M] of the + /// input according to paddings to produce padded of shape padded_shape. + /// + /// 2. Reshape padded to reshaped_padded of shape: + /// + /// [batch] + + /// [padded_shape[1] / block_shape[0], + /// block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1], + /// block_shape[M-1]] + + /// remaining_shape + /// + /// 3. Permute dimensions of reshaped_padded to produce + /// permuted_reshaped_padded of shape: + /// + /// block_shape + + /// [batch] + + /// [padded_shape[1] / block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1]] + + /// remaining_shape + /// + /// 4. Reshape permuted_reshaped_padded to flatten block_shape into the batch + /// dimension, producing an output tensor of shape: + /// + /// [batch * prod(block_shape)] + + /// [padded_shape[1] / block_shape[0], + /// ..., + /// padded_shape[M] / block_shape[M-1]] + + /// remaining_shape + /// + /// Some examples: + /// + /// (1) For the following input of shape [1, 2, 2, 1], block_shape = [2, 2], and + /// paddings = [[0, 0], [0, 0]]: + /// + /// + /// x = [[[[1], [2]], [[3], [4]]]] + /// + /// + /// The output tensor has shape [4, 1, 1, 1] and value: + /// + /// + /// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + /// + /// + /// (2) For the following input of shape [1, 2, 2, 3], block_shape = [2, 2], and + /// paddings = [[0, 0], [0, 0]]: + /// + /// + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// The output tensor has shape [4, 1, 1, 3] and value: + /// + /// + /// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] + /// + /// + /// (3) For the following input of shape [1, 4, 4, 1], block_shape = [2, 2], and + /// paddings = [[0, 0], [0, 0]]: + /// + /// + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]], + /// [[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// + /// + /// The output tensor has shape [4, 2, 2, 1] and value: + /// + /// + /// x = [[[[1], [3]], [[9], [11]]], + /// [[[2], [4]], [[10], [12]]], + /// [[[5], [7]], [[13], [15]]], + /// [[[6], [8]], [[14], [16]]]] + /// + /// + /// (4) For the following input of shape [2, 2, 4, 1], block_shape = [2, 2], and + /// paddings = [[0, 0], [2, 0]]: + /// + /// + /// x = [[[[1], [2], [3], [4]], + /// [[5], [6], [7], [8]]], + /// [[[9], [10], [11], [12]], + /// [[13], [14], [15], [16]]]] + /// + /// + /// The output tensor has shape [8, 1, 3, 1] and value: + /// + /// + /// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], + /// [[[0], [2], [4]]], [[[0], [10], [12]]], + /// [[[0], [5], [7]]], [[[0], [13], [15]]], + /// [[[0], [6], [8]]], [[[0], [14], [16]]]] + /// + /// + /// Among others, this operation is useful for reducing atrous convolution into + /// regular convolution. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SpaceToBatchND'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation divides "spatial" dimensions [1, ..., M] of the input into a + /// grid of blocks of shape block_shape, and interleaves these blocks with the + /// "batch" dimension (0) such that in the output, the spatial dimensions + /// [1, ..., M] correspond to the position within the grid, and the batch + /// dimension combines both the position within a spatial block and the original + /// batch position. Prior to division into blocks, the spatial dimensions of the + /// input are optionally zero padded according to paddings. See below for a + /// precise description. + /// + public static Tensor space_to_batch_n_d(Tensor input, Tensor block_shape, Tensor paddings, string name = "SpaceToBatchND") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["block_shape"] = block_shape; + dict["paddings"] = paddings; + var op = tf.OpDefLib._apply_op_helper("SpaceToBatchND", name: name, keywords: dict); + return op.output; + } + + /// + /// SpaceToDepth for tensors of type T. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SpaceToDepth'. + /// + /// + /// Optional argument + /// The size of the spatial block. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Rearranges blocks of spatial data, into depth. More specifically, + /// this op outputs a copy of the input tensor where values from the height + /// and width dimensions are moved to the depth dimension. + /// The attr block_size indicates the input block size. + /// + /// * Non-overlapping blocks of size block_size x block size are rearranged + /// into depth at each location. + /// * The depth of the output tensor is block_size * block_size * input_depth. + /// * The Y, X coordinates within each block of the input become the high order + /// component of the output channel index. + /// * The input tensor's height and width must be divisible by block_size. + /// + /// The data_format attr specifies the layout of the input and output tensors + /// with the following options: + /// "NHWC": [ batch, height, width, channels ] + /// "NCHW": [ batch, channels, height, width ] + /// "NCHW_VECT_C": + /// qint8 [ batch, channels / 4, height, width, 4 ] + /// + /// It is useful to consider the operation as transforming a 6-D Tensor. + /// e.g. for data_format = NHWC, + /// Each element in the input tensor can be specified via 6 coordinates, + /// ordered by decreasing memory layout significance as: + /// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates + /// within the output image, bX, bY means coordinates + /// within the input block, iC means input channels). + /// The output would be a transpose to the following layout: + /// n,oY,oX,bY,bX,iC + /// + /// This operation is useful for resizing the activations between convolutions + /// (but keeping all data), e.g. instead of pooling. It is also useful for training + /// purely convolutional models. + /// + /// For example, given an input of shape [1, 2, 2, 1], data_format = "NHWC" and + /// block_size = 2: + /// + /// + /// x = [[[[1], [2]], + /// [[3], [4]]]] + /// + /// + /// This operation will output a tensor of shape [1, 1, 1, 4]: + /// + /// + /// [[[[1, 2, 3, 4]]]] + /// + /// + /// Here, the input has a batch of 1 and each batch element has shape [2, 2, 1], + /// the corresponding output will have a single element (i.e. width and height are + /// both 1) and will have a depth of 4 channels (1 * block_size * block_size). + /// The output element shape is [1, 1, 4]. + /// + /// For an input tensor with larger depth, here of shape [1, 2, 2, 3], e.g. + /// + /// + /// x = [[[[1, 2, 3], [4, 5, 6]], + /// [[7, 8, 9], [10, 11, 12]]]] + /// + /// + /// This operation, for block_size of 2, will return the following tensor of shape + /// [1, 1, 1, 12] + /// + /// + /// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] + /// + /// + /// Similarly, for the following input of shape [1 4 4 1], and a block size of 2: + /// + /// + /// x = [[[[1], [2], [5], [6]], + /// [[3], [4], [7], [8]], + /// [[9], [10], [13], [14]], + /// [[11], [12], [15], [16]]]] + /// + /// + /// the operator will return the following tensor of shape [1 2 2 4]: + /// + /// + /// x = [[[[1, 2, 3, 4], + /// [5, 6, 7, 8]], + /// [[9, 10, 11, 12], + /// [13, 14, 15, 16]]]] + /// + /// + public static Tensor space_to_depth(Tensor input, int block_size, string data_format = null, string name = "SpaceToDepth") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["block_size"] = block_size; + if (data_format != null) + dict["data_format"] = data_format; + var op = tf.OpDefLib._apply_op_helper("SpaceToDepth", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies a sparse gradient to a given accumulator. + /// + /// + /// The handle to a accumulator. + /// + /// + /// The local_step value at which the sparse gradient was computed. + /// + /// + /// Indices of the sparse gradient to be accumulated. Must be a + /// vector. + /// + /// + /// Values are the non-zero slices of the gradient, and must have + /// the same first dimension as indices, i.e., the nnz represented by indices and + /// values must be consistent. + /// + /// + /// Shape of the sparse gradient to be accumulated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseAccumulatorApplyGradient'. + /// + /// + /// Optional argument + /// Boolean indicating whether gradient_shape is unknown, in which + /// case the input is ignored during validation. + /// + /// + /// Returns the description of the operation + /// + /// + /// Does not add if local_step is smaller than the accumulator's + /// global_step. + /// + public static Operation sparse_accumulator_apply_gradient(Tensor handle, Tensor local_step, Tensor gradient_indices, Tensor gradient_values, Tensor gradient_shape, bool has_known_shape, string name = "SparseAccumulatorApplyGradient") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["local_step"] = local_step; + dict["gradient_indices"] = gradient_indices; + dict["gradient_values"] = gradient_values; + dict["gradient_shape"] = gradient_shape; + dict["has_known_shape"] = has_known_shape; + var op = tf.OpDefLib._apply_op_helper("SparseAccumulatorApplyGradient", name: name, keywords: dict); + return op; + } + + /// + /// Extracts the average sparse gradient in a SparseConditionalAccumulator. + /// + /// + /// The handle to a SparseConditionalAccumulator. + /// + /// + /// Number of gradients required before we return an aggregate. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseAccumulatorTakeGradient'. + /// + /// + /// Optional argument + /// The data type of accumulated gradients. Needs to correspond to the type + /// of the accumulator. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// indices : Indices of the average of the accumulated sparse gradients. + /// values : Values of the average of the accumulated sparse gradients. + /// shape : Shape of the average of the accumulated sparse gradients. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The op will blocks until sufficient (i.e., more than num_required) + /// gradients have been accumulated. If the accumulator has already + /// aggregated more than num_required gradients, it will return its + /// average of the accumulated gradients. Also automatically increments + /// the recorded global_step in the accumulator by 1, and resets the + /// aggregate to 0. + /// + public static (Tensor indices, Tensor values, Tensor shape) sparse_accumulator_take_gradient(Tensor handle, Tensor num_required, TF_DataType dtype, string name = "SparseAccumulatorTakeGradient") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["num_required"] = num_required; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("SparseAccumulatorTakeGradient", name: name, keywords: dict); + int _idx = 0; + var indices = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + var shape = op.outputs[_idx++]; + return (indices, values, shape); + } + + /// + /// Adds two SparseTensor objects to produce another SparseTensor. + /// + /// + /// 2-D. The indices of the first SparseTensor, size [nnz, ndims] Matrix. + /// + /// + /// 1-D. The values of the first SparseTensor, size [nnz] Vector. + /// + /// + /// 1-D. The shape of the first SparseTensor, size [ndims] Vector. + /// + /// + /// 2-D. The indices of the second SparseTensor, size [nnz, ndims] Matrix. + /// + /// + /// 1-D. The values of the second SparseTensor, size [nnz] Vector. + /// + /// + /// 1-D. The shape of the second SparseTensor, size [ndims] Vector. + /// + /// + /// 0-D. The magnitude threshold that determines if an output value/index + /// pair takes space. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseAdd'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sum_indices : + /// sum_values : + /// sum_shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The input SparseTensor objects' indices are assumed ordered in standard + /// lexicographic order. If this is not the case, before this step run + /// SparseReorder to restore index ordering. + /// + /// By default, if two values sum to zero at some index, the output SparseTensor + /// would still include that particular location in its index, storing a zero in the + /// corresponding value slot. To override this, callers can specify thresh, + /// indicating that if the sum has a magnitude strictly smaller than thresh, its + /// corresponding value and index would then not be included. In particular, + /// thresh == 0 (default) means everything is kept and actual thresholding happens + /// only for a positive value. + /// + /// In the following shapes, nnz is the count after taking thresh into account. + /// + public static (Tensor sum_indices, Tensor sum_values, Tensor sum_shape) sparse_add(Tensor a_indices, Tensor a_values, Tensor a_shape, Tensor b_indices, Tensor b_values, Tensor b_shape, Tensor thresh, string name = "SparseAdd") + { + var dict = new Dictionary(); + dict["a_indices"] = a_indices; + dict["a_values"] = a_values; + dict["a_shape"] = a_shape; + dict["b_indices"] = b_indices; + dict["b_values"] = b_values; + dict["b_shape"] = b_shape; + dict["thresh"] = thresh; + var op = tf.OpDefLib._apply_op_helper("SparseAdd", name: name, keywords: dict); + int _idx = 0; + var sum_indices = op.outputs[_idx++]; + var sum_values = op.outputs[_idx++]; + var sum_shape = op.outputs[_idx++]; + return (sum_indices, sum_values, sum_shape); + } + + /// + /// The gradient operator for the SparseAdd op. + /// + /// + /// 1-D with shape [nnz(sum)]. The gradient with respect to + /// the non-empty values of the sum. + /// + /// + /// 2-D. The indices of the SparseTensor A, size [nnz(A), ndims]. + /// + /// + /// 2-D. The indices of the SparseTensor B, size [nnz(B), ndims]. + /// + /// + /// 2-D. The indices of the sum SparseTensor, size + /// [nnz(sum), ndims]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseAddGrad'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// a_val_grad : 1-D with shape [nnz(A)]. The gradient with respect to the + /// non-empty values of A. + /// b_val_grad : 1-D with shape [nnz(B)]. The gradient with respect to the + /// non-empty values of B. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The SparseAdd op calculates A + B, where A, B, and the sum are all represented + /// as SparseTensor objects. This op takes in the upstream gradient w.r.t. + /// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty + /// values of A and B. + /// + public static (Tensor a_val_grad, Tensor b_val_grad) sparse_add_grad(Tensor backprop_val_grad, Tensor a_indices, Tensor b_indices, Tensor sum_indices, string name = "SparseAddGrad") + { + var dict = new Dictionary(); + dict["backprop_val_grad"] = backprop_val_grad; + dict["a_indices"] = a_indices; + dict["b_indices"] = b_indices; + dict["sum_indices"] = sum_indices; + var op = tf.OpDefLib._apply_op_helper("SparseAddGrad", name: name, keywords: dict); + int _idx = 0; + var a_val_grad = op.outputs[_idx++]; + var b_val_grad = op.outputs[_idx++]; + return (a_val_grad, b_val_grad); + } + + /// + /// var: Should be from a Variable(). + /// + /// + /// + /// + /// Should be from a Variable(). + /// + /// + /// : Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// Decay factor. Must be a scalar. + /// + /// + /// Constant factor. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyAdadelta'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sparse_apply_adadelta(Tensor var, Tensor accum, Tensor accum_update, Tensor lr, Tensor rho, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "SparseApplyAdadelta") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["accum_update"] = accum_update; + dict["lr"] = lr; + dict["rho"] = rho; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyAdadelta", name: name, keywords: dict); + return op.output; + } + + /// + /// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// $$accum += grad * grad$$ + /// $$var -= lr * grad * (1 / sqrt(accum))$$ + /// + public static Tensor sparse_apply_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor indices, bool? use_locking = null, bool? update_slots = null, string name = "SparseApplyAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (update_slots.HasValue) + dict["update_slots"] = update_slots.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyAdagrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Training step number. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyAdagradDA'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sparse_apply_adagrad_d_a(Tensor var, Tensor gradient_accumulator, Tensor gradient_squared_accumulator, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor global_step, bool? use_locking = null, string name = "SparseApplyAdagradDA") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["gradient_accumulator"] = gradient_accumulator; + dict["gradient_squared_accumulator"] = gradient_squared_accumulator; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["global_step"] = global_step; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyAdagradDA", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the centered RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var, ms and mom. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyCenteredRMSProp'. + /// + /// + /// If True, updating of the var, mg, ms, and mom tensors is + /// protected by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The centered RMSProp algorithm uses an estimate of the centered second moment + /// (i.e., the variance) for normalization, as opposed to regular RMSProp, which + /// uses the (uncentered) second moment. This often helps with training, but is + /// slightly more expensive in terms of computation and memory. + /// + /// Note that in dense implementation of this algorithm, mg, ms, and mom will + /// update even if the grad is zero, but in this sparse implementation, mg, ms, + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// mean_grad = decay * mean_grad + (1-decay) * gradient + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) + /// + /// $$ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad$$ + /// $$mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)$$ + /// $$var &lt;- var - mom$$ + /// + public static Tensor sparse_apply_centered_r_m_s_prop(Tensor var, Tensor mg, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "SparseApplyCenteredRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["mg"] = mg; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyCenteredRMSProp", name: name, keywords: dict); + return op.output; + } + + /// + /// Update relevant entries in '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyFtrl'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// That is for rows we have grad for, we update var, accum and linear as follows: + /// $$accum_new = accum + grad * grad$$ + /// $$linear += grad + (accum_{new}^{-lr_{power}} - accum^{-lr_{power}} / lr * var$$ + /// $$quadratic = 1.0 / (accum_{new}^{lr_{power}} * lr) + 2 * l2$$ + /// $$var = (sign(linear) * l1 - linear) / quadratic\ if\ |linear| &gt; l1\ else\ 0.0$$ + /// $$accum = accum_{new}$$ + /// + public static Tensor sparse_apply_ftrl(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor lr_power, bool? use_locking = null, string name = "SparseApplyFtrl") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyFtrl", name: name, keywords: dict); + return op.output; + } + + /// + /// Update relevant entries in '*var' according to the Ftrl-proximal scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 shrinkage regulariation. Must be a scalar. + /// + /// + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyFtrlV2'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// That is for rows we have grad for, we update var, accum and linear as follows: + /// grad_with_shrinkage = grad + 2 * l2_shrinkage * var + /// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage + /// linear += grad_with_shrinkage + + /// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var + /// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 + /// var = (sign(linear) * l1 - linear) / quadratic if |linear| &gt; l1 else 0.0 + /// accum = accum_new + /// + public static Tensor sparse_apply_ftrl_v2(Tensor var, Tensor accum, Tensor linear, Tensor grad, Tensor indices, Tensor lr, Tensor l1, Tensor l2, Tensor l2_shrinkage, Tensor lr_power, bool? use_locking = null, string name = "SparseApplyFtrlV2") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["linear"] = linear; + dict["grad"] = grad; + dict["indices"] = indices; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["l2_shrinkage"] = l2_shrinkage; + dict["lr_power"] = lr_power; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyFtrlV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Update relevant entries in '*var' and '*accum' according to the momentum scheme. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// Momentum. Must be a scalar. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyMomentum'. + /// + /// + /// If True, updating of the var and accum tensors will be protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// If True, the tensor passed to compute grad will be + /// var - lr * momentum * accum, so in the end, the var you get is actually + /// var - lr * momentum * accum. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Set use_nesterov = True if you want to use Nesterov momentum. + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// + /// $$accum = accum * momentum + grad$$ + /// $$var -= lr * accum$$ + /// + public static Tensor sparse_apply_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor indices, Tensor momentum, bool? use_locking = null, bool? use_nesterov = null, string name = "SparseApplyMomentum") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["grad"] = grad; + dict["indices"] = indices; + dict["momentum"] = momentum; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + if (use_nesterov.HasValue) + dict["use_nesterov"] = use_nesterov.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyMomentum", name: name, keywords: dict); + return op.output; + } + + /// + /// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Learning rate. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyProximalAdagrad'. + /// + /// + /// If True, updating of the var and accum tensors will be protected by + /// a lock; otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// That is for rows we have grad for, we update var and accum as follows: + /// $$accum += grad * grad$$ + /// $$prox_v = var$$ + /// $$prox_v -= lr * grad * (1 / sqrt(accum))$$ + /// $$var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}$$ + /// + public static Tensor sparse_apply_proximal_adagrad(Tensor var, Tensor accum, Tensor lr, Tensor l1, Tensor l2, Tensor grad, Tensor indices, bool? use_locking = null, string name = "SparseApplyProximalAdagrad") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["accum"] = accum; + dict["lr"] = lr; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyProximalAdagrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Sparse update '*var' as FOBOS algorithm with fixed learning rate. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// L1 regularization. Must be a scalar. + /// + /// + /// L2 regularization. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var and accum. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyProximalGradientDescent'. + /// + /// + /// If True, the subtraction will be protected by a lock; + /// otherwise the behavior is undefined, but may exhibit less contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// That is for rows we have grad for, we update var as follows: + /// $$prox_v = var - alpha * grad$$ + /// $$var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}$$ + /// + public static Tensor sparse_apply_proximal_gradient_descent(Tensor var, Tensor alpha, Tensor l1, Tensor l2, Tensor grad, Tensor indices, bool? use_locking = null, string name = "SparseApplyProximalGradientDescent") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["alpha"] = alpha; + dict["l1"] = l1; + dict["l2"] = l2; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyProximalGradientDescent", name: name, keywords: dict); + return op.output; + } + + /// + /// Update '*var' according to the RMSProp algorithm. + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Should be from a Variable(). + /// + /// + /// Scaling factor. Must be a scalar. + /// + /// + /// Decay rate. Must be a scalar. + /// + /// + /// + /// + /// Ridge term. Must be a scalar. + /// + /// + /// The gradient. + /// + /// + /// A vector of indices into the first dimension of var, ms and mom. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseApplyRMSProp'. + /// + /// + /// If True, updating of the var, ms, and mom tensors is protected + /// by a lock; otherwise the behavior is undefined, but may exhibit less + /// contention. + /// + /// + /// Same as "var". + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note that in dense implementation of this algorithm, ms and mom will + /// update even if the grad is zero, but in this sparse implementation, ms + /// and mom will not update in iterations during which the grad is zero. + /// + /// mean_square = decay * mean_square + (1-decay) * gradient ** 2 + /// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + /// + /// $$ms &lt;- rho * ms_{t-1} + (1-rho) * grad * grad$$ + /// $$mom &lt;- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)$$ + /// $$var &lt;- var - mom$$ + /// + public static Tensor sparse_apply_r_m_s_prop(Tensor var, Tensor ms, Tensor mom, Tensor lr, Tensor rho, Tensor momentum, Tensor epsilon, Tensor grad, Tensor indices, bool? use_locking = null, string name = "SparseApplyRMSProp") + { + var dict = new Dictionary(); + dict["var"] = var; + dict["ms"] = ms; + dict["mom"] = mom; + dict["lr"] = lr; + dict["rho"] = rho; + dict["momentum"] = momentum; + dict["epsilon"] = epsilon; + dict["grad"] = grad; + dict["indices"] = indices; + if (use_locking.HasValue) + dict["use_locking"] = use_locking.Value; + var op = tf.OpDefLib._apply_op_helper("SparseApplyRMSProp", name: name, keywords: dict); + return op.output; + } + + /// + /// Concatenates a list of SparseTensor along the specified dimension. + /// + /// + /// 2-D. Indices of each input SparseTensor. + /// + /// + /// 1-D. Non-empty values of each SparseTensor. + /// + /// + /// 1-D. Shapes of each SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseConcat'. + /// + /// + /// Optional argument + /// Dimension to concatenate along. Must be in range [-rank, rank), + /// where rank is the number of dimensions in each input SparseTensor. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. Indices of the concatenated SparseTensor. + /// output_values : 1-D. Non-empty values of the concatenated SparseTensor. + /// output_shape : 1-D. Shape of the concatenated SparseTensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Concatenation is with respect to the dense versions of these sparse tensors. + /// It is assumed that each input is a SparseTensor whose elements are ordered + /// along increasing dimension number. + /// + /// All inputs' shapes must match, except for the concat dimension. The + /// indices, values, and shapes lists must have the same length. + /// + /// The output shape is identical to the inputs', except along the concat + /// dimension, where it is the sum of the inputs' sizes along that dimension. + /// + /// The output elements will be resorted to preserve the sort order along + /// increasing dimension number. + /// + /// This op runs in O(M log M) time, where M is the total number of non-empty + /// values across all inputs. This is due to the need for an internal sort in + /// order to concatenate efficiently across an arbitrary dimension. + /// + /// For example, if concat_dim = 1 and the inputs are + /// + /// sp_inputs[0]: shape = [2, 3] + /// [0, 2]: "a" + /// [1, 0]: "b" + /// [1, 1]: "c" + /// + /// sp_inputs[1]: shape = [2, 4] + /// [0, 1]: "d" + /// [0, 2]: "e" + /// + /// then the output will be + /// + /// shape = [2, 7] + /// [0, 2]: "a" + /// [0, 4]: "d" + /// [0, 5]: "e" + /// [1, 0]: "b" + /// [1, 1]: "c" + /// + /// Graphically this is equivalent to doing + /// + /// [ a] concat [ d e ] = [ a d e ] + /// [b c ] [ ] [b c ] + /// + public static (Tensor output_indices, Tensor output_values, Tensor output_shape) sparse_concat(Tensor[] indices, Tensor[] values, Tensor[] shapes, int concat_dim, string name = "SparseConcat") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["values"] = values; + dict["shapes"] = shapes; + dict["concat_dim"] = concat_dim; + var op = tf.OpDefLib._apply_op_helper("SparseConcat", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_values, output_shape); + } + + /// + /// A conditional accumulator for aggregating sparse gradients. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseConditionalAccumulator'. + /// + /// + /// Optional argument + /// The type of the value being accumulated. + /// + /// + /// Optional argument + /// The shape of the values. + /// + /// + /// If non-empty, this accumulator is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this accumulator will be shared under the given name + /// across multiple sessions. + /// + /// + /// The handle to the accumulator. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The accumulator accepts gradients marked with local_step greater or + /// equal to the most recent global_step known to the accumulator. The + /// average can be extracted from the accumulator, provided sufficient + /// gradients have been accumulated. Extracting the average automatically + /// resets the aggregate to 0, and increments the global_step recorded by + /// the accumulator. + /// + public static Tensor sparse_conditional_accumulator(TF_DataType dtype, Shape shape, string container = null, string shared_name = null, string name = "SparseConditionalAccumulator") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("SparseConditionalAccumulator", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates sparse cross from a list of sparse and dense tensors. + /// + /// + /// 2-D. Indices of each input SparseTensor. + /// + /// + /// 1-D. values of each SparseTensor. + /// + /// + /// 1-D. Shapes of each SparseTensor. + /// + /// + /// 2-D. Columns represented by dense Tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseCross'. + /// + /// + /// Optional argument + /// If true, returns the hash of the cross instead of the string. + /// This will allow us avoiding string manipulations. + /// + /// + /// Optional argument + /// It is used if hashed_output is true. + /// output = hashed_value%num_buckets if num_buckets &gt; 0 else hashed_value. + /// + /// + /// Optional argument + /// Specify the hash_key that will be used by the FingerprintCat64 + /// function to combine the crosses fingerprints. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. Indices of the concatenated SparseTensor. + /// output_values : 1-D. Non-empty values of the concatenated or hashed + /// SparseTensor. + /// output_shape : 1-D. Shape of the concatenated SparseTensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The op takes two lists, one of 2D SparseTensor and one of 2D Tensor, each + /// representing features of one feature column. It outputs a 2D SparseTensor with + /// the batchwise crosses of these features. + /// + /// For example, if the inputs are + /// + /// inputs[0]: SparseTensor with shape = [2, 2] + /// [0, 0]: "a" + /// [1, 0]: "b" + /// [1, 1]: "c" + /// + /// inputs[1]: SparseTensor with shape = [2, 1] + /// [0, 0]: "d" + /// [1, 0]: "e" + /// + /// inputs[2]: Tensor [["f"], ["g"]] + /// + /// then the output will be + /// + /// shape = [2, 2] + /// [0, 0]: "a_X_d_X_f" + /// [1, 0]: "b_X_e_X_g" + /// [1, 1]: "c_X_e_X_g" + /// + /// if hashed_output=true then the output will be + /// + /// shape = [2, 2] + /// [0, 0]: FingerprintCat64( + /// Fingerprint64("f"), FingerprintCat64( + /// Fingerprint64("d"), Fingerprint64("a"))) + /// [1, 0]: FingerprintCat64( + /// Fingerprint64("g"), FingerprintCat64( + /// Fingerprint64("e"), Fingerprint64("b"))) + /// [1, 1]: FingerprintCat64( + /// Fingerprint64("g"), FingerprintCat64( + /// Fingerprint64("e"), Fingerprint64("c"))) + /// + public static (Tensor output_indices, Tensor output_values, Tensor output_shape) sparse_cross(Tensor[] indices, Tensor[] values, Tensor[] shapes, Tensor[] dense_inputs, bool hashed_output, int num_buckets, int hash_key, TF_DataType out_type, TF_DataType internal_type, string name = "SparseCross") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["values"] = values; + dict["shapes"] = shapes; + dict["dense_inputs"] = dense_inputs; + dict["hashed_output"] = hashed_output; + dict["num_buckets"] = num_buckets; + dict["hash_key"] = hash_key; + dict["out_type"] = out_type; + dict["internal_type"] = internal_type; + var op = tf.OpDefLib._apply_op_helper("SparseCross", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_values, output_shape); + } + + /// + /// Adds up a SparseTensor and a dense Tensor, using these special rules: + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to sp_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// R-D. The dense Tensor operand. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseDenseCwiseAdd'. + /// + /// + /// 1-D. The N values that are operated on. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// (1) Broadcasts the dense side to have the same shape as the sparse side, if + /// eligible; + /// (2) Then, only the dense values pointed to by the indices of the SparseTensor + /// participate in the cwise addition. + /// + /// By these rules, the result is a logical SparseTensor with exactly the same + /// indices and shape, but possibly with different non-zero values. The output of + /// this Op is the resultant non-zero values. + /// + public static Tensor sparse_dense_cwise_add(Tensor sp_indices, Tensor sp_values, Tensor sp_shape, Tensor dense, string name = "SparseDenseCwiseAdd") + { + var dict = new Dictionary(); + dict["sp_indices"] = sp_indices; + dict["sp_values"] = sp_values; + dict["sp_shape"] = sp_shape; + dict["dense"] = dense; + var op = tf.OpDefLib._apply_op_helper("SparseDenseCwiseAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Component-wise divides a SparseTensor by a dense Tensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to sp_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// R-D. The dense Tensor operand. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseDenseCwiseDiv'. + /// + /// + /// 1-D. The N values that are operated on. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not + /// the other direction. + /// + public static Tensor sparse_dense_cwise_div(Tensor sp_indices, Tensor sp_values, Tensor sp_shape, Tensor dense, string name = "SparseDenseCwiseDiv") + { + var dict = new Dictionary(); + dict["sp_indices"] = sp_indices; + dict["sp_values"] = sp_values; + dict["sp_shape"] = sp_shape; + dict["dense"] = dense; + var op = tf.OpDefLib._apply_op_helper("SparseDenseCwiseDiv", name: name, keywords: dict); + return op.output; + } + + /// + /// Component-wise multiplies a SparseTensor by a dense Tensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to sp_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// R-D. The dense Tensor operand. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseDenseCwiseMul'. + /// + /// + /// 1-D. The N values that are operated on. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The output locations corresponding to the implicitly zero elements in the sparse + /// tensor will be zero (i.e., will not take up storage space), regardless of the + /// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). + /// + /// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not + /// the other direction. + /// + public static Tensor sparse_dense_cwise_mul(Tensor sp_indices, Tensor sp_values, Tensor sp_shape, Tensor dense, string name = "SparseDenseCwiseMul") + { + var dict = new Dictionary(); + dict["sp_indices"] = sp_indices; + dict["sp_values"] = sp_values; + dict["sp_shape"] = sp_shape; + dict["dense"] = dense; + var op = tf.OpDefLib._apply_op_helper("SparseDenseCwiseMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Fills empty rows in the input 2-D SparseTensor with a default value. + /// + /// + /// 2-D. the indices of the sparse tensor. + /// + /// + /// 1-D. the values of the sparse tensor. + /// + /// + /// 1-D. the shape of the sparse tensor. + /// + /// + /// 0-D. default value to insert into location [row, 0, ..., 0] + /// for rows missing from the input sparse tensor. + /// output indices: 2-D. the indices of the filled sparse tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseFillEmptyRows'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : + /// output_values : 1-D. the values of the filled sparse tensor. + /// empty_row_indicator : 1-D. whether the dense row was missing in the + /// input sparse tensor. + /// reverse_index_map : 1-D. a map from the input indices to the output indices. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The input SparseTensor is represented via the tuple of inputs + /// (indices, values, dense_shape). The output SparseTensor has the + /// same dense_shape but with indices output_indices and values + /// output_values. + /// + /// This op inserts a single entry for every row that doesn't have any values. + /// The index is created as [row, 0, ..., 0] and the inserted value + /// is default_value. + /// + /// For example, suppose sp_input has shape [5, 6] and non-empty values: + /// + /// [0, 1]: a + /// [0, 3]: b + /// [2, 0]: c + /// [3, 1]: d + /// + /// Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values: + /// + /// [0, 1]: a + /// [0, 3]: b + /// [1, 0]: default_value + /// [2, 0]: c + /// [3, 1]: d + /// [4, 0]: default_value + /// + /// The output SparseTensor will be in row-major order and will have the + /// same shape as the input. + /// + /// This op also returns an indicator vector shaped [dense_shape[0]] such that + /// + /// empty_row_indicator[i] = True iff row i was an empty row. + /// + /// And a reverse index map vector shaped [indices.shape[0]] that is used during + /// backpropagation, + /// + /// reverse_index_map[j] = out_j s.t. indices[j, :] == output_indices[out_j, :] + /// + public static (Tensor output_indices, Tensor output_values, Tensor empty_row_indicator, Tensor reverse_index_map) sparse_fill_empty_rows(Tensor indices, Tensor values, Tensor dense_shape, Tensor default_value, string name = "SparseFillEmptyRows") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["values"] = values; + dict["dense_shape"] = dense_shape; + dict["default_value"] = default_value; + var op = tf.OpDefLib._apply_op_helper("SparseFillEmptyRows", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var empty_row_indicator = op.outputs[_idx++]; + var reverse_index_map = op.outputs[_idx++]; + return (output_indices, output_values, empty_row_indicator, reverse_index_map); + } + + /// + /// The gradient of SparseFillEmptyRows. + /// + /// + /// 1-D. The reverse index map from SparseFillEmptyRows. + /// + /// + /// 1-D. The gradients from backprop. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseFillEmptyRowsGrad'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// d_values : 1-D. The backprop into values. + /// d_default_value : 0-D. The backprop into default_value. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Takes vectors reverse_index_map, shaped [N], and grad_values, + /// shaped [N_full], where N_full &gt;= N and copies data into either + /// d_values or d_default_value. Here d_values is shaped [N] and + /// d_default_value is a scalar. + /// + /// d_values[j] = grad_values[reverse_index_map[j]] + /// d_default_value = sum_{k : 0 .. N_full - 1} ( + /// grad_values[k] * 1{k not in reverse_index_map}) + /// + public static (Tensor d_values, Tensor d_default_value) sparse_fill_empty_rows_grad(Tensor reverse_index_map, Tensor grad_values, string name = "SparseFillEmptyRowsGrad") + { + var dict = new Dictionary(); + dict["reverse_index_map"] = reverse_index_map; + dict["grad_values"] = grad_values; + var op = tf.OpDefLib._apply_op_helper("SparseFillEmptyRowsGrad", name: name, keywords: dict); + int _idx = 0; + var d_values = op.outputs[_idx++]; + var d_default_value = op.outputs[_idx++]; + return (d_values, d_default_value); + } + + /// + /// Multiply matrix "a" by matrix "b". + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseMatMul'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The inputs must be two-dimensional matrices and the inner dimension of "a" must + /// match the outer dimension of "b". Both "a" and "b" must be Tensors not + /// SparseTensors. This op is optimized for the case where at least one of "a" or + /// "b" is sparse, in the sense that they have a large proportion of zero values. + /// The breakeven for using this versus a dense matrix multiply on one platform was + /// 30% zero values in the sparse matrix. + /// + /// The gradient computation of this operation will only take advantage of sparsity + /// in the input gradient when that gradient comes from a Relu. + /// + public static Tensor sparse_mat_mul(Tensor a, Tensor b, bool? transpose_a = null, bool? transpose_b = null, bool? a_is_sparse = null, bool? b_is_sparse = null, string name = "SparseMatMul") + { + var dict = new Dictionary(); + dict["a"] = a; + dict["b"] = b; + if (transpose_a.HasValue) + dict["transpose_a"] = transpose_a.Value; + if (transpose_b.HasValue) + dict["transpose_b"] = transpose_b.Value; + if (a_is_sparse.HasValue) + dict["a_is_sparse"] = a_is_sparse.Value; + if (b_is_sparse.HasValue) + dict["b_is_sparse"] = b_is_sparse.Value; + var op = tf.OpDefLib._apply_op_helper("SparseMatMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the max of elements across dimensions of a SparseTensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to input_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// 1-D. Length-K vector containing the reduction axes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReduceMax'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// R-K-D. The reduced Tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This Op takes a SparseTensor and is the sparse counterpart to + /// tf.reduce_max(). In particular, this Op also returns a dense Tensor + /// instead of a sparse one. + /// + /// Reduces sp_input along the dimensions given in reduction_axes. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// reduction_axes. If keep_dims is true, the reduced dimensions are retained + /// with length 1. + /// + /// If reduction_axes has no entries, all dimensions are reduced, and a tensor + /// with a single element is returned. Additionally, the axes can be negative, + /// which are interpreted according to the indexing rules in Python. + /// + public static Tensor sparse_reduce_max(Tensor input_indices, Tensor input_values, Tensor input_shape, Tensor reduction_axes, bool? keep_dims = null, string name = "SparseReduceMax") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_values"] = input_values; + dict["input_shape"] = input_shape; + dict["reduction_axes"] = reduction_axes; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("SparseReduceMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the max of elements across dimensions of a SparseTensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to input_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// 1-D. Length-K vector containing the reduction axes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReduceMaxSparse'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : + /// output_values : + /// output_shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This Op takes a SparseTensor and is the sparse counterpart to + /// tf.reduce_max(). In contrast to SparseReduceMax, this Op returns a + /// SparseTensor. + /// + /// Reduces sp_input along the dimensions given in reduction_axes. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// reduction_axes. If keep_dims is true, the reduced dimensions are retained + /// with length 1. + /// + /// If reduction_axes has no entries, all dimensions are reduced, and a tensor + /// with a single element is returned. Additionally, the axes can be negative, + /// which are interpreted according to the indexing rules in Python. + /// + public static (Tensor output_indices, Tensor output_values, Tensor output_shape) sparse_reduce_max_sparse(Tensor input_indices, Tensor input_values, Tensor input_shape, Tensor reduction_axes, bool? keep_dims = null, string name = "SparseReduceMaxSparse") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_values"] = input_values; + dict["input_shape"] = input_shape; + dict["reduction_axes"] = reduction_axes; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("SparseReduceMaxSparse", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_values, output_shape); + } + + /// + /// Computes the sum of elements across dimensions of a SparseTensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to input_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// 1-D. Length-K vector containing the reduction axes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReduceSum'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// R-K-D. The reduced Tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This Op takes a SparseTensor and is the sparse counterpart to + /// tf.reduce_sum(). In particular, this Op also returns a dense Tensor + /// instead of a sparse one. + /// + /// Reduces sp_input along the dimensions given in reduction_axes. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// reduction_axes. If keep_dims is true, the reduced dimensions are retained + /// with length 1. + /// + /// If reduction_axes has no entries, all dimensions are reduced, and a tensor + /// with a single element is returned. Additionally, the axes can be negative, + /// which are interpreted according to the indexing rules in Python. + /// + public static Tensor sparse_reduce_sum(Tensor input_indices, Tensor input_values, Tensor input_shape, Tensor reduction_axes, bool? keep_dims = null, string name = "SparseReduceSum") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_values"] = input_values; + dict["input_shape"] = input_shape; + dict["reduction_axes"] = reduction_axes; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("SparseReduceSum", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum of elements across dimensions of a SparseTensor. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to input_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// 1-D. Length-K vector containing the reduction axes. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReduceSumSparse'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : + /// output_values : + /// output_shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This Op takes a SparseTensor and is the sparse counterpart to + /// tf.reduce_sum(). In contrast to SparseReduceSum, this Op returns a + /// SparseTensor. + /// + /// Reduces sp_input along the dimensions given in reduction_axes. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// reduction_axes. If keep_dims is true, the reduced dimensions are retained + /// with length 1. + /// + /// If reduction_axes has no entries, all dimensions are reduced, and a tensor + /// with a single element is returned. Additionally, the axes can be negative, + /// which are interpreted according to the indexing rules in Python. + /// + public static (Tensor output_indices, Tensor output_values, Tensor output_shape) sparse_reduce_sum_sparse(Tensor input_indices, Tensor input_values, Tensor input_shape, Tensor reduction_axes, bool? keep_dims = null, string name = "SparseReduceSumSparse") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_values"] = input_values; + dict["input_shape"] = input_shape; + dict["reduction_axes"] = reduction_axes; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("SparseReduceSumSparse", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_values, output_shape); + } + + /// + /// Reorders a SparseTensor into the canonical, row-major ordering. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, possibly not in canonical ordering. + /// + /// + /// 1-D. N non-empty values corresponding to input_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReorder'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. N x R matrix with the same indices as input_indices, but + /// in canonical row-major ordering. + /// output_values : 1-D. N non-empty values corresponding to output_indices. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Note that by convention, all sparse ops preserve the canonical ordering along + /// increasing dimension number. The only time ordering can be violated is during + /// manual manipulation of the indices and values vectors to add entries. + /// + /// Reordering does not affect the shape of the SparseTensor. + /// + /// If the tensor has rank R and N non-empty values, input_indices has + /// shape [N, R], input_values has length N, and input_shape has length R. + /// + public static (Tensor output_indices, Tensor output_values) sparse_reorder(Tensor input_indices, Tensor input_values, Tensor input_shape, string name = "SparseReorder") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_values"] = input_values; + dict["input_shape"] = input_shape; + var op = tf.OpDefLib._apply_op_helper("SparseReorder", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + return (output_indices, output_values); + } + + /// + /// Reshapes a SparseTensor to represent values in a new dense shape. + /// + /// + /// 2-D. N x R_in matrix with the indices of non-empty values in a + /// SparseTensor. + /// + /// + /// 1-D. R_in vector with the input SparseTensor's dense shape. + /// + /// + /// 1-D. R_out vector with the requested new dense shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseReshape'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. N x R_out matrix with the updated indices of non-empty + /// values in the output SparseTensor. + /// output_shape : 1-D. R_out vector with the full dense shape of the output + /// SparseTensor. This is the same as new_shape but with any -1 dimensions + /// filled in. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation has the same semantics as reshape on the represented dense + /// tensor. The input_indices are recomputed based on the requested new_shape. + /// + /// If one component of new_shape is the special value -1, the size of that + /// dimension is computed so that the total dense size remains constant. At + /// most one component of new_shape can be -1. The number of dense elements + /// implied by new_shape must be the same as the number of dense elements + /// originally implied by input_shape. + /// + /// Reshaping does not affect the order of values in the SparseTensor. + /// + /// If the input tensor has rank R_in and N non-empty values, and new_shape + /// has length R_out, then input_indices has shape [N, R_in], + /// input_shape has length R_in, output_indices has shape [N, R_out], and + /// output_shape has length R_out. + /// + public static (Tensor output_indices, Tensor output_shape) sparse_reshape(Tensor input_indices, Tensor input_shape, Tensor new_shape, string name = "SparseReshape") + { + var dict = new Dictionary(); + dict["input_indices"] = input_indices; + dict["input_shape"] = input_shape; + dict["new_shape"] = new_shape; + var op = tf.OpDefLib._apply_op_helper("SparseReshape", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_shape); + } + + /// + /// Computes the mean along sparse segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentMean'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Like SegmentMean, but segment_ids can have rank less than data's first + /// dimension, selecting a subset of dimension 0, specified by indices. + /// + public static Tensor sparse_segment_mean(Tensor data, Tensor indices, Tensor segment_ids, string name = "SparseSegmentMean") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentMean", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients for SparseSegmentMean. + /// + /// + /// gradient propagated to the SparseSegmentMean op. + /// + /// + /// indices passed to the corresponding SparseSegmentMean op. + /// + /// + /// segment_ids passed to the corresponding SparseSegmentMean op. + /// + /// + /// dimension 0 of "data" passed to SparseSegmentMean op. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentMeanGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Returns tensor "output" with same shape as grad, except for dimension 0 whose + /// value is output_dim0. + /// + public static Tensor sparse_segment_mean_grad(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string name = "SparseSegmentMeanGrad") + { + var dict = new Dictionary(); + dict["grad"] = grad; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + dict["output_dim0"] = output_dim0; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentMeanGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the mean along sparse segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// Should equal the number of distinct segment IDs. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentMeanWithNumSegments'. + /// + /// + /// Has same shape as data, except for dimension 0 which has size + /// num_segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Like SparseSegmentMean, but allows missing ids in segment_ids. If an id is + /// misisng, the output tensor at that position will be zeroed. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + public static Tensor sparse_segment_mean_with_num_segments(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string name = "SparseSegmentMeanWithNumSegments") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentMeanWithNumSegments", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along sparse segments of a tensor divided by the sqrt of N. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentSqrtN'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// N is the size of the segment being reduced. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + public static Tensor sparse_segment_sqrt_n(Tensor data, Tensor indices, Tensor segment_ids, string name = "SparseSegmentSqrtN") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentSqrtN", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes gradients for SparseSegmentSqrtN. + /// + /// + /// gradient propagated to the SparseSegmentSqrtN op. + /// + /// + /// indices passed to the corresponding SparseSegmentSqrtN op. + /// + /// + /// segment_ids passed to the corresponding SparseSegmentSqrtN op. + /// + /// + /// dimension 0 of "data" passed to SparseSegmentSqrtN op. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentSqrtNGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Returns tensor "output" with same shape as grad, except for dimension 0 whose + /// value is output_dim0. + /// + public static Tensor sparse_segment_sqrt_n_grad(Tensor grad, Tensor indices, Tensor segment_ids, Tensor output_dim0, string name = "SparseSegmentSqrtNGrad") + { + var dict = new Dictionary(); + dict["grad"] = grad; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + dict["output_dim0"] = output_dim0; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentSqrtNGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along sparse segments of a tensor divided by the sqrt of N. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// Should equal the number of distinct segment IDs. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentSqrtNWithNumSegments'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// N is the size of the segment being reduced. + /// + /// Like SparseSegmentSqrtN, but allows missing ids in segment_ids. If an id is + /// misisng, the output tensor at that position will be zeroed. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + public static Tensor sparse_segment_sqrt_n_with_num_segments(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string name = "SparseSegmentSqrtNWithNumSegments") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentSqrtNWithNumSegments", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along sparse segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentSum'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size k, the number of segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Like SegmentSum, but segment_ids can have rank less than data's first + /// dimension, selecting a subset of dimension 0, specified by indices. + /// + /// For example: + /// + /// + /// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + /// + /// # Select two rows, one segment. + /// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) + /// # =&gt; [[0 0 0 0]] + /// + /// # Select two rows, two segment. + /// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) + /// # =&gt; [[ 1 2 3 4] + /// # [-1 -2 -3 -4]] + /// + /// # Select all rows, two segments. + /// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) + /// # =&gt; [[0 0 0 0] + /// # [5 6 7 8]] + /// + /// # Which is equivalent to: + /// tf.segment_sum(c, tf.constant([0, 0, 1])) + /// + /// + public static Tensor sparse_segment_sum(Tensor data, Tensor indices, Tensor segment_ids, string name = "SparseSegmentSum") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentSum", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along sparse segments of a tensor. + /// + /// + /// + /// + /// A 1-D tensor. Has same rank as segment_ids. + /// + /// + /// A 1-D tensor. Values should be sorted and can be repeated. + /// + /// + /// Should equal the number of distinct segment IDs. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSegmentSumWithNumSegments'. + /// + /// + /// Has same shape as data, except for dimension 0 which + /// has size num_segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Like SparseSegmentSum, but allows missing ids in segment_ids. If an id is + /// misisng, the output tensor at that position will be zeroed. + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// For example: + /// + /// + /// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) + /// + /// tf.sparse_segment_sum_with_num_segments( + /// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) + /// # =&gt; [[0 0 0 0] + /// # [0 0 0 0] + /// # [0 0 0 0]] + /// + /// tf.sparse_segment_sum_with_num_segments(c, + /// tf.constant([0, 1]), + /// tf.constant([0, 2], + /// num_segments=4)) + /// # =&gt; [[ 1 2 3 4] + /// # [ 0 0 0 0] + /// # [-1 -2 -3 -4] + /// # [ 0 0 0 0]] + /// + /// + public static Tensor sparse_segment_sum_with_num_segments(Tensor data, Tensor indices, Tensor segment_ids, Tensor num_segments, string name = "SparseSegmentSumWithNumSegments") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["indices"] = indices; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("SparseSegmentSumWithNumSegments", name: name, keywords: dict); + return op.output; + } + + /// + /// Slice a SparseTensor based on the start and size. + /// + /// + /// 2-D tensor represents the indices of the sparse tensor. + /// + /// + /// 1-D tensor represents the values of the sparse tensor. + /// + /// + /// 1-D. tensor represents the shape of the sparse tensor. + /// + /// + /// 1-D. tensor represents the start of the slice. + /// + /// + /// 1-D. tensor represents the size of the slice. + /// output indices: A list of 1-D tensors represents the indices of the output + /// sparse tensors. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSlice'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : + /// output_values : A list of 1-D tensors represents the values of the output sparse + /// tensors. + /// output_shape : A list of 1-D tensors represents the shape of the output sparse + /// tensors. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// For example, if the input is + /// + /// input_tensor = shape = [2, 7] + /// [ a d e ] + /// [b c ] + /// + /// Graphically the output tensors are: + /// + /// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] + /// [ a ] + /// [b c ] + /// + /// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] + /// [ d e ] + /// [ ] + /// + public static (Tensor output_indices, Tensor output_values, Tensor output_shape) sparse_slice(Tensor indices, Tensor values, Tensor shape, Tensor start, Tensor size, string name = "SparseSlice") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["values"] = values; + dict["shape"] = shape; + dict["start"] = start; + dict["size"] = size; + var op = tf.OpDefLib._apply_op_helper("SparseSlice", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + var output_shape = op.outputs[_idx++]; + return (output_indices, output_values, output_shape); + } + + /// + /// The gradient operator for the SparseSlice op. + /// + /// + /// 1-D. The gradient with respect to + /// the non-empty values of the sliced SparseTensor. + /// + /// + /// 2-D. The indices of the input SparseTensor. + /// + /// + /// 1-D. tensor represents the start of the slice. + /// + /// + /// 2-D. The indices of the sliced SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSliceGrad'. + /// + /// + /// 1-D. The gradient with respect to the non-empty values of input SparseTensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op takes in the upstream gradient w.r.t. non-empty values of + /// the sliced SparseTensor, and outputs the gradients w.r.t. + /// the non-empty values of input SparseTensor. + /// + public static Tensor sparse_slice_grad(Tensor backprop_val_grad, Tensor input_indices, Tensor input_start, Tensor output_indices, string name = "SparseSliceGrad") + { + var dict = new Dictionary(); + dict["backprop_val_grad"] = backprop_val_grad; + dict["input_indices"] = input_indices; + dict["input_start"] = input_start; + dict["output_indices"] = output_indices; + var op = tf.OpDefLib._apply_op_helper("SparseSliceGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies softmax to a batched N-D SparseTensor. + /// + /// + /// 2-D. NNZ x R matrix with the indices of non-empty values in a + /// SparseTensor, in canonical ordering. + /// + /// + /// 1-D. NNZ non-empty values corresponding to sp_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSoftmax'. + /// + /// + /// 1-D. The NNZ values for the result SparseTensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The inputs represent an N-D SparseTensor with logical shape [..., B, C] + /// (where N &gt;= 2), and with indices sorted in the canonical lexicographic order. + /// + /// This op is equivalent to applying the normal tf.nn.softmax() to each innermost + /// logical submatrix with shape [B, C], but with the catch that *the implicitly + /// zero elements do not participate*. Specifically, the algorithm is equivalent + /// to the following: + /// + /// (1) Applies tf.nn.softmax() to a densified view of each innermost submatrix + /// with shape [B, C], along the size-C dimension; + /// (2) Masks out the original implicitly-zero locations; + /// (3) Renormalizes the remaining elements. + /// + /// Hence, the SparseTensor result has exactly the same non-zero indices and + /// shape. + /// + public static Tensor sparse_softmax(Tensor sp_indices, Tensor sp_values, Tensor sp_shape, string name = "SparseSoftmax") + { + var dict = new Dictionary(); + dict["sp_indices"] = sp_indices; + dict["sp_values"] = sp_values; + dict["sp_shape"] = sp_shape; + var op = tf.OpDefLib._apply_op_helper("SparseSoftmax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes softmax cross entropy cost and gradients to backpropagate. + /// + /// + /// batch_size x num_classes matrix + /// + /// + /// batch_size vector with values in [0, num_classes). + /// This is the label for the given minibatch entry. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSoftmaxCrossEntropyWithLogits'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// loss : Per example loss (batch_size vector). + /// backprop : backpropagated gradients (batch_size x num_classes matrix). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Unlike SoftmaxCrossEntropyWithLogits, this operation does not accept + /// a matrix of label probabilities, but rather a single label per row + /// of features. This label is considered to have probability 1.0 for the + /// given row. + /// + /// Inputs are the logits, not probabilities. + /// + public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits") + { + var dict = new Dictionary(); + dict["features"] = features; + dict["labels"] = labels; + var op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, keywords: dict); + int _idx = 0; + var loss = op.outputs[_idx++]; + var backprop = op.outputs[_idx++]; + return (loss, backprop); + } + + /// + /// Returns the element-wise max of two SparseTensors. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, in the canonical lexicographic ordering. + /// + /// + /// 1-D. N non-empty values corresponding to a_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// counterpart to a_indices for the other operand. + /// + /// + /// counterpart to a_values for the other operand; must be of the same dtype. + /// + /// + /// counterpart to a_shape for the other operand; the two shapes must be equal. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSparseMaximum'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. The indices of the output SparseTensor. + /// output_values : 1-D. The values of the output SparseTensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. + /// + public static (Tensor output_indices, Tensor output_values) sparse_sparse_maximum(Tensor a_indices, Tensor a_values, Tensor a_shape, Tensor b_indices, Tensor b_values, Tensor b_shape, string name = "SparseSparseMaximum") + { + var dict = new Dictionary(); + dict["a_indices"] = a_indices; + dict["a_values"] = a_values; + dict["a_shape"] = a_shape; + dict["b_indices"] = b_indices; + dict["b_values"] = b_values; + dict["b_shape"] = b_shape; + var op = tf.OpDefLib._apply_op_helper("SparseSparseMaximum", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + return (output_indices, output_values); + } + + /// + /// Returns the element-wise min of two SparseTensors. + /// + /// + /// 2-D. N x R matrix with the indices of non-empty values in a + /// SparseTensor, in the canonical lexicographic ordering. + /// + /// + /// 1-D. N non-empty values corresponding to a_indices. + /// + /// + /// 1-D. Shape of the input SparseTensor. + /// + /// + /// counterpart to a_indices for the other operand. + /// + /// + /// counterpart to a_values for the other operand; must be of the same dtype. + /// + /// + /// counterpart to a_shape for the other operand; the two shapes must be equal. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSparseMinimum'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : 2-D. The indices of the output SparseTensor. + /// output_values : 1-D. The values of the output SparseTensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. + /// + public static (Tensor output_indices, Tensor output_values) sparse_sparse_minimum(Tensor a_indices, Tensor a_values, Tensor a_shape, Tensor b_indices, Tensor b_values, Tensor b_shape, string name = "SparseSparseMinimum") + { + var dict = new Dictionary(); + dict["a_indices"] = a_indices; + dict["a_values"] = a_values; + dict["a_shape"] = a_shape; + dict["b_indices"] = b_indices; + dict["b_values"] = b_values; + dict["b_shape"] = b_shape; + var op = tf.OpDefLib._apply_op_helper("SparseSparseMinimum", name: name, keywords: dict); + int _idx = 0; + var output_indices = op.outputs[_idx++]; + var output_values = op.outputs[_idx++]; + return (output_indices, output_values); + } + + /// + /// Split a SparseTensor into num_split tensors along one dimension. + /// + /// + /// 0-D. The dimension along which to split. Must be in the range + /// [0, rank(shape)). + /// + /// + /// 2-D tensor represents the indices of the sparse tensor. + /// + /// + /// 1-D tensor represents the values of the sparse tensor. + /// + /// + /// 1-D. tensor represents the shape of the sparse tensor. + /// output indices: A list of 1-D tensors represents the indices of the output + /// sparse tensors. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSplit'. + /// + /// + /// Optional argument + /// The number of ways to split. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_indices : + /// output_values : A list of 1-D tensors represents the values of the output sparse + /// tensors. + /// output_shape : A list of 1-D tensors represents the shape of the output sparse + /// tensors. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If the shape[split_dim] is not an integer multiple of num_split. Slices + /// [0 : shape[split_dim] % num_split] gets one extra dimension. + /// For example, if split_dim = 1 and num_split = 2 and the input is + /// + /// input_tensor = shape = [2, 7] + /// [ a d e ] + /// [b c ] + /// + /// Graphically the output tensors are: + /// + /// output_tensor[0] = shape = [2, 4] + /// [ a ] + /// [b c ] + /// + /// output_tensor[1] = shape = [2, 3] + /// [ d e ] + /// [ ] + /// + public static (Tensor[] output_indices, Tensor[] output_values, Tensor[] output_shape) sparse_split(Tensor split_dim, Tensor indices, Tensor values, Tensor shape, int num_split, string name = "SparseSplit") + { + var dict = new Dictionary(); + dict["split_dim"] = split_dim; + dict["indices"] = indices; + dict["values"] = values; + dict["shape"] = shape; + dict["num_split"] = num_split; + var op = tf.OpDefLib._apply_op_helper("SparseSplit", name: name, keywords: dict); + int _idx = 0; + var output_indices = Enumerable.Range(0, op.OutputListLength("output_indices")).Select(_ => op.outputs[_idx++]).ToArray(); + var output_values = Enumerable.Range(0, op.OutputListLength("output_values")).Select(_ => op.outputs[_idx++]).ToArray(); + var output_shape = Enumerable.Range(0, op.OutputListLength("output_shape")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output_indices, output_values, output_shape); + } + + /// + /// Adds up a SparseTensor and a dense Tensor, producing a dense Tensor. + /// + /// + /// 2-D. The indices of the SparseTensor, with shape [nnz, ndims]. + /// + /// + /// 1-D. The values of the SparseTensor, with shape [nnz]. + /// + /// + /// 1-D. The shape of the SparseTensor, with shape [ndims]. + /// + /// + /// ndims-D Tensor. With shape a_shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseTensorDenseAdd'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This Op does not require a_indices be sorted in standard lexicographic order. + /// + public static Tensor sparse_tensor_dense_add(Tensor a_indices, Tensor a_values, Tensor a_shape, Tensor b, string name = "SparseTensorDenseAdd") + { + var dict = new Dictionary(); + dict["a_indices"] = a_indices; + dict["a_values"] = a_values; + dict["a_shape"] = a_shape; + dict["b"] = b; + var op = tf.OpDefLib._apply_op_helper("SparseTensorDenseAdd", name: name, keywords: dict); + return op.output; + } + + /// + /// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". + /// + /// + /// 2-D. The indices of the SparseTensor, size [nnz, 2] Matrix. + /// + /// + /// 1-D. The values of the SparseTensor, size [nnz] Vector. + /// + /// + /// 1-D. The shape of the SparseTensor, size [2] Vector. + /// + /// + /// 2-D. A dense Matrix. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseTensorDenseMatMul'. + /// + /// + /// Use the adjoint of A in the matrix multiply. If A is complex, this + /// is transpose(conj(A)). Otherwise it's transpose(A). + /// + /// + /// Use the adjoint of B in the matrix multiply. If B is complex, this + /// is transpose(conj(B)). Otherwise it's transpose(B). + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// No validity checking is performed on the indices of A. However, the following + /// input format is recommended for optimal behavior: + /// + /// if adjoint_a == false: + /// A should be sorted in lexicographically increasing order. Use SparseReorder + /// if you're not sure. + /// if adjoint_a == true: + /// A should be sorted in order of increasing dimension 1 (i.e., "column major" + /// order instead of "row major" order). + /// + public static Tensor sparse_tensor_dense_mat_mul(Tensor a_indices, Tensor a_values, Tensor a_shape, Tensor b, bool? adjoint_a = null, bool? adjoint_b = null, string name = "SparseTensorDenseMatMul") + { + var dict = new Dictionary(); + dict["a_indices"] = a_indices; + dict["a_values"] = a_values; + dict["a_shape"] = a_shape; + dict["b"] = b; + if (adjoint_a.HasValue) + dict["adjoint_a"] = adjoint_a.Value; + if (adjoint_b.HasValue) + dict["adjoint_b"] = adjoint_b.Value; + var op = tf.OpDefLib._apply_op_helper("SparseTensorDenseMatMul", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that splits a SparseTensor into elements row-wise. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseTensorSliceDataset'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sparse_tensor_slice_dataset(Tensor indices, Tensor values, Tensor dense_shape, string name = "SparseTensorSliceDataset") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["values"] = values; + dict["dense_shape"] = dense_shape; + var op = tf.OpDefLib._apply_op_helper("SparseTensorSliceDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// 0-D, 1-D, or 2-D. sparse_indices[i] contains the complete + /// index where sparse_values[i] will be placed. + /// + /// + /// 1-D. Shape of the dense output tensor. + /// + /// + /// 1-D. Values corresponding to each row of sparse_indices, + /// or a scalar value to be used for all sparse indices. + /// + /// + /// Scalar value to set for indices not specified in + /// sparse_indices. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseToDense'. + /// + /// + /// If true, indices are checked to make sure they are sorted in + /// lexicographic order and that there are no repeats. + /// + /// + /// Dense output tensor of shape output_shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Builds an array dense with shape output_shape such that + /// + /// + /// # If sparse_indices is scalar + /// dense[i] = (i == sparse_indices ? sparse_values : default_value) + /// + /// # If sparse_indices is a vector, then for each i + /// dense[sparse_indices[i]] = sparse_values[i] + /// + /// # If sparse_indices is an n by d matrix, then for each i in [0, n) + /// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] + /// + /// + /// All other values in dense are set to default_value. If sparse_values is a + /// scalar, all sparse indices are set to this single value. + /// + /// Indices should be sorted in lexicographic order, and indices must not + /// contain any repeats. If validate_indices is true, these properties + /// are checked during execution. + /// + public static Tensor sparse_to_dense(Tensor sparse_indices, Tensor output_shape, Tensor sparse_values, Tensor default_value, bool? validate_indices = null, string name = "SparseToDense") + { + var dict = new Dictionary(); + dict["sparse_indices"] = sparse_indices; + dict["output_shape"] = output_shape; + dict["sparse_values"] = sparse_values; + dict["default_value"] = default_value; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("SparseToDense", name: name, keywords: dict); + return op.output; + } + + /// + /// Applies set operation along last dimension of 2 SparseTensor inputs. + /// + /// + /// 2D Tensor, indices of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, values of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, shape of a SparseTensor. set1_shape[0...n-1] must + /// be the same as set2_shape[0...n-1], set1_shape[n] is the + /// max set size across 0...n-1 dimensions. + /// + /// + /// 2D Tensor, indices of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, values of a SparseTensor. Must be in row-major + /// order. + /// + /// + /// 1D Tensor, shape of a SparseTensor. set2_shape[0...n-1] must + /// be the same as set1_shape[0...n-1], set2_shape[n] is the + /// max set size across 0...n-1 dimensions. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseToSparseSetOperation'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// result_indices : 2D indices of a SparseTensor. + /// result_values : 1D values of a SparseTensor. + /// result_shape : 1D Tensor shape of a SparseTensor. result_shape[0...n-1] is + /// the same as the 1st n-1 dimensions of set1 and set2, result_shape[n] + /// is the max result set size across all 0...n-1 dimensions. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See SetOperationOp::SetOperationFromContext for values of set_operation. + /// + /// If validate_indices is True, SparseToSparseSetOperation validates the + /// order and range of set1 and set2 indices. + /// + /// Input set1 is a SparseTensor represented by set1_indices, set1_values, + /// and set1_shape. For set1 ranked n, 1st n-1 dimensions must be the same + /// as set2. Dimension n contains values in a set, duplicates are allowed but + /// ignored. + /// + /// Input set2 is a SparseTensor represented by set2_indices, set2_values, + /// and set2_shape. For set2 ranked n, 1st n-1 dimensions must be the same + /// as set1. Dimension n contains values in a set, duplicates are allowed but + /// ignored. + /// + /// If validate_indices is True, this op validates the order and range of set1 + /// and set2 indices. + /// + /// Output result is a SparseTensor represented by result_indices, + /// result_values, and result_shape. For set1 and set2 ranked n, this + /// has rank n and the same 1st n-1 dimensions as set1 and set2. The nth + /// dimension contains the result of set_operation applied to the corresponding + /// [0...n-1] dimension of set. + /// + public static (Tensor result_indices, Tensor result_values, Tensor result_shape) sparse_to_sparse_set_operation(Tensor set1_indices, Tensor set1_values, Tensor set1_shape, Tensor set2_indices, Tensor set2_values, Tensor set2_shape, string set_operation, bool? validate_indices = null, string name = "SparseToSparseSetOperation") + { + var dict = new Dictionary(); + dict["set1_indices"] = set1_indices; + dict["set1_values"] = set1_values; + dict["set1_shape"] = set1_shape; + dict["set2_indices"] = set2_indices; + dict["set2_values"] = set2_values; + dict["set2_shape"] = set2_shape; + dict["set_operation"] = set_operation; + if (validate_indices.HasValue) + dict["validate_indices"] = validate_indices.Value; + var op = tf.OpDefLib._apply_op_helper("SparseToSparseSetOperation", name: name, keywords: dict); + int _idx = 0; + var result_indices = op.outputs[_idx++]; + var result_values = op.outputs[_idx++]; + var result_shape = op.outputs[_idx++]; + return (result_indices, result_values, result_shape); + } + + /// + /// Splits a tensor into num_split tensors along one dimension. + /// + /// + /// 0-D. The dimension along which to split. Must be in the range + /// [-rank(value), rank(value)). + /// + /// + /// The tensor to split. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Split'. + /// + /// + /// Optional argument + /// The number of ways to split. Must evenly divide + /// value.shape[split_dim]. + /// + /// + /// They are identically shaped tensors, whose shape matches that of value + /// except along axis, where their sizes are + /// values.shape[split_dim] / num_split. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] split(Tensor split_dim, Tensor value, int num_split, string name = "Split") + { + var dict = new Dictionary(); + dict["split_dim"] = split_dim; + dict["value"] = value; + dict["num_split"] = num_split; + var op = tf.OpDefLib._apply_op_helper("Split", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// Splits a tensor into num_split tensors along one dimension. + /// + /// + /// The tensor to split. + /// + /// + /// list containing the sizes of each output tensor along the split + /// dimension. Must sum to the dimension of value along split_dim. + /// Can contain one -1 indicating that dimension is to be inferred. + /// + /// + /// 0-D. The dimension along which to split. Must be in the range + /// [-rank(value), rank(value)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SplitV'. + /// + /// + /// Optional argument + /// + /// + /// Tensors whose shape matches that of value + /// except along axis, where their sizes are + /// size_splits[i]. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] split_v(Tensor value, Tensor size_splits, Tensor split_dim, int num_split, string name = "SplitV") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["size_splits"] = size_splits; + dict["split_dim"] = split_dim; + dict["num_split"] = num_split; + var op = tf.OpDefLib._apply_op_helper("SplitV", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// Creates a dataset that executes a SQL query and emits rows of the result set. + /// + /// + /// The database type. Currently, the only supported type is 'sqlite'. + /// + /// + /// A connection string to connect to the database. + /// + /// + /// A SQL query to execute. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SqlDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor sql_dataset(Tensor driver_name, Tensor data_source_name, Tensor query, TF_DataType[] output_types, Shape[] output_shapes, string name = "SqlDataset") + { + var dict = new Dictionary(); + dict["driver_name"] = driver_name; + dict["data_source_name"] = data_source_name; + dict["query"] = query; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("SqlDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes square root of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sqrt'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = \sqrt{x} = x^{1/2}\\). + /// + public static Tensor sqrt(Tensor x, string name = "Sqrt") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Sqrt", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient for the sqrt of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SqrtGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = dy * 0.5 / y, where y = sqrt(x), and dy + /// is the corresponding input gradient. + /// + public static Tensor sqrt_grad(Tensor y, Tensor dy, string name = "SqrtGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("SqrtGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes square of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Square'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// I.e., \\(y = x * x = x^2\\). + /// + public static Tensor square(Tensor x, string name = "Square") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Square", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns (x - y)(x - y) element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SquaredDifference'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: SquaredDifference supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor squared_difference(Tensor x, Tensor y, string name = "SquaredDifference") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("SquaredDifference", name: name, keywords: dict); + return op.output; + } + + /// + /// Removes dimensions of size 1 from the shape of a tensor. + /// + /// + /// The input to squeeze. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Squeeze'. + /// + /// + /// If specified, only squeezes the dimensions listed. The dimension + /// index starts at 0. It is an error to squeeze a dimension that is not 1. Must + /// be in the range [-rank(input), rank(input)). + /// + /// + /// Contains the same data as input, but has one or more dimensions of + /// size 1 removed. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor input, this operation returns a tensor of the same type with + /// all dimensions of size 1 removed. If you don't want to remove all size 1 + /// dimensions, you can remove specific size 1 dimensions by specifying + /// axis. + /// + /// For example: + /// + /// + /// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] + /// shape(squeeze(t)) ==&gt; [2, 3] + /// + /// + /// Or, to remove specific size 1 dimensions: + /// + /// + /// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] + /// shape(squeeze(t, [2, 4])) ==&gt; [1, 2, 3, 1] + /// + /// + public static Tensor squeeze(Tensor input, int[] squeeze_dims = null, string name = "Squeeze") + { + var dict = new Dictionary(); + dict["input"] = input; + if (squeeze_dims != null) + dict["squeeze_dims"] = squeeze_dims; + var op = tf.OpDefLib._apply_op_helper("Squeeze", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated, use StackV2. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Stack'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack(TF_DataType elem_type, string stack_name = null, string name = "Stack") + { + var dict = new Dictionary(); + dict["elem_type"] = elem_type; + if (stack_name != null) + dict["stack_name"] = stack_name; + var op = tf.OpDefLib._apply_op_helper("Stack", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated, use StackCloseV2. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackClose'. + /// + /// + /// Returns the description of the operation + /// + public static Operation stack_close(Tensor handle, string name = "StackClose") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("StackClose", name: name, keywords: dict); + return op; + } + + /// + /// Delete the stack from its resource container. + /// + /// + /// The handle to a stack. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackCloseV2'. + /// + /// + /// Returns the description of the operation + /// + public static Operation stack_close_v2(Tensor handle, string name = "StackCloseV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("StackCloseV2", name: name, keywords: dict); + return op; + } + + /// + /// Deprecated, use StackPopV2. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackPop'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack_pop(Tensor handle, TF_DataType elem_type, string name = "StackPop") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["elem_type"] = elem_type; + var op = tf.OpDefLib._apply_op_helper("StackPop", name: name, keywords: dict); + return op.output; + } + + /// + /// Pop the element at the top of the stack. + /// + /// + /// The handle to a stack. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackPopV2'. + /// + /// + /// Optional argument + /// The type of the elem that is popped. + /// + /// + /// The tensor that is popped from the top of the stack. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack_pop_v2(Tensor handle, TF_DataType elem_type, string name = "StackPopV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["elem_type"] = elem_type; + var op = tf.OpDefLib._apply_op_helper("StackPopV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated, use StackPushV2. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackPush'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack_push(Tensor handle, Tensor elem, bool? swap_memory = null, string name = "StackPush") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["elem"] = elem; + if (swap_memory.HasValue) + dict["swap_memory"] = swap_memory.Value; + var op = tf.OpDefLib._apply_op_helper("StackPush", name: name, keywords: dict); + return op.output; + } + + /// + /// Push an element onto the stack. + /// + /// + /// The handle to a stack. + /// + /// + /// The tensor to be pushed onto the stack. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackPushV2'. + /// + /// + /// Swap elem to CPU. Default to false. + /// + /// + /// The same tensor as the input 'elem'. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack_push_v2(Tensor handle, Tensor elem, bool? swap_memory = null, string name = "StackPushV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["elem"] = elem; + if (swap_memory.HasValue) + dict["swap_memory"] = swap_memory.Value; + var op = tf.OpDefLib._apply_op_helper("StackPushV2", name: name, keywords: dict); + return op.output; + } + + /// + /// A stack that produces elements in first-in last-out order. + /// + /// + /// The maximum size of the stack if non-negative. If negative, the stack + /// size is unlimited. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StackV2'. + /// + /// + /// Optional argument + /// The type of the elements on the stack. + /// + /// + /// Overrides the name used for the temporary stack resource. Default + /// value is the name of the 'Stack' op (which is guaranteed unique). + /// + /// + /// The handle to the stack. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stack_v2(Tensor max_size, TF_DataType elem_type, string stack_name = null, string name = "StackV2") + { + var dict = new Dictionary(); + dict["max_size"] = max_size; + dict["elem_type"] = elem_type; + if (stack_name != null) + dict["stack_name"] = stack_name; + var op = tf.OpDefLib._apply_op_helper("StackV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Stage values similar to a lightweight Enqueue. + /// + /// + /// a list of tensors + /// dtypes A list of data types that inserted values should adhere to. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Stage'. + /// + /// + /// Maximum number of elements in the Staging Area. If &gt; 0, inserts + /// on the container will block when the capacity is reached. + /// + /// + /// The maximum number of bytes allowed for Tensors in the Staging Area. + /// If &gt; 0, inserts will block until sufficient space is available. + /// + /// + /// If non-empty, this queue is placed in the given container. Otherwise, + /// a default container is used. + /// + /// + /// It is necessary to match this name to the matching Unstage Op. + /// + /// + /// Returns the description of the operation + /// + /// + /// The basic functionality of this Op is similar to a queue with many + /// fewer capabilities and options. This Op is optimized for performance. + /// + public static Operation stage(Tensor[] values, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "Stage") + { + var dict = new Dictionary(); + dict["values"] = values; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("Stage", name: name, keywords: dict); + return op; + } + + /// + /// Op removes all elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StageClear'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// Returns the description of the operation + /// + public static Operation stage_clear(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "StageClear") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("StageClear", name: name, keywords: dict); + return op; + } + + /// + /// Op peeks at the values at the specified index. If the + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StagePeek'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// underlying container does not contain sufficient elements + /// this op will block until it does. This Op is optimized for + /// performance. + /// + public static Tensor[] stage_peek(Tensor index, TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "StagePeek") + { + var dict = new Dictionary(); + dict["index"] = index; + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("StagePeek", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Op returns the number of elements in the underlying container. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StageSize'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stage_size(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "StageSize") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("StageSize", name: name, keywords: dict); + return op.output; + } + + /// + /// Draws samples from a multinomial distribution. + /// + /// + /// 2-D Tensor with shape [batch_size, num_classes]. Each slice [i, :] + /// represents the unnormalized log probabilities for all classes. + /// + /// + /// 0-D. Number of independent samples to draw for each row slice. + /// + /// + /// 2 seeds (shape [2]). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatelessMultinomial'. + /// + /// + /// + /// + /// 2-D Tensor with shape [batch_size, num_samples]. Each slice [i, :] + /// contains the drawn class labels with range [0, num_classes). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stateless_multinomial(Tensor logits, Tensor num_samples, Tensor seed, TF_DataType? output_dtype = null, string name = "StatelessMultinomial") + { + var dict = new Dictionary(); + dict["logits"] = logits; + dict["num_samples"] = num_samples; + dict["seed"] = seed; + if (output_dtype.HasValue) + dict["output_dtype"] = output_dtype.Value; + var op = tf.OpDefLib._apply_op_helper("StatelessMultinomial", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs deterministic pseudorandom values from a normal distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// 2 seeds (shape [2]). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatelessRandomNormal'. + /// + /// + /// The type of the output. + /// + /// + /// Random values with specified shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values will have mean 0 and standard deviation 1. + /// + /// The outputs are a deterministic function of shape and seed. + /// + public static Tensor stateless_random_normal(Tensor shape, Tensor seed, TF_DataType? dtype = null, string name = "StatelessRandomNormal") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["seed"] = seed; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("StatelessRandomNormal", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs deterministic pseudorandom random values from a uniform distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// 2 seeds (shape [2]). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatelessRandomUniform'. + /// + /// + /// The type of the output. + /// + /// + /// Random values with specified shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values follow a uniform distribution in the range [0, 1). The + /// lower bound 0 is included in the range, while the upper bound 1 is excluded. + /// + /// The outputs are a deterministic function of shape and seed. + /// + public static Tensor stateless_random_uniform(Tensor shape, Tensor seed, TF_DataType? dtype = null, string name = "StatelessRandomUniform") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["seed"] = seed; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("StatelessRandomUniform", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs deterministic pseudorandom values from a truncated normal distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// 2 seeds (shape [2]). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatelessTruncatedNormal'. + /// + /// + /// The type of the output. + /// + /// + /// Random values with specified shape. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values follow a normal distribution with mean 0 and standard + /// deviation 1, except that values whose magnitude is more than 2 standard + /// deviations from the mean are dropped and re-picked. + /// + /// The outputs are a deterministic function of shape and seed. + /// + public static Tensor stateless_truncated_normal(Tensor shape, Tensor seed, TF_DataType? dtype = null, string name = "StatelessTruncatedNormal") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["seed"] = seed; + if (dtype.HasValue) + dict["dtype"] = dtype.Value; + var op = tf.OpDefLib._apply_op_helper("StatelessTruncatedNormal", name: name, keywords: dict); + return op.output; + } + + /// + /// Replaces the match of pattern in input with rewrite. + /// + /// + /// The text to be processed. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StaticRegexReplace'. + /// + /// + /// Optional argument + /// The regular expression to match the input. + /// + /// + /// Optional argument + /// The rewrite to be applied to the matched expresion. + /// + /// + /// If True, the replacement is global, otherwise the replacement + /// is done only on the first match. + /// + /// + /// The text after applying pattern and rewrite. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) + /// + public static Tensor static_regex_replace(Tensor input, string pattern, string rewrite, bool? replace_global = null, string name = "StaticRegexReplace") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["pattern"] = pattern; + dict["rewrite"] = rewrite; + if (replace_global.HasValue) + dict["replace_global"] = replace_global.Value; + var op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a statistics manager resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatsAggregatorHandle'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stats_aggregator_handle(string container = null, string shared_name = null, string name = "StatsAggregatorHandle") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("StatsAggregatorHandle", name: name, keywords: dict); + return op.output; + } + + /// + /// Produces a summary of any statistics recorded by the given statistics manager. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StatsAggregatorSummary'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor stats_aggregator_summary(Tensor iterator, string name = "StatsAggregatorSummary") + { + var dict = new Dictionary(); + dict["iterator"] = iterator; + var op = tf.OpDefLib._apply_op_helper("StatsAggregatorSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Stops gradient computation. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StopGradient'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, this op prevents the contribution of + /// its inputs to be taken into account. Normally, the gradient generator adds ops + /// to a graph to compute the derivatives of a specified 'loss' by recursively + /// finding out inputs that contributed to its computation. If you insert this op + /// in the graph it inputs are masked from the gradient generator. They are not + /// taken into account for computing gradients. + /// + /// This is useful any time you want to compute a value with TensorFlow but need + /// to pretend that the value was a constant. Some examples include: + /// + /// * The *EM* algorithm where the *M-step* should not involve backpropagation + /// through the output of the *E-step*. + /// * Contrastive divergence training of Boltzmann machines where, when + /// differentiating the energy function, the training must not backpropagate + /// through the graph that generated the samples from the model. + /// * Adversarial training, where no backprop should happen through the adversarial + /// example generation process. + /// + public static Tensor stop_gradient(Tensor input, string name = "StopGradient") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("StopGradient", name: name, keywords: dict); + return op.output; + } + + /// + /// Return a strided slice from input. + /// + /// + /// + /// + /// begin[k] specifies the offset into the kth range specification. + /// The exact dimension this corresponds to will be determined by context. + /// Out-of-bounds values will be silently clamped. If the kth bit of + /// begin_mask then begin[k] is ignored and the full range of the + /// appropriate dimension is used instead. Negative values causes indexing + /// to start from the highest element e.g. If foo==[1,2,3] then foo[-1]==3. + /// + /// + /// end[i] is like begin with the exception that end_mask is + /// used to determine full ranges. + /// + /// + /// strides[i] specifies the increment in the ith specification + /// after extracting a given element. Negative indices will reverse + /// the original order. Out or range values are + /// clamped to [0,dim[i]) if slice[i]&gt;0 or [-1,dim[i]-1] if slice[i] &lt; 0 + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StridedSlice'. + /// + /// + /// a bitmask where a bit i being 1 means to ignore the begin + /// value and instead use the largest interval possible. At runtime + /// begin[i] will be replaced with [0, n-1) if stride[i] &gt; 0 or + /// [-1, n-1] if stride[i] &lt; 0 + /// + /// + /// analogous to begin_mask + /// + /// + /// a bitmask where bit i being 1 means the ith + /// position is actually an ellipsis. One bit at most can be 1. + /// If ellipsis_mask == 0, then an implicit ellipsis mask of 1 &lt;&lt; (m+1) + /// is provided. This means that foo[3:5] == foo[3:5, ...]. An ellipsis + /// implicitly creates as many range specifications as necessary to fully + /// specify the sliced range for every dimension. For example for a 4-dimensional + /// tensor foo the slice foo[2, ..., 5:8] implies foo[2, :, :, 5:8]. + /// + /// + /// a bitmask where bit i being 1 means the ith + /// specification creates a new shape 1 dimension. For example + /// foo[:4, tf.newaxis, :2] would produce a shape (4, 1, 2) tensor. + /// + /// + /// a bitmask where bit i implies that the ith + /// specification should shrink the dimensionality. begin and end + /// must imply a slice of size 1 in the dimension. For example in + /// python one might do foo[:, 3, :] which would result in + /// shrink_axis_mask being 2. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Note, most python users will want to use the Python Tensor.__getitem__ + /// or Variable.__getitem__ rather than this op directly. + /// + /// The goal of this op is to produce a new tensor with a subset of + /// the elements from the n dimensional input tensor. The subset is chosen using + /// a sequence of m sparse range specifications encoded into the arguments + /// of this function. Note, in some cases + /// m could be equal to n, but this need not be the case. Each + /// range specification entry can be one of the following: + /// + /// - An ellipsis (...). Ellipses are used to imply zero or more + /// dimensions of full-dimension selection and are produced using + /// ellipsis_mask. For example, foo[...] is the identity slice. + /// + /// - A new axis. This is used to insert a new shape=1 dimension and is + /// produced using new_axis_mask. For example, foo[:, ...] where + /// foo is shape (3, 4) produces a (1, 3, 4) tensor. + /// + /// + /// - A range begin:end:stride. This is used to specify how much to choose from + /// a given dimension. stride can be any integer but 0. begin is an integer + /// which represents the index of the first value to select while end represents + /// the index of the last value to select. The number of values selected in each + /// dimension is end - begin if stride &gt; 0 and begin - end if stride &lt; 0. + /// begin and end can be negative where -1 is the last element, -2 is + /// the second to last. begin_mask controls whether to replace the explicitly + /// given begin with an implicit effective value of 0 if stride &gt; 0 and + /// -1 if stride &lt; 0. end_mask is analogous but produces the number + /// required to create the largest open interval. For example, given a shape + /// (3,) tensor foo[:], the effective begin and end are 0 and 3. Do + /// not assume this is equivalent to foo[0:-1] which has an effective begin + /// and end of 0 and 2. Another example is foo[-2::-1] which reverses the + /// first dimension of a tensor while dropping the last two (in the original + /// order elements). For example foo = [1,2,3,4]; foo[-2::-1] is [4,3]. + /// + /// - A single index. This is used to keep only elements that have a given + /// index. For example (foo[2, :] on a shape (5,6) tensor produces a + /// shape (6,) tensor. This is encoded in begin and end and + /// shrink_axis_mask. + /// + /// Each conceptual range specification is encoded in the op's argument. This + /// encoding is best understand by considering a non-trivial example. In + /// particular, + /// foo[1, 2:4, None, ..., :-3:-1, :] will be encoded as + /// + /// + /// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) + /// end = [2, 4, x, x, -3, x] + /// strides = [1, 1, x, x, -1, 1] + /// begin_mask = 1&lt;&lt;4 | 1 &lt;&lt; 5 = 48 + /// end_mask = 1&lt;&lt;5 = 32 + /// ellipsis_mask = 1&lt;&lt;3 = 8 + /// new_axis_mask = 1&lt;&lt;2 4 + /// shrink_axis_mask = 1&lt;&lt;0 + /// + /// + /// In this case if foo.shape is (5, 5, 5, 5, 5, 5) the final shape of + /// the slice becomes (2, 1, 5, 5, 2, 5). + /// Let us walk step by step through each argument specification. + /// + /// 1. The first argument in the example slice is turned into begin = 1 and + /// end = begin + 1 = 2. To disambiguate from the original spec 2:4 we + /// also set the appropriate bit in shrink_axis_mask. + /// + /// 2. 2:4 is contributes 2, 4, 1 to begin, end, and stride. All masks have + /// zero bits contributed. + /// + /// 3. None is a synonym for tf.newaxis. This means insert a dimension of size 1 + /// dimension in the final shape. Dummy values are contributed to begin, + /// end and stride, while the new_axis_mask bit is set. + /// + /// 4. ... grab the full ranges from as many dimensions as needed to + /// fully specify a slice for every dimension of the input shape. + /// + /// 5. :-3:-1 shows the use of negative indices. A negative index i associated + /// with a dimension that has shape s is converted to a positive index + /// s + i. So -1 becomes s-1 (i.e. the last element). This conversion + /// is done internally so begin, end and strides receive x, -3, and -1. + /// The appropriate begin_mask bit is set to indicate the start range is the + /// full range (ignoring the x). + /// + /// 6. : indicates that the entire contents of the corresponding dimension + /// is selected. This is equivalent to :: or 0::1. begin, end, and strides + /// receive 0, 0, and 1, respectively. The appropriate bits in begin_mask and + /// end_mask are also set. + /// + /// *Requirements*: + /// 0 != strides[i] for i in [0, m) + /// ellipsis_mask must be a power of two (only one ellipsis) + /// + public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, int? begin_mask = null, int? end_mask = null, int? ellipsis_mask = null, int? new_axis_mask = null, int? shrink_axis_mask = null, string name = "StridedSlice") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["begin"] = begin; + dict["end"] = end; + dict["strides"] = strides; + if (begin_mask.HasValue) + dict["begin_mask"] = begin_mask.Value; + if (end_mask.HasValue) + dict["end_mask"] = end_mask.Value; + if (ellipsis_mask.HasValue) + dict["ellipsis_mask"] = ellipsis_mask.Value; + if (new_axis_mask.HasValue) + dict["new_axis_mask"] = new_axis_mask.Value; + if (shrink_axis_mask.HasValue) + dict["shrink_axis_mask"] = shrink_axis_mask.Value; + var op = tf.OpDefLib._apply_op_helper("StridedSlice", name: name, keywords: dict); + return op.output; + } + + /// + /// Assign value to the sliced l-value reference of ref. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StridedSliceAssign'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The values of value are assigned to the positions in the variable + /// ref that are selected by the slice parameters. The slice parameters + /// begin, end, strides, etc. work exactly as in StridedSlice. + /// + /// NOTE this op currently does not support broadcasting and so value's + /// shape must be exactly the shape produced by the slice of ref. + /// + public static Tensor strided_slice_assign(Tensor referecne, Tensor begin, Tensor end, Tensor strides, Tensor value, int? begin_mask = null, int? end_mask = null, int? ellipsis_mask = null, int? new_axis_mask = null, int? shrink_axis_mask = null, string name = "StridedSliceAssign") + { + var dict = new Dictionary(); + dict["ref"] = referecne; + dict["begin"] = begin; + dict["end"] = end; + dict["strides"] = strides; + dict["value"] = value; + if (begin_mask.HasValue) + dict["begin_mask"] = begin_mask.Value; + if (end_mask.HasValue) + dict["end_mask"] = end_mask.Value; + if (ellipsis_mask.HasValue) + dict["ellipsis_mask"] = ellipsis_mask.Value; + if (new_axis_mask.HasValue) + dict["new_axis_mask"] = new_axis_mask.Value; + if (shrink_axis_mask.HasValue) + dict["shrink_axis_mask"] = shrink_axis_mask.Value; + var op = tf.OpDefLib._apply_op_helper("StridedSliceAssign", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the gradient of StridedSlice. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StridedSliceGrad'. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Since StridedSlice cuts out pieces of its input which is size + /// shape, its gradient will have the same shape (which is passed here + /// as shape). The gradient will be zero in any element that the slice + /// does not select. + /// + /// Arguments are the same as StridedSliceGrad with the exception that + /// dy is the input gradient to be propagated and shape is the + /// shape of StridedSlice's input. + /// + public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, int? begin_mask = null, int? end_mask = null, int? ellipsis_mask = null, int? new_axis_mask = null, int? shrink_axis_mask = null, string name = "StridedSliceGrad") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["begin"] = begin; + dict["end"] = end; + dict["strides"] = strides; + dict["dy"] = dy; + if (begin_mask.HasValue) + dict["begin_mask"] = begin_mask.Value; + if (end_mask.HasValue) + dict["end_mask"] = end_mask.Value; + if (ellipsis_mask.HasValue) + dict["ellipsis_mask"] = ellipsis_mask.Value; + if (new_axis_mask.HasValue) + dict["new_axis_mask"] = new_axis_mask.Value; + if (shrink_axis_mask.HasValue) + dict["shrink_axis_mask"] = shrink_axis_mask.Value; + var op = tf.OpDefLib._apply_op_helper("StridedSliceGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Joins the strings in the given list of string tensors into one tensor; + /// + /// + /// A list of string tensors. The tensors must all have the same shape, + /// or be scalars. Scalars may be mixed in; these will be broadcast to the shape + /// of non-scalar inputs. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringJoin'. + /// + /// + /// string, an optional join separator. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// with the given separator (default is an empty separator). + /// + public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "StringJoin", name, inputs, "separator", separator)); + return result[0]; + } + var dict = new Dictionary(); + dict["inputs"] = inputs; + if (separator != null) + dict["separator"] = separator; + var op = tf.OpDefLib._apply_op_helper("StringJoin", name: name, keywords: dict); + return op.output; + } + + /// + /// Split elements of input based on delimiter into a SparseTensor. + /// + /// + /// 1-D. Strings to split. + /// + /// + /// 0-D. Delimiter characters (bytes), or empty string. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringSplit'. + /// + /// + /// A bool. If True, skip the empty strings from the result. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// indices : A dense matrix of int64 representing the indices of the sparse tensor. + /// values : A vector of strings corresponding to the splited values. + /// shape : a length-2 vector of int64 representing the shape of the sparse + /// tensor, where the first value is N and the second value is the maximum number + /// of tokens in a single input entry. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Let N be the size of source (typically N will be the batch size). Split each + /// element of input based on delimiter and return a SparseTensor + /// containing the splitted tokens. Empty tokens are ignored. + /// + /// delimiter can be empty, or a string of split characters. If delimiter is an + /// empty string, each element of input is split into individual single-byte + /// character strings, including splitting of UTF-8 multibyte sequences. Otherwise + /// every character of delimiter is a potential split point. + /// + /// For example: + /// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output + /// will be + /// + /// indices = [0, 0; + /// 0, 1; + /// 1, 0; + /// 1, 1; + /// 1, 2] + /// shape = [2, 3] + /// values = ['hello', 'world', 'a', 'b', 'c'] + /// + public static (Tensor indices, Tensor values, Tensor shape) string_split(Tensor input, Tensor delimiter, bool? skip_empty = null, string name = "StringSplit") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["delimiter"] = delimiter; + if (skip_empty.HasValue) + dict["skip_empty"] = skip_empty.Value; + var op = tf.OpDefLib._apply_op_helper("StringSplit", name: name, keywords: dict); + int _idx = 0; + var indices = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + var shape = op.outputs[_idx++]; + return (indices, values, shape); + } + + /// + /// Split elements of source based on sep into a SparseTensor. + /// + /// + /// 1-D string Tensor, the strings to split. + /// + /// + /// 0-D string Tensor, the delimiter character. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringSplitV2'. + /// + /// + /// An int. If maxsplit &gt; 0, limit of the split of the result. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// indices : + /// values : + /// shape : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Let N be the size of source (typically N will be the batch size). Split each + /// element of source based on sep and return a SparseTensor + /// containing the split tokens. Empty tokens are ignored. + /// + /// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', + /// then the output will be + /// + /// st.indices = [0, 0; + /// 0, 1; + /// 1, 0; + /// 1, 1; + /// 1, 2] + /// st.shape = [2, 3] + /// st.values = ['hello', 'world', 'a', 'b', 'c'] + /// + /// + /// If sep is given, consecutive delimiters are not grouped together and are + /// deemed to delimit empty strings. For example, source of "1&lt;&gt;2&lt;&gt;&lt;&gt;3" and + /// sep of "&lt;&gt;" returns ["1", "2", "", "3"]. If sep is None or an empty + /// string, consecutive whitespace are regarded as a single separator, and the + /// result will contain no empty strings at the startor end if the string has + /// leading or trailing whitespace. + /// + /// Note that the above mentioned behavior matches python's str.split. + /// + public static (Tensor indices, Tensor values, Tensor shape) string_split_v2(Tensor input, Tensor sep, int? maxsplit = null, string name = "StringSplitV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["sep"] = sep; + if (maxsplit.HasValue) + dict["maxsplit"] = maxsplit.Value; + var op = tf.OpDefLib._apply_op_helper("StringSplitV2", name: name, keywords: dict); + int _idx = 0; + var indices = op.outputs[_idx++]; + var values = op.outputs[_idx++]; + var shape = op.outputs[_idx++]; + return (indices, values, shape); + } + + /// + /// Strip leading and trailing whitespaces from the Tensor. + /// + /// + /// A string Tensor of any shape. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringStrip'. + /// + /// + /// A string Tensor of the same shape as the input. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor string_strip(Tensor input, string name = "StringStrip") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("StringStrip", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts each string in the input Tensor to its hash mod by a number of buckets. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringToHashBucket'. + /// + /// + /// Optional argument + /// The number of buckets. + /// + /// + /// A Tensor of the same shape as the input string_tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The hash function is deterministic on the content of the string within the + /// process. + /// + /// Note that the hash function may change from time to time. + /// This functionality will be deprecated and it's recommended to use + /// tf.string_to_hash_bucket_fast() or tf.string_to_hash_bucket_strong(). + /// + public static Tensor string_to_hash_bucket(Tensor string_tensor, int num_buckets, string name = "StringToHashBucket") + { + var dict = new Dictionary(); + dict["string_tensor"] = string_tensor; + dict["num_buckets"] = num_buckets; + var op = tf.OpDefLib._apply_op_helper("StringToHashBucket", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts each string in the input Tensor to its hash mod by a number of buckets. + /// + /// + /// The strings to assign a hash bucket. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringToHashBucketFast'. + /// + /// + /// Optional argument + /// The number of buckets. + /// + /// + /// A Tensor of the same shape as the input string_tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The hash function is deterministic on the content of the string within the + /// process and will never change. However, it is not suitable for cryptography. + /// This function may be used when CPU time is scarce and inputs are trusted or + /// unimportant. There is a risk of adversaries constructing inputs that all hash + /// to the same bucket. To prevent this problem, use a strong hash function with + /// tf.string_to_hash_bucket_strong. + /// + public static Tensor string_to_hash_bucket_fast(Tensor input, int num_buckets, string name = "StringToHashBucketFast") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["num_buckets"] = num_buckets; + var op = tf.OpDefLib._apply_op_helper("StringToHashBucketFast", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts each string in the input Tensor to its hash mod by a number of buckets. + /// + /// + /// The strings to assign a hash bucket. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringToHashBucketStrong'. + /// + /// + /// Optional argument + /// The number of buckets. + /// + /// + /// Optional argument + /// The key for the keyed hash function passed as a list of two uint64 + /// elements. + /// + /// + /// A Tensor of the same shape as the input string_tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The hash function is deterministic on the content of the string within the + /// process. The hash function is a keyed hash function, where attribute key + /// defines the key of the hash function. key is an array of 2 elements. + /// + /// A strong hash is important when inputs may be malicious, e.g. URLs with + /// additional components. Adversaries could try to make their inputs hash to the + /// same bucket for a denial-of-service attack or to skew the results. A strong + /// hash prevents this by making it difficult, if not infeasible, to compute inputs + /// that hash to the same bucket. This comes at a cost of roughly 4x higher compute + /// time than tf.string_to_hash_bucket_fast. + /// + public static Tensor string_to_hash_bucket_strong(Tensor input, int num_buckets, int[] key, string name = "StringToHashBucketStrong") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["num_buckets"] = num_buckets; + dict["key"] = key; + var op = tf.OpDefLib._apply_op_helper("StringToHashBucketStrong", name: name, keywords: dict); + return op.output; + } + + /// + /// Converts each string in the input Tensor to the specified numeric type. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'StringToNumber'. + /// + /// + /// The numeric type to interpret each string in string_tensor as. + /// + /// + /// A Tensor of the same shape as the input string_tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// (Note that int32 overflow results in an error while float overflow + /// results in a rounded value.) + /// + public static Tensor string_to_number(Tensor string_tensor, TF_DataType? out_type = null, string name = "StringToNumber") + { + var dict = new Dictionary(); + dict["string_tensor"] = string_tensor; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("StringToNumber", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x - y element-wise. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sub'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// *NOTE*: Subtract supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor sub(Tensor x, Tensor y, string name = "Sub") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("Sub", name: name, keywords: dict); + return op.output; + } + + /// + /// Return substrings from Tensor of strings. + /// + /// + /// Tensor of strings + /// + /// + /// Scalar defining the position of first character in each substring + /// + /// + /// Scalar defining the number of characters to include in each substring + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Substr'. + /// + /// + /// Tensor of substrings + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// For each string in the input Tensor, creates a substring starting at index + /// pos with a total length of len. + /// + /// If len defines a substring that would extend beyond the length of the input + /// string, then as many characters as possible are used. + /// + /// A negative pos indicates distance within the string backwards from the end. + /// + /// If pos specifies an index which is out of range for any of the input strings, + /// then an InvalidArgumentError is thrown. + /// + /// pos and len must have the same shape, otherwise a ValueError is thrown on + /// Op creation. + /// + /// *NOTE*: Substr supports broadcasting up to two dimensions. More about + /// broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + /// --- + /// + /// Examples + /// + /// Using scalar pos and len: + /// + /// + /// input = [b'Hello', b'World'] + /// position = 1 + /// length = 3 + /// + /// output = [b'ell', b'orl'] + /// + /// + /// Using pos and len with same shape as input: + /// + /// + /// input = [[b'ten', b'eleven', b'twelve'], + /// [b'thirteen', b'fourteen', b'fifteen'], + /// [b'sixteen', b'seventeen', b'eighteen']] + /// position = [[1, 2, 3], + /// [1, 2, 3], + /// [1, 2, 3]] + /// length = [[2, 3, 4], + /// [4, 3, 2], + /// [5, 5, 5]] + /// + /// output = [[b'en', b'eve', b'lve'], + /// [b'hirt', b'urt', b'te'], + /// [b'ixtee', b'vente', b'hteen']] + /// + /// + /// Broadcasting pos and len onto input: + /// + /// + /// input = [[b'ten', b'eleven', b'twelve'], + /// [b'thirteen', b'fourteen', b'fifteen'], + /// [b'sixteen', b'seventeen', b'eighteen'], + /// [b'nineteen', b'twenty', b'twentyone']] + /// position = [1, 2, 3] + /// length = [1, 2, 3] + /// + /// output = [[b'e', b'ev', b'lve'], + /// [b'h', b'ur', b'tee'], + /// [b'i', b've', b'hte'], + /// [b'i', b'en', b'nty']] + /// + /// + /// Broadcasting input onto pos and len: + /// + /// + /// input = b'thirteen' + /// position = [1, 5, 7] + /// length = [3, 2, 1] + /// + /// output = [b'hir', b'ee', b'n'] + /// + /// + public static Tensor substr(Tensor input, Tensor pos, Tensor len, string name = "Substr") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["pos"] = pos; + dict["len"] = len; + var op = tf.OpDefLib._apply_op_helper("Substr", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum of elements across dimensions of a tensor. + /// + /// + /// The tensor to reduce. + /// + /// + /// The dimensions to reduce. Must be in the range + /// [-rank(input), rank(input)). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sum'. + /// + /// + /// If true, retain reduced dimensions with length 1. + /// + /// + /// The reduced tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Reduces input along the dimensions given in axis. Unless + /// keep_dims is true, the rank of the tensor is reduced by 1 for each entry in + /// axis. If keep_dims is true, the reduced dimensions are + /// retained with length 1. + /// + public static Tensor sum(Tensor input, Tensor reduction_indices, bool? keep_dims = null, string name = "Sum") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["reduction_indices"] = reduction_indices; + if (keep_dims.HasValue) + dict["keep_dims"] = keep_dims.Value; + var op = tf.OpDefLib._apply_op_helper("Sum", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the singular value decompositions of one or more matrices. + /// + /// + /// A tensor of shape [..., M, N] whose inner-most 2 dimensions + /// form matrices of size [M, N]. Let P be the minimum of M and N. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Svd'. + /// + /// + /// If true, left and right singular vectors will be + /// computed and returned in u and v, respectively. + /// If false, u and v are not set and should never referenced. + /// + /// + /// If true, compute full-sized u and v. If false + /// (the default), compute only the leading P singular vectors. + /// Ignored if compute_uv is False. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// s : Singular values. Shape is [..., P]. + /// u : Left singular vectors. If full_matrices is False then shape is + /// [..., M, P]; if full_matrices is True then shape is + /// [..., M, M]. Undefined if compute_uv is False. + /// v : Left singular vectors. If full_matrices is False then shape is + /// [..., N, P]. If full_matrices is True then shape is [..., N, N]. + /// Undefined if compute_uv is false. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Computes the SVD of each inner matrix in input such that + /// input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :]) + /// + /// + /// # a is a tensor containing a batch of matrices. + /// # s is a tensor of singular values for each matrix. + /// # u is the tensor containing of left singular vectors for each matrix. + /// # v is the tensor containing of right singular vectors for each matrix. + /// s, u, v = svd(a) + /// s, _, _ = svd(a, compute_uv=False) + /// + /// + public static (Tensor s, Tensor u, Tensor v) svd(Tensor input, bool? compute_uv = null, bool? full_matrices = null, string name = "Svd") + { + var dict = new Dictionary(); + dict["input"] = input; + if (compute_uv.HasValue) + dict["compute_uv"] = compute_uv.Value; + if (full_matrices.HasValue) + dict["full_matrices"] = full_matrices.Value; + var op = tf.OpDefLib._apply_op_helper("Svd", name: name, keywords: dict); + int _idx = 0; + var s = op.outputs[_idx++]; + var u = op.outputs[_idx++]; + var v = op.outputs[_idx++]; + return (s, u, v); + } + + /// + /// Forwards data to the output port determined by pred. + /// + /// + /// The tensor to be forwarded to the appropriate output. + /// + /// + /// A scalar that specifies which output port will receive data. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Switch'. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_false : If pred is false, data will be forwarded to this output. + /// output_true : If pred is true, data will be forwarded to this output. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If pred is true, the data input is forwarded to output_true. Otherwise, + /// the data goes to output_false. + /// + /// See also RefSwitch and Merge. + /// + public static (Tensor output_false, Tensor output_true) switch_(Tensor data, Tensor pred, string name = "Switch") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["pred"] = pred; + var op = tf.OpDefLib._apply_op_helper("Switch", name: name, keywords: dict); + int _idx = 0; + var output_false = op.outputs[_idx++]; + var output_true = op.outputs[_idx++]; + return (output_false, output_true); + } + + /// + /// Creates a dataset that emits the records from one or more TFRecord files. + /// + /// + /// A scalar or vector containing the name(s) of the file(s) to be + /// read. + /// + /// + /// A scalar containing either (i) the empty string (no + /// compression), (ii) "ZLIB", or (iii) "GZIP". + /// + /// + /// A scalar representing the number of bytes to buffer. A value of + /// 0 means no buffering will be performed. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TFRecordDataset'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor t_f_record_dataset(Tensor filenames, Tensor compression_type, Tensor buffer_size, string name = "TFRecordDataset") + { + var dict = new Dictionary(); + dict["filenames"] = filenames; + dict["compression_type"] = compression_type; + dict["buffer_size"] = buffer_size; + var op = tf.OpDefLib._apply_op_helper("TFRecordDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the records from a TensorFlow Records file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TFRecordReader'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor t_f_record_reader(string container = null, string shared_name = null, string compression_type = null, string name = "TFRecordReader") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (compression_type != null) + dict["compression_type"] = compression_type; + var op = tf.OpDefLib._apply_op_helper("TFRecordReader", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the records from a TensorFlow Records file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TFRecordReaderV2'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor t_f_record_reader_v2(string container = null, string shared_name = null, string compression_type = null, string name = "TFRecordReaderV2") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + if (compression_type != null) + dict["compression_type"] = compression_type; + var op = tf.OpDefLib._apply_op_helper("TFRecordReaderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// An op enabling differentiation of TPU Embeddings. + /// + /// + /// A trainable variable, enabling optimizers to find this op. + /// + /// + /// The embedding activations Tensor to return. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingActivations'. + /// + /// + /// Optional argument + /// The id of the table in the embedding layer configuration from which + /// these activations were computed. + /// + /// + /// Optional argument + /// Identifier of the set of embedding indices which produced these + /// activations. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op simply returns its first input, which is assumed to have been sliced + /// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of this + /// op, and its first argument being a trainable Variable, enables automatic + /// differentiation of graphs containing embeddings via the TPU Embedding Python + /// libraries. + /// + public static Tensor t_p_u_embedding_activations(Tensor embedding_variable, Tensor sliced_activations, int table_id, int lookup_id, string name = "TPUEmbeddingActivations") + { + var dict = new Dictionary(); + dict["embedding_variable"] = embedding_variable; + dict["sliced_activations"] = sliced_activations; + dict["table_id"] = table_id; + dict["lookup_id"] = lookup_id; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingActivations", name: name, keywords: dict); + return op.output; + } + + /// + /// An op that feeds a batch of embedding indices and weights to the TPU. + /// + /// + /// A list of rank 1 Tensors specifying row indices of the COO + /// sparse matrix representing the embedding lookups for each table. + /// + /// + /// A list of rank 1 Tensors specifying column indices of the + /// COO sparse matrix representing the embedding lookups for each table. + /// + /// + /// A list of rank 1 Tensors specifying the nonzero values + /// of the COO sparse matrix representing the embedding lookups for each table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingEnqueueSparseBatch'. + /// + /// + /// The TPU device to use. This should be -1 when the Op + /// is running on a TPU device, and &gt;= 0 when the Op is running on the CPU + /// device. + /// + /// + /// Returns the description of the operation + /// + /// + /// Embedding lookups are equivalent to sparse-dense matrix multiplications: the + /// sparse matrix contains nonzeros in column j in order to retrieve row j from the + /// embedding table. + /// + /// The three Tensor list arguments (sample_indices, embedding_indices, and + /// aggregation_weights) represent these sparse matrices in COO format. The Tensor + /// lists each have one entry for each embedding table specified in the model. + /// For the kth embedding table, the three Tensors at position k in the list + /// specify a COO-format sparse matrix. For the kth table, the row indices, + /// column indices, and nonzero values of the COO sparse matrix are specified by + /// sample_indices[k], embedding_indices[k], and aggregation_weights[k], + /// respectively. Entries must be sorted by row index, then by column index. + /// + /// There should be at most one TPUEmbeddingEnqueueSparseBatch op in a signle + /// training step per TPU shard. + /// + public static Operation t_p_u_embedding_enqueue_sparse_batch(Tensor[] sample_indices, Tensor[] embedding_indices, Tensor[] aggregation_weights, int? device_ordinal = null, string name = "TPUEmbeddingEnqueueSparseBatch") + { + var dict = new Dictionary(); + dict["sample_indices"] = sample_indices; + dict["embedding_indices"] = embedding_indices; + dict["aggregation_weights"] = aggregation_weights; + if (device_ordinal.HasValue) + dict["device_ordinal"] = device_ordinal.Value; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingEnqueueSparseBatch", name: name, keywords: dict); + return op; + } + + /// + /// Load an embedding table shard into TensorNode memories for use with Adagrad. + /// + /// + /// The shard of the embedding table resident on the host executing this + /// op. For single-TPU models, this is the entire embedding table. + /// + /// + /// Shard of the Adagrad accumulators resident on the host executing + /// this op. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingLoadAdagradParameters'. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// Optional argument + /// The id of the table specified in the embedding_config. + /// + /// + /// Optional argument + /// The number of CPU hosts in the distributed training job. + /// + /// + /// Optional argument + /// Which CPU host in the distributed training job will execute this op. + /// + /// + /// Returns the description of the operation + /// + /// + /// TPU embeddings use dedicated per-optimizer Ops for loading and retrieving + /// trainable variables and optimizer state from TPU memory. This op enables + /// functionality equivalent to AdagradOptimizer. + /// + public static Operation t_p_u_embedding_load_adagrad_parameters(Tensor parameters, Tensor accumulators, string tpu_embedding_config, int table_id, int num_hosts, int host_id, string name = "TPUEmbeddingLoadAdagradParameters") + { + var dict = new Dictionary(); + dict["parameters"] = parameters; + dict["accumulators"] = accumulators; + dict["tpu_embedding_config"] = tpu_embedding_config; + dict["table_id"] = table_id; + dict["num_hosts"] = num_hosts; + dict["host_id"] = host_id; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingLoadAdagradParameters", name: name, keywords: dict); + return op; + } + + /// + /// Load an embedding table shard into TPU memory for use with GradientDescent. + /// + /// + /// The shard of the embedding table resident on the host executing this + /// op. For single-TPU models, this is the entire embedding table. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingLoadGradientDescentParameters'. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// Optional argument + /// The id of the table specified in the tpu_embedding_config. + /// + /// + /// Optional argument + /// The number of CPU hosts in the distributed training job. + /// + /// + /// Optional argument + /// Which CPU host in the distributed training job will execute this op. + /// + /// + /// Returns the description of the operation + /// + /// + /// TPU embeddings use dedicated per-optimizer Ops for loading and retrieving + /// trainable variables and optimizer state from TPU memory. This op enables + /// functionality equivalent to GradientDescentOptimizer. + /// + public static Operation t_p_u_embedding_load_gradient_descent_parameters(Tensor parameters, string tpu_embedding_config, int table_id, int num_hosts, int host_id, string name = "TPUEmbeddingLoadGradientDescentParameters") + { + var dict = new Dictionary(); + dict["parameters"] = parameters; + dict["tpu_embedding_config"] = tpu_embedding_config; + dict["table_id"] = table_id; + dict["num_hosts"] = num_hosts; + dict["host_id"] = host_id; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingLoadGradientDescentParameters", name: name, keywords: dict); + return op; + } + + /// + /// An op that receives embedding activations on the TPU. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingReceiveActivations'. + /// + /// + /// Optional argument + /// The number of output activation tensors, equal to the number of + /// embedding tables in the model. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// A TensorList of embedding activations containing one Tensor per + /// embedding table in the model. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The TPU system performs the embedding lookups and aggregations specified by + /// the arguments to TPUEmbeddingEnqueueSparseBatch. The results of these + /// aggregations are visible to the Tensorflow Graph as the outputs of a + /// TPUEmbeddingDequeueActivations Op. This op returns a list containing one + /// Tensor of activations per table specified in the model. There can be at most + /// one ReceieveActivations op in the TPU graph. + /// + public static Tensor[] t_p_u_embedding_receive_activations(int num_tables, string tpu_embedding_config, string name = "TPUEmbeddingReceiveActivations") + { + var dict = new Dictionary(); + dict["num_tables"] = num_tables; + dict["tpu_embedding_config"] = tpu_embedding_config; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingReceiveActivations", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// Retrieve an embedding table shard from TPU memory. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingRetrieveAdagradParameters'. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// Optional argument + /// The id of the table specified in the embedding_config_json. + /// + /// + /// Optional argument + /// The number of CPU hosts in the distributed training job. + /// + /// + /// Optional argument + /// Which CPU host in the distributed training job will execute this op. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// parameters : + /// accumulators : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// TPU embeddings use dedicated per-optimizer Ops for loading and retrieving + /// trainable variables and optimizer state from TPU memory. This op enables + /// functionality equivalent to AdagradOptimizer. + /// + public static (Tensor parameters, Tensor accumulators) t_p_u_embedding_retrieve_adagrad_parameters(string tpu_embedding_config, int table_id, int num_hosts, int host_id, string name = "TPUEmbeddingRetrieveAdagradParameters") + { + var dict = new Dictionary(); + dict["tpu_embedding_config"] = tpu_embedding_config; + dict["table_id"] = table_id; + dict["num_hosts"] = num_hosts; + dict["host_id"] = host_id; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingRetrieveAdagradParameters", name: name, keywords: dict); + int _idx = 0; + var parameters = op.outputs[_idx++]; + var accumulators = op.outputs[_idx++]; + return (parameters, accumulators); + } + + /// + /// Retrieve an embedding table shard from TPU memory. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingRetrieveGradientDescentParameters'. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// Optional argument + /// The id of the table specified in tpu_embedding_config. + /// + /// + /// Optional argument + /// The number of CPU hosts in the distributed training job. + /// + /// + /// Optional argument + /// Which CPU host in the distributed training job will execute this op. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// TPU embeddings use dedicated per-optimizer Ops for loading and retrieving + /// trainable variables and optimizer state from TPU memory. This op enables + /// functionality equivalent to GradientDescentOptimizer. + /// + public static Tensor t_p_u_embedding_retrieve_gradient_descent_parameters(string tpu_embedding_config, int table_id, int num_hosts, int host_id, string name = "TPUEmbeddingRetrieveGradientDescentParameters") + { + var dict = new Dictionary(); + dict["tpu_embedding_config"] = tpu_embedding_config; + dict["table_id"] = table_id; + dict["num_hosts"] = num_hosts; + dict["host_id"] = host_id; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingRetrieveGradientDescentParameters", name: name, keywords: dict); + return op.output; + } + + /// + /// An op that performs gradient updates of embedding tables. + /// + /// + /// A TensorList of gradients with which to update embedding tables. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUEmbeddingSendGradients'. + /// + /// + /// Optional argument + /// Serialized TPUEmbeddingConfiguration proto. + /// + /// + /// Returns the description of the operation + /// + /// + /// The TensorList argument has the same length and shapes as the return value of + /// TPUEmbeddingReceiveActivations, but contains gradients of the model's loss + /// with respect to the embedding activations. The embedding tables are updated + /// from these gradients via the optimizer specified in the configuration given + /// to tpu.initialize_system. + /// + public static Operation t_p_u_embedding_send_gradients(Tensor[] gradients, string tpu_embedding_config, string name = "TPUEmbeddingSendGradients") + { + var dict = new Dictionary(); + dict["gradients"] = gradients; + dict["tpu_embedding_config"] = tpu_embedding_config; + var op = tf.OpDefLib._apply_op_helper("TPUEmbeddingSendGradients", name: name, keywords: dict); + return op; + } + + /// + /// Operator that connects N unreplicated inputs to an N-way replicated TPU computation. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUReplicatedInput'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor t_p_u_replicated_input(Tensor[] inputs, string name = "TPUReplicatedInput") + { + var dict = new Dictionary(); + dict["inputs"] = inputs; + var op = tf.OpDefLib._apply_op_helper("TPUReplicatedInput", name: name, keywords: dict); + return op.output; + } + + /// + /// Operator that connects the output of an N-way replicated TPU computation to N separate outputs. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TPUReplicatedOutput'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor[] t_p_u_replicated_output(Tensor input, int num_replicas, string name = "TPUReplicatedOutput") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["num_replicas"] = num_replicas; + var op = tf.OpDefLib._apply_op_helper("TPUReplicatedOutput", name: name, keywords: dict); + int _idx = 0; + var outputs = Enumerable.Range(0, op.OutputListLength("outputs")).Select(_ => op.outputs[_idx++]).ToArray(); + return (outputs); + } + + /// + /// Creates a dataset that contains count elements from the input_dataset. + /// + /// + /// + /// + /// A scalar representing the number of elements from the input_dataset + /// that should be taken. A value of -1 indicates that all of input_dataset + /// is taken. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TakeDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor take_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, Shape[] output_shapes, string name = "TakeDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["count"] = count; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("TakeDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Read SparseTensors from a SparseTensorsMap and concatenate them. + /// + /// + /// 1-D, The N serialized SparseTensor objects. + /// Shape: [N]. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TakeManySparseFromTensorsMap'. + /// + /// + /// Optional argument + /// The dtype of the SparseTensor objects stored in the + /// SparseTensorsMap. + /// + /// + /// The container name for the SparseTensorsMap read by this op. + /// + /// + /// The shared name for the SparseTensorsMap read by this op. + /// It should not be blank; rather the shared_name or unique Operation name + /// of the Op that created the original SparseTensorsMap should be used. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sparse_indices : 2-D. The indices of the minibatch SparseTensor. + /// sparse_values : 1-D. The values of the minibatch SparseTensor. + /// sparse_shape : 1-D. The shape of the minibatch SparseTensor. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// The input sparse_handles must be an int64 matrix of shape [N, 1] where + /// N is the minibatch size and the rows correspond to the output handles of + /// AddSparseToTensorsMap or AddManySparseToTensorsMap. The ranks of the + /// original SparseTensor objects that went into the given input ops must all + /// match. When the final SparseTensor is created, it has rank one + /// higher than the ranks of the incoming SparseTensor objects + /// (they have been concatenated along a new row dimension on the left). + /// + /// The output SparseTensor object's shape values for all dimensions but the + /// first are the max across the input SparseTensor objects' shape values + /// for the corresponding dimensions. Its first shape value is N, the minibatch + /// size. + /// + /// The input SparseTensor objects' indices are assumed ordered in + /// standard lexicographic order. If this is not the case, after this + /// step run SparseReorder to restore index ordering. + /// + /// For example, if the handles represent an input, which is a [2, 3] matrix + /// representing two original SparseTensor objects: + /// + /// + /// index = [ 0] + /// [10] + /// [20] + /// values = [1, 2, 3] + /// shape = [50] + /// + /// + /// and + /// + /// + /// index = [ 2] + /// [10] + /// values = [4, 5] + /// shape = [30] + /// + /// + /// then the final SparseTensor will be: + /// + /// + /// index = [0 0] + /// [0 10] + /// [0 20] + /// [1 2] + /// [1 10] + /// values = [1, 2, 3, 4, 5] + /// shape = [2 50] + /// + /// + public static (Tensor sparse_indices, Tensor sparse_values, Tensor sparse_shape) take_many_sparse_from_tensors_map(Tensor sparse_handles, TF_DataType dtype, string container = null, string shared_name = null, string name = "TakeManySparseFromTensorsMap") + { + var dict = new Dictionary(); + dict["sparse_handles"] = sparse_handles; + dict["dtype"] = dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("TakeManySparseFromTensorsMap", name: name, keywords: dict); + int _idx = 0; + var sparse_indices = op.outputs[_idx++]; + var sparse_values = op.outputs[_idx++]; + var sparse_shape = op.outputs[_idx++]; + return (sparse_indices, sparse_values, sparse_shape); + } + + /// + /// Computes tan of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Tan'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tan(Tensor x, string name = "Tan") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Tan", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes hyperbolic tangent of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Tanh'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tanh(Tensor x, string name = "Tanh") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("Tanh", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the gradient for the tanh of x wrt its input. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TanhGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, grad = dy * (1 - y*y), where y = tanh(x), and dy + /// is the corresponding input gradient. + /// + public static Tensor tanh_grad(Tensor y, Tensor dy, string name = "TanhGrad") + { + var dict = new Dictionary(); + dict["y"] = y; + dict["dy"] = dy; + var op = tf.OpDefLib._apply_op_helper("TanhGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns a tensor that may be mutated, but only persists within a single step. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TemporaryVariable'. + /// + /// + /// Optional argument + /// The shape of the variable tensor. + /// + /// + /// Optional argument + /// The type of elements in the variable tensor. + /// + /// + /// Overrides the name used for the temporary variable resource. Default + /// value is the name of the 'TemporaryVariable' op (which is guaranteed unique). + /// + /// + /// A reference to the variable tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This is an experimental op for internal use only and it is possible to use this + /// op in unsafe ways. DO NOT USE unless you fully understand the risks. + /// + /// It is the caller's responsibility to ensure that 'ref' is eventually passed to a + /// matching 'DestroyTemporaryVariable' op after all other uses have completed. + /// + /// Outputs a ref to the tensor state so it may be read or modified. + /// + /// E.g. + /// var = state_ops._temporary_variable([1, 2], types.float_) + /// var_name = var.op.name + /// var = state_ops.assign(var, [[4.0, 5.0]]) + /// var = state_ops.assign_add(var, [[6.0, 7.0]]) + /// final = state_ops._destroy_temporary_variable(var, var_name=var_name) + /// + public static Tensor temporary_variable(Shape shape, TF_DataType dtype, string var_name = null, string name = "TemporaryVariable") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (var_name != null) + dict["var_name"] = var_name; + var op = tf.OpDefLib._apply_op_helper("TemporaryVariable", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArrayCloseV3 + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayCloseV2'. + /// + /// + /// Returns the description of the operation + /// + public static Operation tensor_array_close_v2(Tensor handle, string name = "TensorArrayCloseV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("TensorArrayCloseV2", name: name, keywords: dict); + return op; + } + + /// + /// Delete the TensorArray from its resource container. + /// + /// + /// The handle to a TensorArray (output of TensorArray or TensorArrayGrad). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayCloseV3'. + /// + /// + /// Returns the description of the operation + /// + /// + /// This enables the user to close and release the resource in the middle + /// of a step/run. + /// + public static Operation tensor_array_close_v3(Tensor handle, string name = "TensorArrayCloseV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + var op = tf.OpDefLib._apply_op_helper("TensorArrayCloseV3", name: name, keywords: dict); + return op; + } + + /// + /// Deprecated. Use TensorArrayConcatV3 + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayConcatV2'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// value : + /// lengths : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + public static (Tensor value, Tensor lengths) tensor_array_concat_v2(Tensor handle, Tensor flow_in, TF_DataType dtype, Shape element_shape_except0 = null, string name = "TensorArrayConcatV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + if (element_shape_except0 != null) + dict["element_shape_except0"] = element_shape_except0; + var op = tf.OpDefLib._apply_op_helper("TensorArrayConcatV2", name: name, keywords: dict); + int _idx = 0; + var value = op.outputs[_idx++]; + var lengths = op.outputs[_idx++]; + return (value, lengths); + } + + /// + /// Concat the elements from the TensorArray into value value. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayConcatV3'. + /// + /// + /// Optional argument + /// The type of the elem that is returned. + /// + /// + /// The expected shape of an element, if known, + /// excluding the first dimension. Used to validate the shapes of + /// TensorArray elements. If this shape is not fully specified, concatenating + /// zero-size TensorArrays is an error. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// value : All of the elements in the TensorArray, concatenated along the first + /// axis. + /// lengths : A vector of the row sizes of the original T elements in the + /// value output. In the example above, this would be the values: + /// (n1, n2, ..., n(T-1)). + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Takes T elements of shapes + /// + /// + /// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) + /// + /// + /// and concatenates them into a Tensor of shape: + /// + /// + /// (n0 + n1 + ... + n(T-1) x d0 x d1 x ...) + /// + /// + /// All elements must have the same shape (excepting the first dimension). + /// + public static (Tensor value, Tensor lengths) tensor_array_concat_v3(Tensor handle, Tensor flow_in, TF_DataType dtype, Shape element_shape_except0 = null, string name = "TensorArrayConcatV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + if (element_shape_except0 != null) + dict["element_shape_except0"] = element_shape_except0; + var op = tf.OpDefLib._apply_op_helper("TensorArrayConcatV3", name: name, keywords: dict); + int _idx = 0; + var value = op.outputs[_idx++]; + var lengths = op.outputs[_idx++]; + return (value, lengths); + } + + /// + /// Deprecated. Use TensorArrayGatherV3 + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayGatherV2'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_gather_v2(Tensor handle, Tensor indices, Tensor flow_in, TF_DataType dtype, Shape element_shape = null, string name = "TensorArrayGatherV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["indices"] = indices; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + if (element_shape != null) + dict["element_shape"] = element_shape; + var op = tf.OpDefLib._apply_op_helper("TensorArrayGatherV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Gather specific elements from the TensorArray into output value. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// The locations in the TensorArray from which to read tensor elements. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayGatherV3'. + /// + /// + /// Optional argument + /// The type of the elem that is returned. + /// + /// + /// The expected shape of an element, if known. Used to + /// validate the shapes of TensorArray elements. If this shape is not + /// fully specified, gathering zero-size TensorArrays is an error. + /// + /// + /// All of the elements in the TensorArray, concatenated along a new + /// axis (the new dimension 0). + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// All elements selected by indices must have the same shape. + /// + public static Tensor tensor_array_gather_v3(Tensor handle, Tensor indices, Tensor flow_in, TF_DataType dtype, Shape element_shape = null, string name = "TensorArrayGatherV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["indices"] = indices; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + if (element_shape != null) + dict["element_shape"] = element_shape; + var op = tf.OpDefLib._apply_op_helper("TensorArrayGatherV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArrayGradV3 + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayGradV2'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_grad_v2(Tensor handle, Tensor flow_in, string source, string name = "TensorArrayGradV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + dict["source"] = source; + var op = tf.OpDefLib._apply_op_helper("TensorArrayGradV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a TensorArray for storing the gradients of values in the given handle. + /// + /// + /// The handle to the forward TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayGradV3'. + /// + /// + /// Optional argument + /// The gradient source string, used to decide which gradient TensorArray + /// to return. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// grad_handle : + /// flow_out : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If the given TensorArray gradient already exists, returns a reference to it. + /// + /// Locks the size of the original TensorArray by disabling its dynamic size flag. + /// + /// **A note about the input flow_in:** + /// + /// The handle flow_in forces the execution of the gradient lookup to occur + /// only after certain other operations have occurred. For example, when + /// the forward TensorArray is dynamically sized, writes to this TensorArray + /// may resize the object. The gradient TensorArray is statically sized based + /// on the size of the forward TensorArray when this operation executes. + /// Furthermore, the size of the forward TensorArray is frozen by this call. + /// As a result, the flow is used to ensure that the call to generate the gradient + /// TensorArray only happens after all writes are executed. + /// + /// In the case of dynamically sized TensorArrays, gradient computation should + /// only be performed on read operations that have themselves been chained via + /// flow to occur only after all writes have executed. That way the final size + /// of the forward TensorArray is known when this operation is called. + /// + /// **A note about the source attribute:** + /// + /// TensorArray gradient calls use an accumulator TensorArray object. If + /// multiple gradients are calculated and run in the same session, the multiple + /// gradient nodes may accidentally flow through the same accumulator TensorArray. + /// This double counts and generally breaks the TensorArray gradient flow. + /// + /// The solution is to identify which gradient call this particular + /// TensorArray gradient is being called in. This is performed by identifying + /// a unique string (e.g. "gradients", "gradients_1", ...) from the input + /// gradient Tensor's name. This string is used as a suffix when creating + /// the TensorArray gradient object here (the attribute source). + /// + /// The attribute source is added as a suffix to the forward TensorArray's + /// name when performing the creation / lookup, so that each separate gradient + /// calculation gets its own TensorArray accumulator. + /// + public static (Tensor grad_handle, Tensor flow_out) tensor_array_grad_v3(Tensor handle, Tensor flow_in, string source, string name = "TensorArrayGradV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + dict["source"] = source; + var op = tf.OpDefLib._apply_op_helper("TensorArrayGradV3", name: name, keywords: dict); + int _idx = 0; + var grad_handle = op.outputs[_idx++]; + var flow_out = op.outputs[_idx++]; + return (grad_handle, flow_out); + } + + /// + /// Creates a TensorArray for storing multiple gradients of values in the given handle. + /// + /// + /// The handle to the forward TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// An int32 vector representing a shape. Elements in the gradient accumulator will + /// have shape which is this shape_to_prepend value concatenated with shape of the + /// elements in the TensorArray corresponding to the input handle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayGradWithShape'. + /// + /// + /// Optional argument + /// The gradient source string, used to decide which gradient TensorArray + /// to return. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// grad_handle : + /// flow_out : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Similar to TensorArrayGradV3. However it creates an accumulator with an + /// expanded shape compared to the input TensorArray whose gradient is being + /// computed. This enables multiple gradients for the same TensorArray to be + /// calculated using the same accumulator. + /// + public static (Tensor grad_handle, Tensor flow_out) tensor_array_grad_with_shape(Tensor handle, Tensor flow_in, Tensor shape_to_prepend, string source, string name = "TensorArrayGradWithShape") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + dict["shape_to_prepend"] = shape_to_prepend; + dict["source"] = source; + var op = tf.OpDefLib._apply_op_helper("TensorArrayGradWithShape", name: name, keywords: dict); + int _idx = 0; + var grad_handle = op.outputs[_idx++]; + var flow_out = op.outputs[_idx++]; + return (grad_handle, flow_out); + } + + /// + /// Deprecated. Use TensorArrayReadV3 + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayReadV2'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_read_v2(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = "TensorArrayReadV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["index"] = index; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("TensorArrayReadV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Read an element from the TensorArray into output value. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayReadV3'. + /// + /// + /// Optional argument + /// The type of the elem that is returned. + /// + /// + /// The tensor that is read from the TensorArray. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_read_v3(Tensor handle, Tensor index, Tensor flow_in, TF_DataType dtype, string name = "TensorArrayReadV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["index"] = index; + dict["flow_in"] = flow_in; + dict["dtype"] = dtype; + var op = tf.OpDefLib._apply_op_helper("TensorArrayReadV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArrayScatterV3 + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayScatterV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_scatter_v2(Tensor handle, Tensor indices, Tensor value, Tensor flow_in, string name = "TensorArrayScatterV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["indices"] = indices; + dict["value"] = value; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArrayScatterV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Scatter the data from the input value into specific TensorArray elements. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// The locations at which to write the tensor elements. + /// + /// + /// The concatenated tensor to write to the TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayScatterV3'. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// indices must be a vector, its length must match the first dim of value. + /// + public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, Tensor flow_in, string name = "TensorArrayScatterV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["indices"] = indices; + dict["value"] = value; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArrayScatterV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArraySizeV3 + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArraySizeV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_size_v2(Tensor handle, Tensor flow_in, string name = "TensorArraySizeV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArraySizeV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Get the current size of the TensorArray. + /// + /// + /// The handle to a TensorArray (output of TensorArray or TensorArrayGrad). + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArraySizeV3'. + /// + /// + /// The current size of the TensorArray. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_size_v3(Tensor handle, Tensor flow_in, string name = "TensorArraySizeV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArraySizeV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArraySplitV3 + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArraySplitV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_split_v2(Tensor handle, Tensor value, Tensor lengths, Tensor flow_in, string name = "TensorArraySplitV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["value"] = value; + dict["lengths"] = lengths; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArraySplitV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Split the data from the input value into TensorArray elements. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// The concatenated tensor to write to the TensorArray. + /// + /// + /// The vector of lengths, how to split the rows of value into the + /// TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArraySplitV3'. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Assuming that lengths takes on values + /// + /// + /// (n0, n1, ..., n(T-1)) + /// + /// + /// and that value has shape + /// + /// + /// (n0 + n1 + ... + n(T-1) x d0 x d1 x ...), + /// + /// this splits values into a TensorArray with T tensors. + /// + /// TensorArray index t will be the subtensor of values with starting position + /// + /// + /// (n0 + n1 + ... + n(t-1), 0, 0, ...) + /// + /// + /// and having size + /// + /// + /// nt x d0 x d1 x ... + /// + /// + public static Tensor tensor_array_split_v3(Tensor handle, Tensor value, Tensor lengths, Tensor flow_in, string name = "TensorArraySplitV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["value"] = value; + dict["lengths"] = lengths; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArraySplitV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Deprecated. Use TensorArrayV3 + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayV2'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_v2(Tensor size, TF_DataType dtype, Shape element_shape = null, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, string name = "TensorArrayV2") + { + var dict = new Dictionary(); + dict["size"] = size; + dict["dtype"] = dtype; + if (element_shape != null) + dict["element_shape"] = element_shape; + if (dynamic_size.HasValue) + dict["dynamic_size"] = dynamic_size.Value; + if (clear_after_read.HasValue) + dict["clear_after_read"] = clear_after_read.Value; + if (tensor_array_name != null) + dict["tensor_array_name"] = tensor_array_name; + var op = tf.OpDefLib._apply_op_helper("TensorArrayV2", name: name, keywords: dict); + return op.output; + } + + /// + /// An array of Tensors of given size. + /// + /// + /// The size of the array. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayV3'. + /// + /// + /// Optional argument + /// The type of the elements on the tensor_array. + /// + /// + /// The expected shape of an element, if known. Used to + /// validate the shapes of TensorArray elements. If this shape is not + /// fully specified, gathering zero-size TensorArrays is an error. + /// + /// + /// A boolean that determines whether writes to the TensorArray + /// are allowed to grow the size. By default, this is not allowed. + /// + /// + /// If true (default), Tensors in the TensorArray are cleared + /// after being read. This disables multiple read semantics but allows early + /// release of memory. + /// + /// + /// If true (default is false), then all + /// elements in the TensorArray will be expected to have have identical shapes. + /// This allows certain behaviors, like dynamically checking for + /// consistent shapes on write, and being able to fill in properly + /// shaped zero tensors on stack -- even if the element_shape attribute + /// is not fully defined. + /// + /// + /// Overrides the name used for the temporary tensor_array + /// resource. Default value is the name of the 'TensorArray' op (which + /// is guaranteed unique). + /// + /// + /// Returns a tuple with multiple values, as follows: + /// handle : The handle to the TensorArray. + /// flow : A scalar used to control gradient flow. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Write data via Write and read via Read or Pack. + /// + public static (Tensor handle, Tensor flow) tensor_array_v3(Tensor size, TF_DataType dtype, Shape element_shape = null, bool? dynamic_size = null, bool? clear_after_read = null, bool? identical_element_shapes = null, string tensor_array_name = null, string name = "TensorArrayV3") + { + var dict = new Dictionary(); + dict["size"] = size; + dict["dtype"] = dtype; + if (element_shape != null) + dict["element_shape"] = element_shape; + if (dynamic_size.HasValue) + dict["dynamic_size"] = dynamic_size.Value; + if (clear_after_read.HasValue) + dict["clear_after_read"] = clear_after_read.Value; + if (identical_element_shapes.HasValue) + dict["identical_element_shapes"] = identical_element_shapes.Value; + if (tensor_array_name != null) + dict["tensor_array_name"] = tensor_array_name; + var op = tf.OpDefLib._apply_op_helper("TensorArrayV3", name: name, keywords: dict); + int _idx = 0; + var handle = op.outputs[_idx++]; + var flow = op.outputs[_idx++]; + return (handle, flow); + } + + /// + /// Deprecated. Use TensorArrayGradV3 + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayWriteV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_write_v2(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = "TensorArrayWriteV2") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["index"] = index; + dict["value"] = value; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArrayWriteV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Push an element onto the tensor_array. + /// + /// + /// The handle to a TensorArray. + /// + /// + /// The position to write to inside the TensorArray. + /// + /// + /// The tensor to write to the TensorArray. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorArrayWriteV3'. + /// + /// + /// A float scalar that enforces proper chaining of operations. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_array_write_v3(Tensor handle, Tensor index, Tensor value, Tensor flow_in, string name = "TensorArrayWriteV3") + { + var dict = new Dictionary(); + dict["handle"] = handle; + dict["index"] = index; + dict["value"] = value; + dict["flow_in"] = flow_in; + var op = tf.OpDefLib._apply_op_helper("TensorArrayWriteV3", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that emits components as a tuple of tensors once. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorDataset'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_dataset(Tensor[] components, Shape[] output_shapes, string name = "TensorDataset") + { + var dict = new Dictionary(); + dict["components"] = components; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("TensorDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// The shape of the elements of the given list, as a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListElementShape'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// input_handle: the list + /// element_shape: the shape of elements of the list + /// + public static Tensor tensor_list_element_shape(Tensor input_handle, TF_DataType shape_type, string name = "TensorListElementShape") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["shape_type"] = shape_type; + var op = tf.OpDefLib._apply_op_helper("TensorListElementShape", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a TensorList which, when stacked, has the value of tensor. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListFromTensor'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Each tensor in the result list corresponds to one row of the input tensor. + /// + /// tensor: The input tensor. + /// output_handle: The list. + /// + public static Tensor tensor_list_from_tensor(Tensor tensor, Tensor element_shape, string name = "TensorListFromTensor") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["element_shape"] = element_shape; + var op = tf.OpDefLib._apply_op_helper("TensorListFromTensor", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a Tensor by indexing into the TensorList. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListGather'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Each row in the produced Tensor corresponds to the element in the TensorList + /// specified by the given index (see tf.gather). + /// + /// input_handle: The input tensor list. + /// indices: The indices used to index into the list. + /// values: The tensor. + /// + public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, string name = "TensorListGather") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["indices"] = indices; + dict["element_dtype"] = element_dtype; + var op = tf.OpDefLib._apply_op_helper("TensorListGather", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the item in the list with the given index. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListGetItem'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// input_handle: the list + /// index: the position in the list from which an element will be retrieved + /// item: the element at that position + /// + /// + /// + public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, string name = "TensorListGetItem") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["index"] = index; + dict["element_dtype"] = element_dtype; + var op = tf.OpDefLib._apply_op_helper("TensorListGetItem", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the number of tensors in the input tensor list. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListLength'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// input_handle: the input list + /// length: the number of tensors in the list + /// + public static Tensor tensor_list_length(Tensor input_handle, string name = "TensorListLength") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + var op = tf.OpDefLib._apply_op_helper("TensorListLength", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the last element of the input list as well as a list with all but that element. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListPopBack'. + /// + /// + /// Optional argument + /// + /// + /// Returns a tuple with multiple values, as follows: + /// output_handle : + /// tensor : + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// Fails if the list is empty. + /// + /// input_handle: the input list + /// tensor: the withdrawn last element of the list + /// element_dtype: the type of elements in the list + /// element_shape: the shape of the output tensor + /// + public static (Tensor output_handle, Tensor tensor) tensor_list_pop_back(Tensor input_handle, TF_DataType element_dtype, string name = "TensorListPopBack") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["element_dtype"] = element_dtype; + var op = tf.OpDefLib._apply_op_helper("TensorListPopBack", name: name, keywords: dict); + int _idx = 0; + var output_handle = op.outputs[_idx++]; + var tensor = op.outputs[_idx++]; + return (output_handle, tensor); + } + + /// + /// Returns a list list which has the passed-in Tensor as last element and the other elements of the given list in input_handle. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListPushBack'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// tensor: The tensor to put on the list. + /// input_handle: The old list. + /// output_handle: A list with the elements of the old list followed by tensor. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + public static Tensor tensor_list_push_back(Tensor input_handle, Tensor tensor, string name = "TensorListPushBack") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["tensor"] = tensor; + var op = tf.OpDefLib._apply_op_helper("TensorListPushBack", name: name, keywords: dict); + return op.output; + } + + /// + /// List of the given size with empty elements. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListReserve'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// element_shape: the shape of the future elements of the list + /// num_elements: the number of elements to reserve + /// handle: the output list + /// element_dtype: the desired type of elements in the list. + /// + public static Tensor tensor_list_reserve(Tensor element_shape, Tensor num_elements, TF_DataType element_dtype, string name = "TensorListReserve") + { + var dict = new Dictionary(); + dict["element_shape"] = element_shape; + dict["num_elements"] = num_elements; + dict["element_dtype"] = element_dtype; + var op = tf.OpDefLib._apply_op_helper("TensorListReserve", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a TensorList by indexing into a Tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListScatter'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see tf.gather). + /// + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// element_shape: The shape of the elements in the list (can be less specified than + /// the shape of the tensor). + /// output_handle: The TensorList. + /// + public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Tensor element_shape, string name = "TensorListScatter") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + dict["indices"] = indices; + dict["element_shape"] = element_shape; + var op = tf.OpDefLib._apply_op_helper("TensorListScatter", name: name, keywords: dict); + return op.output; + } + + /// + /// Sets the index-th position of the list to contain the given tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListSetItem'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// input_handle: the list + /// index: the position in the list to which the tensor will be assigned + /// item: the element to be assigned to that position + /// output_handle: the new list, with the element in the proper position + /// + /// + public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, string name = "TensorListSetItem") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["index"] = index; + dict["item"] = item; + var op = tf.OpDefLib._apply_op_helper("TensorListSetItem", name: name, keywords: dict); + return op.output; + } + + /// + /// Stacks all tensors in the list. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorListStack'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Requires that all tensors have the same shape. + /// + /// input_handle: the input list + /// tensor: the gathered result + /// num_elements: optional. If not -1, the number of elements in the list. + /// + /// + public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int? num_elements = null, string name = "TensorListStack") + { + var dict = new Dictionary(); + dict["input_handle"] = input_handle; + dict["element_dtype"] = element_dtype; + if (num_elements.HasValue) + dict["num_elements"] = num_elements.Value; + var op = tf.OpDefLib._apply_op_helper("TensorListStack", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that emits each dim-0 slice of components once. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorSliceDataset'. + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_slice_dataset(Tensor[] components, Shape[] output_shapes, string name = "TensorSliceDataset") + { + var dict = new Dictionary(); + dict["components"] = components; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("TensorSliceDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with a tensor. + /// + /// + /// A tensor to serialize. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorSummary'. + /// + /// + /// A json-encoded SummaryDescription proto. + /// + /// + /// An unused list of strings. + /// + /// + /// An unused string. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This op is being phased out in favor of TensorSummaryV2, which lets callers pass + /// a tag as well as a serialized SummaryMetadata proto string that contains + /// plugin-specific data. We will keep this op to maintain backwards compatibility. + /// + public static Tensor tensor_summary(Tensor tensor, string description = null, string[] labels = null, string display_name = null, string name = "TensorSummary") + { + var dict = new Dictionary(); + dict["tensor"] = tensor; + if (description != null) + dict["description"] = description; + if (labels != null) + dict["labels"] = labels; + if (display_name != null) + dict["display_name"] = display_name; + var op = tf.OpDefLib._apply_op_helper("TensorSummary", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs a Summary protocol buffer with a tensor and per-plugin data. + /// + /// + /// A string attached to this summary. Used for organization in TensorBoard. + /// + /// + /// A tensor to serialize. + /// + /// + /// A serialized SummaryMetadata proto. Contains plugin + /// data. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TensorSummaryV2'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor tensor_summary_v2(Tensor tag, Tensor tensor, Tensor serialized_summary_metadata, string name = "TensorSummaryV2") + { + var dict = new Dictionary(); + dict["tag"] = tag; + dict["tensor"] = tensor; + dict["serialized_summary_metadata"] = serialized_summary_metadata; + var op = tf.OpDefLib._apply_op_helper("TensorSummaryV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that emits the lines of one or more text files. + /// + /// + /// A scalar or a vector containing the name(s) of the file(s) to be + /// read. + /// + /// + /// A scalar containing either (i) the empty string (no + /// compression), (ii) "ZLIB", or (iii) "GZIP". + /// + /// + /// A scalar containing the number of bytes to buffer. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TextLineDataset'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor text_line_dataset(Tensor filenames, Tensor compression_type, Tensor buffer_size, string name = "TextLineDataset") + { + var dict = new Dictionary(); + dict["filenames"] = filenames; + dict["compression_type"] = compression_type; + dict["buffer_size"] = buffer_size; + var op = tf.OpDefLib._apply_op_helper("TextLineDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the lines of a file delimited by '\n'. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TextLineReader'. + /// + /// + /// Number of lines to skip from the beginning of every file. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor text_line_reader(int? skip_header_lines = null, string container = null, string shared_name = null, string name = "TextLineReader") + { + var dict = new Dictionary(); + if (skip_header_lines.HasValue) + dict["skip_header_lines"] = skip_header_lines.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("TextLineReader", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the lines of a file delimited by '\n'. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TextLineReaderV2'. + /// + /// + /// Number of lines to skip from the beginning of every file. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor text_line_reader_v2(int? skip_header_lines = null, string container = null, string shared_name = null, string name = "TextLineReaderV2") + { + var dict = new Dictionary(); + if (skip_header_lines.HasValue) + dict["skip_header_lines"] = skip_header_lines.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("TextLineReaderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a learned unigram distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ThreadUnsafeUnigramCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to randomly sample. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// Optional argument + /// The sampler will sample integers from the interval [0, range_max). + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See explanations of candidate sampling and the data formats at + /// go/candidate-sampling. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) thread_unsafe_unigram_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int range_max, int? seed = null, int? seed2 = null, string name = "ThreadUnsafeUnigramCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + dict["range_max"] = range_max; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("ThreadUnsafeUnigramCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Constructs a tensor by tiling a given tensor. + /// + /// + /// 1-D or higher. + /// + /// + /// 1-D. Length must be the same as the number of dimensions in input + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Tile'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation creates a new tensor by replicating input multiples times. + /// The output tensor's i'th dimension has input.dims(i) * multiples[i] elements, + /// and the values of input are replicated multiples[i] times along the 'i'th + /// dimension. For example, tiling [a b c d] by [2] produces + /// [a b c d a b c d]. + /// + public static Tensor tile(Tensor input, Tensor multiples, string name = "Tile") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["multiples"] = multiples; + var op = tf.OpDefLib._apply_op_helper("Tile", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the gradient of Tile. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TileGrad'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Since Tile takes an input and repeats the input multiples times + /// along each dimension, TileGrad takes in multiples and aggregates + /// each repeated tile of input into output. + /// + public static Tensor tile_grad(Tensor input, Tensor multiples, string name = "TileGrad") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["multiples"] = multiples; + var op = tf.OpDefLib._apply_op_helper("TileGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Provides the time since epoch in seconds. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Timestamp'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Returns the timestamp as a float64 for seconds since the Unix epoch. + /// + /// Note: the timestamp is computed when the op is executed, not when it is added + /// to the graph. + /// + public static Tensor timestamp(string name = "Timestamp") + { + var dict = new Dictionary(); + var op = tf.OpDefLib._apply_op_helper("Timestamp", name: name, keywords: dict); + return op.output; + } + + /// + /// Finds values and indices of the k largest elements for the last dimension. + /// + /// + /// 1-D or higher with last dimension at least k. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TopK'. + /// + /// + /// Optional argument + /// Number of top elements to look for along the last dimension (along each + /// row for matrices). + /// + /// + /// If true the resulting k elements will be sorted by the values in + /// descending order. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// values : The k largest elements along each last dimensional slice. + /// indices : The indices of values within the last dimension of input. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If the input is a vector (rank-1), finds the k largest entries in the vector + /// and outputs their values and indices as vectors. Thus values[j] is the + /// j-th largest entry in input, and its index is indices[j]. + /// + /// For matrices (resp. higher rank input), computes the top k entries in each + /// row (resp. vector along the last dimension). Thus, + /// + /// values.shape = indices.shape = input.shape[:-1] + [k] + /// + /// If two elements are equal, the lower-index element appears first. + /// + /// If k varies dynamically, use TopKV2 below. + /// + public static (Tensor values, Tensor indices) top_k(Tensor input, int k, bool? sorted = null, string name = "TopK") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["k"] = k; + if (sorted.HasValue) + dict["sorted"] = sorted.Value; + var op = tf.OpDefLib._apply_op_helper("TopK", name: name, keywords: dict); + int _idx = 0; + var values = op.outputs[_idx++]; + var indices = op.outputs[_idx++]; + return (values, indices); + } + + /// + /// Finds values and indices of the k largest elements for the last dimension. + /// + /// + /// 1-D or higher with last dimension at least k. + /// + /// + /// 0-D. Number of top elements to look for along the last dimension (along each + /// row for matrices). + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TopKV2'. + /// + /// + /// If true the resulting k elements will be sorted by the values in + /// descending order. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// values : The k largest elements along each last dimensional slice. + /// indices : The indices of values within the last dimension of input. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// If the input is a vector (rank-1), finds the k largest entries in the vector + /// and outputs their values and indices as vectors. Thus values[j] is the + /// j-th largest entry in input, and its index is indices[j]. + /// + /// For matrices (resp. higher rank input), computes the top k entries in each + /// row (resp. vector along the last dimension). Thus, + /// + /// values.shape = indices.shape = input.shape[:-1] + [k] + /// + /// If two elements are equal, the lower-index element appears first. + /// + public static (Tensor values, Tensor indices) top_k_v2(Tensor input, Tensor k, bool? sorted = null, string name = "TopKV2") + { + var dict = new Dictionary(); + dict["input"] = input; + dict["k"] = k; + if (sorted.HasValue) + dict["sorted"] = sorted.Value; + var op = tf.OpDefLib._apply_op_helper("TopKV2", name: name, keywords: dict); + int _idx = 0; + var values = op.outputs[_idx++]; + var indices = op.outputs[_idx++]; + return (values, indices); + } + + /// + /// Shuffle dimensions of x according to a permutation. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Transpose'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The output y has the same rank as x. The shapes of x and y satisfy: + /// y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1] + /// + public static Tensor transpose(Tensor x, Tensor perm, string name = "Transpose") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["perm"] = perm; + var op = tf.OpDefLib._apply_op_helper("Transpose", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns x / y element-wise for integer types. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TruncateDiv'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Truncation designates that negative numbers will round fractional quantities + /// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different + /// than Python semantics. See FloorDiv for a division function that matches + /// Python Semantics. + /// + /// *NOTE*: TruncateDiv supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor truncate_div(Tensor x, Tensor y, string name = "TruncateDiv") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("TruncateDiv", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns element-wise remainder of division. This emulates C semantics in that + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TruncateMod'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// the result here is consistent with a truncating divide. E.g. truncate(x / y) * + /// y + truncate_mod(x, y) = x. + /// + /// *NOTE*: TruncateMod supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor truncate_mod(Tensor x, Tensor y, string name = "TruncateMod") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["y"] = y; + var op = tf.OpDefLib._apply_op_helper("TruncateMod", name: name, keywords: dict); + return op.output; + } + + /// + /// Outputs random values from a truncated normal distribution. + /// + /// + /// The shape of the output tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TruncatedNormal'. + /// + /// + /// Optional argument + /// The type of the output. + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// A second seed to avoid seed collision. + /// + /// + /// A tensor of the specified shape filled with random truncated normal + /// values. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The generated values follow a normal distribution with mean 0 and standard + /// deviation 1, except that values whose magnitude is more than 2 standard + /// deviations from the mean are dropped and re-picked. + /// + public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = null, int? seed2 = null, string name = "TruncatedNormal") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("TruncatedNormal", name: name, keywords: dict); + return op.output; + } + + /// + /// Perform batches of RPC requests. + /// + /// + /// 0-D or 1-D. The address (i.e. host_name:port) of the RPC server. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with method and request. + /// + /// + /// 0-D or 1-D. The method address on the RPC server. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with address and request. + /// + /// + /// 0-D or 1-D. Serialized proto strings: the rpc request argument. + /// If this tensor has more than 1 element, then multiple parallel rpc requests + /// are sent. This argument broadcasts with address and method. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'TryRpc'. + /// + /// + /// RPC protocol to use. Empty string means use the default protocol. + /// Options include 'grpc'. + /// + /// + /// boolean. If true (default), then failures to connect + /// (i.e., the server does not immediately respond) cause an RPC failure. + /// + /// + /// int. If 0 (default), then the kernel will run the RPC + /// request and only time out if the RPC deadline passes or the session times out. + /// If this value is greater than 0, then the op will raise an exception if + /// the RPC takes longer than timeout_in_ms. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// response : Same shape as request. Serialized proto strings: the rpc responses. + /// status_code : Same shape as request. Values correspond to tensorflow Status enum codes. + /// status_message : Same shape as request. Values correspond to Status messages + /// returned from the RPC calls. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This op asynchronously performs either a single RPC request, or a batch + /// of requests. RPC requests are defined by three main parameters: + /// + /// - address (the host+port or BNS address of the request) + /// - method (the method name for the request) + /// - request (the serialized proto string, or vector of strings, + /// of the RPC request argument). + /// + /// For example, if you have an RPC service running on port localhost:2345, + /// and its interface is configured with the following proto declaration: + /// + /// + /// service MyService { + /// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { + /// } + /// }; + /// + /// + /// then call this op with arguments: + /// + /// + /// address = "localhost:2345" + /// method = "MyService/MyMethod" + /// + /// + /// The request tensor is a string tensor representing serialized MyRequestProto + /// strings; and the output string tensor response will have the same shape + /// and contain (upon successful completion) corresponding serialized + /// MyResponseProto strings. + /// + /// For example, to send a single, empty, MyRequestProto, call + /// this op with request = "". To send 5 **parallel** empty requests, + /// call this op with request = ["", "", "", "", ""]. + /// + /// More generally, one can create a batch of MyRequestProto serialized protos + /// from regular batched tensors using the encode_proto op, and convert + /// the response MyResponseProto serialized protos to batched tensors + /// using the decode_proto op. + /// + /// **NOTE** Working with serialized proto strings is faster than instantiating + /// actual proto objects in memory, so no performance degradation is expected + /// compared to writing custom kernels for this workflow. + /// + /// Unlike the standard Rpc op, if the connection fails or the remote worker + /// returns an error status, this op does **not** reraise the exception. + /// Instead, the status_code and status_message entry for the corresponding RPC + /// call is set with the error returned from the RPC call. The response tensor + /// will contain valid response values for those minibatch entries whose RPCs did + /// not fail; the rest of the entries will have empty strings. + /// + public static (Tensor response, Tensor status_code, Tensor status_message) try_rpc(Tensor address, Tensor method, Tensor request, string protocol = null, bool? fail_fast = null, int? timeout_in_ms = null, string name = "TryRpc") + { + var dict = new Dictionary(); + dict["address"] = address; + dict["method"] = method; + dict["request"] = request; + if (protocol != null) + dict["protocol"] = protocol; + if (fail_fast.HasValue) + dict["fail_fast"] = fail_fast.Value; + if (timeout_in_ms.HasValue) + dict["timeout_in_ms"] = timeout_in_ms.Value; + var op = tf.OpDefLib._apply_op_helper("TryRpc", name: name, keywords: dict); + int _idx = 0; + var response = op.outputs[_idx++]; + var status_code = op.outputs[_idx++]; + var status_message = op.outputs[_idx++]; + return (response, status_code, status_message); + } + + /// + /// Reverses the operation of Batch for a single output Tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Unbatch'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// An instance of Unbatch either receives an empty batched_tensor, in which case it + /// asynchronously waits until the values become available from a concurrently + /// running instance of Unbatch with the same container and shared_name, or receives + /// a non-empty batched_tensor in which case it finalizes all other concurrently + /// running instances and outputs its own element from the batch. + /// + /// batched_tensor: The possibly transformed output of Batch. The size of the first + /// dimension should remain unchanged by the transformations for the operation to + /// work. + /// batch_index: The matching batch_index obtained from Batch. + /// id: The id scalar emitted by Batch. + /// unbatched_tensor: The Tensor corresponding to this execution. + /// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the + /// batched input tensor associated with a given invocation of the op. + /// container: Container to control resource sharing. + /// shared_name: Instances of Unbatch with the same container and shared_name are + /// assumed to possibly belong to the same batch. If left empty, the op name will + /// be used as the shared name. + /// + public static Tensor unbatch(Tensor batched_tensor, Tensor batch_index, Tensor id, int timeout_micros, string container = null, string shared_name = null, string name = "Unbatch") + { + var dict = new Dictionary(); + dict["batched_tensor"] = batched_tensor; + dict["batch_index"] = batch_index; + dict["id"] = id; + dict["timeout_micros"] = timeout_micros; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("Unbatch", name: name, keywords: dict); + return op.output; + } + + /// + /// A dataset that splits the elements of its input into multiple elements. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnbatchDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor unbatch_dataset(Tensor input_dataset, TF_DataType[] output_types, Shape[] output_shapes, string name = "UnbatchDataset") + { + var dict = new Dictionary(); + dict["input_dataset"] = input_dataset; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("UnbatchDataset", name: name, keywords: dict); + return op.output; + } + + /// + /// Gradient of Unbatch. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnbatchGrad'. + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Acts like Batch but using the given batch_index index of batching things as they + /// become available. This ensures that the gradients are propagated back in the + /// same session which did the forward pass. + /// + /// original_input: The input to the Unbatch operation this is the gradient of. + /// batch_index: The batch_index given to the Unbatch operation this is the gradient + /// of. + /// grad: The downstream gradient. + /// id: The id scalar emitted by Batch. + /// batched_grad: The return value, either an empty tensor or the batched gradient. + /// container: Container to control resource sharing. + /// shared_name: Instances of UnbatchGrad with the same container and shared_name + /// are assumed to possibly belong to the same batch. If left empty, the op name + /// will be used as the shared name. + /// + public static Tensor unbatch_grad(Tensor original_input, Tensor batch_index, Tensor grad, Tensor id, string container = null, string shared_name = null, string name = "UnbatchGrad") + { + var dict = new Dictionary(); + dict["original_input"] = original_input; + dict["batch_index"] = batch_index; + dict["grad"] = grad; + dict["id"] = id; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("UnbatchGrad", name: name, keywords: dict); + return op.output; + } + + /// + /// Generates labels for candidate sampling with a uniform distribution. + /// + /// + /// A batch_size * num_true matrix, in which each row contains the + /// IDs of the num_true target_classes in the corresponding original label. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UniformCandidateSampler'. + /// + /// + /// Optional argument + /// Number of true labels per context. + /// + /// + /// Optional argument + /// Number of candidates to randomly sample. + /// + /// + /// Optional argument + /// If unique is true, we sample with rejection, so that all sampled + /// candidates in a batch are unique. This requires some approximation to + /// estimate the post-rejection sampling probabilities. + /// + /// + /// Optional argument + /// The sampler will sample integers from the interval [0, range_max). + /// + /// + /// If either seed or seed2 are set to be non-zero, the random number + /// generator is seeded by the given seed. Otherwise, it is seeded by a + /// random seed. + /// + /// + /// An second seed to avoid seed collision. + /// + /// + /// Returns a tuple with multiple values, as follows: + /// sampled_candidates : A vector of length num_sampled, in which each element is + /// the ID of a sampled candidate. + /// true_expected_count : A batch_size * num_true matrix, representing + /// the number of times each candidate is expected to occur in a batch + /// of sampled candidates. If unique=true, then this is a probability. + /// sampled_expected_count : A vector of length num_sampled, for each sampled + /// candidate representing the number of times the candidate is expected + /// to occur in a batch of sampled candidates. If unique=true, then this is a + /// probability. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// See explanations of candidate sampling and the data formats at + /// go/candidate-sampling. + /// + /// For each batch, this op picks a single set of sampled candidate labels. + /// + /// The advantages of sampling candidates per-batch are simplicity and the + /// possibility of efficient dense matrix multiplication. The disadvantage is that + /// the sampled candidates must be chosen independently of the context and of the + /// true labels. + /// + public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sampled_expected_count) uniform_candidate_sampler(Tensor true_classes, int num_true, int num_sampled, bool unique, int range_max, int? seed = null, int? seed2 = null, string name = "UniformCandidateSampler") + { + var dict = new Dictionary(); + dict["true_classes"] = true_classes; + dict["num_true"] = num_true; + dict["num_sampled"] = num_sampled; + dict["unique"] = unique; + dict["range_max"] = range_max; + if (seed.HasValue) + dict["seed"] = seed.Value; + if (seed2.HasValue) + dict["seed2"] = seed2.Value; + var op = tf.OpDefLib._apply_op_helper("UniformCandidateSampler", name: name, keywords: dict); + int _idx = 0; + var sampled_candidates = op.outputs[_idx++]; + var true_expected_count = op.outputs[_idx++]; + var sampled_expected_count = op.outputs[_idx++]; + return (sampled_candidates, true_expected_count, sampled_expected_count); + } + + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// 1-D. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Unique'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : 1-D. + /// idx : 1-D. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation returns a tensor y containing all of the unique elements of x + /// sorted in the same order that they occur in x. This operation also returns a + /// tensor idx the same size as x that contains the index of each value of x + /// in the unique output y. In other words: + /// + /// y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1] + /// + /// For example: + /// + /// + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx = unique(x) + /// y ==&gt; [1, 2, 4, 7, 8] + /// idx ==&gt; [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// + /// + public static (Tensor y, Tensor idx) unique(Tensor x, TF_DataType? out_idx = null, string name = "Unique") + { + var dict = new Dictionary(); + dict["x"] = x; + if (out_idx.HasValue) + dict["out_idx"] = out_idx.Value; + var op = tf.OpDefLib._apply_op_helper("Unique", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var idx = op.outputs[_idx++]; + return (y, idx); + } + + /// + /// Finds unique elements along an axis of a tensor. + /// + /// + /// A Tensor. + /// + /// + /// A Tensor of type int32 (default: None). The axis of the Tensor to + /// find the unique elements. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UniqueV2'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : A Tensor. Unique elements along the axis of Tensor x. + /// idx : A 1-D Tensor. Has the same type as x that contains the index of each + /// value of x in the output y. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation either returns a tensor y containing unique elements + /// along the axis of a tensor. The returned unique elements is sorted + /// in the same order as they occur along axis in x. + /// This operation also returns a tensor idx that is the same size as + /// the number of the elements in x along the axis dimension. It + /// contains the index in the unique output y. + /// In other words, for an 1-D tensor x with axis = None: + /// + /// y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1] + /// + /// For example: + /// + /// + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx = unique(x) + /// y ==&gt; [1, 2, 4, 7, 8] + /// idx ==&gt; [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// + /// + /// For an 2-D tensor x with axis = 0: + /// + /// + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx = unique(x, axis=0) + /// y ==&gt; [[1, 0, 0], + /// [2, 0, 0]] + /// idx ==&gt; [0, 0, 1] + /// + /// + /// For an 2-D tensor x with axis = 1: + /// + /// + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx = unique(x, axis=1) + /// y ==&gt; [[1, 0], + /// [1, 0], + /// [2, 0]] + /// idx ==&gt; [0, 1, 1] + /// + /// + public static (Tensor y, Tensor idx) unique_v2(Tensor x, Tensor axis, TF_DataType? out_idx = null, string name = "UniqueV2") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["axis"] = axis; + if (out_idx.HasValue) + dict["out_idx"] = out_idx.Value; + var op = tf.OpDefLib._apply_op_helper("UniqueV2", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var idx = op.outputs[_idx++]; + return (y, idx); + } + + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// 1-D. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UniqueWithCounts'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : 1-D. + /// idx : 1-D. + /// count : 1-D. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation returns a tensor y containing all of the unique elements of x + /// sorted in the same order that they occur in x. This operation also returns a + /// tensor idx the same size as x that contains the index of each value of x + /// in the unique output y. Finally, it returns a third tensor count that + /// contains the count of each element of y in x. In other words: + /// + /// y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1] + /// + /// For example: + /// + /// + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx, count = unique_with_counts(x) + /// y ==&gt; [1, 2, 4, 7, 8] + /// idx ==&gt; [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// count ==&gt; [2, 1, 3, 1, 2] + /// + /// + public static (Tensor y, Tensor idx, Tensor count) unique_with_counts(Tensor x, TF_DataType? out_idx = null, string name = "UniqueWithCounts") + { + var dict = new Dictionary(); + dict["x"] = x; + if (out_idx.HasValue) + dict["out_idx"] = out_idx.Value; + var op = tf.OpDefLib._apply_op_helper("UniqueWithCounts", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var idx = op.outputs[_idx++]; + var count = op.outputs[_idx++]; + return (y, idx, count); + } + + /// + /// Finds unique elements along an axis of a tensor. + /// + /// + /// A Tensor. + /// + /// + /// A Tensor of type int32 (default: None). The axis of the Tensor to + /// find the unique elements. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UniqueWithCountsV2'. + /// + /// + /// + /// + /// Returns a tuple with multiple values, as follows: + /// y : A Tensor. Unique elements along the axis of Tensor x. + /// idx : A 1-D Tensor. Has the same type as x that contains the index of each + /// value of x in the output y. + /// count : A 1-D Tensor. The count of each value of x in the output y. + /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. + /// + /// + /// This operation either returns a tensor y containing unique elements + /// along the axis of a tensor. The returned unique elements is sorted + /// in the same order as they occur along axis in x. + /// This operation also returns a tensor idx and a tensor count + /// that are the same size as the number of the elements in x along the + /// axis dimension. The idx contains the index in the unique output y + /// and the count contains the count in the unique output y. + /// In other words, for an 1-D tensor x with axis = None: + /// + /// y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1] + /// + /// For example: + /// + /// + /// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] + /// y, idx, count = unique_with_counts(x) + /// y ==&gt; [1, 2, 4, 7, 8] + /// idx ==&gt; [0, 0, 1, 2, 2, 2, 3, 4, 4] + /// count ==&gt; [2, 1, 3, 1, 2] + /// + /// + /// For an 2-D tensor x with axis = 0: + /// + /// + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx, count = unique_with_counts(x, axis=0) + /// y ==&gt; [[1, 0, 0], + /// [2, 0, 0]] + /// idx ==&gt; [0, 0, 1] + /// count ==&gt; [2, 1] + /// + /// + /// For an 2-D tensor x with axis = 1: + /// + /// + /// # tensor 'x' is [[1, 0, 0], + /// # [1, 0, 0], + /// # [2, 0, 0]] + /// y, idx, count = unique_with_counts(x, axis=1) + /// y ==&gt; [[1, 0], + /// [1, 0], + /// [2, 0]] + /// idx ==&gt; [0, 1, 1] + /// count ==&gt; [1, 2] + /// + /// + public static (Tensor y, Tensor idx, Tensor count) unique_with_counts_v2(Tensor x, Tensor axis, TF_DataType? out_idx = null, string name = "UniqueWithCountsV2") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["axis"] = axis; + if (out_idx.HasValue) + dict["out_idx"] = out_idx.Value; + var op = tf.OpDefLib._apply_op_helper("UniqueWithCountsV2", name: name, keywords: dict); + int _idx = 0; + var y = op.outputs[_idx++]; + var idx = op.outputs[_idx++]; + var count = op.outputs[_idx++]; + return (y, idx, count); + } + + /// + /// Unpacks a given dimension of a rank-R tensor into num rank-(R-1) tensors. + /// + /// + /// 1-D or higher, with axis dimension size equal to num. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Unpack'. + /// + /// + /// Optional argument + /// + /// + /// Dimension along which to unpack. Negative values wrap around, so the + /// valid range is [-R, R). + /// + /// + /// The list of tensors unpacked from value. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Unpacks num tensors from value by chipping it along the axis dimension. + /// For example, given a tensor of shape (A, B, C, D); + /// + /// If axis == 0 then the i'th tensor in output is the slice value[i, :, :, :] + /// and each tensor in output will have shape (B, C, D). (Note that the + /// dimension unpacked along is gone, unlike split). + /// + /// If axis == 1 then the i'th tensor in output is the slice value[:, i, :, :] + /// and each tensor in output will have shape (A, C, D). + /// Etc. + /// + /// This is the opposite of pack. + /// + public static Tensor[] unpack(Tensor value, int num, int? axis = null, string name = "Unpack") + { + var dict = new Dictionary(); + dict["value"] = value; + dict["num"] = num; + if (axis.HasValue) + dict["axis"] = axis.Value; + var op = tf.OpDefLib._apply_op_helper("Unpack", name: name, keywords: dict); + int _idx = 0; + var output = Enumerable.Range(0, op.OutputListLength("output")).Select(_ => op.outputs[_idx++]).ToArray(); + return (output); + } + + /// + /// Converts a flat index or array of flat indices into a tuple of + /// + /// + /// An 0-D or 1-D int Tensor whose elements are indices into the + /// flattened version of an array of dimensions dims. + /// + /// + /// An 1-D int Tensor. The shape of the array to use for unraveling + /// indices. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnravelIndex'. + /// + /// + /// An 2-D (or 1-D if indices is 0-D) tensor where each row has the + /// same shape as the indices array. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// coordinate arrays. + /// + /// @compatibility(numpy) + /// Equivalent to np.unravel_index + /// @end_compatibility + /// + public static Tensor unravel_index(Tensor indices, Tensor dims, string name = "UnravelIndex") + { + var dict = new Dictionary(); + dict["indices"] = indices; + dict["dims"] = dims; + var op = tf.OpDefLib._apply_op_helper("UnravelIndex", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the maximum along segments of a tensor. + /// + /// + /// + /// + /// A tensor whose shape is a prefix of data.shape.END + /// } + /// out_arg { + /// name: "output" + /// description: &lt;&lt;END + /// Has same shape as data, except for the first segment_ids.rank + /// dimensions, which are replaced with a single dimension which has size + /// num_segments. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnsortedSegmentMax'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to the unsorted segment sum operator found + /// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + /// Instead of computing the sum over segments, it computes the maximum such that: + /// + /// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples j... such + /// that segment_ids[j...] == i. + /// + /// If the maximum is empty for a given segment ID i, it outputs the smallest + /// possible value for the specific numeric type, + /// output[i] = numeric_limits&lt;T&gt;::lowest(). + /// + /// If the given segment ID i is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor unsorted_segment_max(Tensor data, Tensor segment_ids, Tensor num_segments, string name = "UnsortedSegmentMax") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("UnsortedSegmentMax", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the minimum along segments of a tensor. + /// + /// + /// + /// + /// A tensor whose shape is a prefix of data.shape. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnsortedSegmentMin'. + /// + /// + /// Has same shape as data, except for the first segment_ids.rank + /// dimensions, which are replaced with a single dimension which has size + /// num_segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to the unsorted segment sum operator found + /// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + /// Instead of computing the sum over segments, it computes the minimum such that: + /// + /// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples j... such + /// that segment_ids[j...] == i. + /// + /// If the minimum is empty for a given segment ID i, it outputs the largest + /// possible value for the specific numeric type, + /// output[i] = numeric_limits&lt;T&gt;::max(). + /// + /// If the given segment ID i is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// + public static Tensor unsorted_segment_min(Tensor data, Tensor segment_ids, Tensor num_segments, string name = "UnsortedSegmentMin") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("UnsortedSegmentMin", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the product along segments of a tensor. + /// + /// + /// + /// + /// A tensor whose shape is a prefix of data.shape. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnsortedSegmentProd'. + /// + /// + /// Has same shape as data, except for the first segment_ids.rank + /// dimensions, which are replaced with a single dimension which has size + /// num_segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation) + /// for an explanation of segments. + /// + /// This operator is similar to the unsorted segment sum operator found + /// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + /// Instead of computing the sum over segments, it computes the product of all + /// entries belonging to a segment such that: + /// + /// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples + /// j... such that segment_ids[j...] == i. + /// + /// If there is no entry for a given segment ID i, it outputs 1. + /// + /// If the given segment ID i is negative, then the corresponding value is + /// dropped, and will not be included in the result. + /// + public static Tensor unsorted_segment_prod(Tensor data, Tensor segment_ids, Tensor num_segments, string name = "UnsortedSegmentProd") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("UnsortedSegmentProd", name: name, keywords: dict); + return op.output; + } + + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// A tensor whose shape is a prefix of data.shape. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'UnsortedSegmentSum'. + /// + /// + /// Has same shape as data, except for the first segment_ids.rank + /// dimensions, which are replaced with a single dimension which has size + /// num_segments. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Read + /// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation) + /// for an explanation of segments. + /// + /// Computes a tensor such that + /// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples j... such + /// that segment_ids[j...] == i. Unlike SegmentSum, segment_ids + /// need not be sorted and need not cover all values in the full + /// range of valid values. + /// + /// If the sum is empty for a given segment ID i, output[i] = 0. + /// If the given segment ID i is negative, the value is dropped and will not be + /// added to the sum of the segment. + /// + /// num_segments should equal the number of distinct segment IDs. + /// + /// &lt;div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"&gt; + /// &lt;img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt&gt; + /// &lt;/div&gt; + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = "UnsortedSegmentSum") + { + var dict = new Dictionary(); + dict["data"] = data; + dict["segment_ids"] = segment_ids; + dict["num_segments"] = num_segments; + var op = tf.OpDefLib._apply_op_helper("UnsortedSegmentSum", name: name, keywords: dict); + return op.output; + } + + /// + /// Op is similar to a lightweight Dequeue. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Unstage'. + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The basic functionality is similar to dequeue with many fewer + /// capabilities and options. This Op is optimized for performance. + /// + public static Tensor[] unstage(TF_DataType[] dtypes, int? capacity = null, int? memory_limit = null, string container = null, string shared_name = null, string name = "Unstage") + { + var dict = new Dictionary(); + dict["dtypes"] = dtypes; + if (capacity.HasValue) + dict["capacity"] = capacity.Value; + if (memory_limit.HasValue) + dict["memory_limit"] = memory_limit.Value; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("Unstage", name: name, keywords: dict); + int _idx = 0; + var values = Enumerable.Range(0, op.OutputListLength("values")).Select(_ => op.outputs[_idx++]).ToArray(); + return (values); + } + + /// + /// Creates a handle to a Variable resource. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'VarHandleOp'. + /// + /// + /// Optional argument + /// the type of this variable. Must agree with the dtypes + /// of all ops using this variable. + /// + /// + /// Optional argument + /// The (possibly partially specified) shape of this variable. + /// + /// + /// the container this variable is placed in. + /// + /// + /// the name by which this variable is referred to. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor var_handle_op(TF_DataType dtype, Shape shape, string container = null, string shared_name = null, string name = "VarHandleOp") + { + var dict = new Dictionary(); + dict["dtype"] = dtype; + dict["shape"] = shape; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("VarHandleOp", name: name, keywords: dict); + return op.output; + } + + /// + /// Checks whether a resource handle-based variable has been initialized. + /// + /// + /// the input resource handle. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'VarIsInitializedOp'. + /// + /// + /// a scalar boolean which is true if the variable has been + /// initialized. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor var_is_initialized_op(Tensor resource, string name = "VarIsInitializedOp") + { + var dict = new Dictionary(); + dict["resource"] = resource; + var op = tf.OpDefLib._apply_op_helper("VarIsInitializedOp", name: name, keywords: dict); + return op.output; + } + + /// + /// Use VariableV2 instead. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Variable'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor variable(Shape shape, TF_DataType dtype, string container = null, string shared_name = null, string name = "Variable") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("Variable", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns the shape of the variable pointed to by resource. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'VariableShape'. + /// + /// + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns a 1-D integer tensor representing the shape of input. + /// + /// For example: + /// + /// + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// shape(t) ==&gt; [2, 2, 3] + /// + /// + public static Tensor variable_shape(Tensor input, TF_DataType? out_type = null, string name = "VariableShape") + { + var dict = new Dictionary(); + dict["input"] = input; + if (out_type.HasValue) + dict["out_type"] = out_type.Value; + var op = tf.OpDefLib._apply_op_helper("VariableShape", name: name, keywords: dict); + return op.output; + } + + /// + /// Holds state in the form of a tensor that persists across steps. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'VariableV2'. + /// + /// + /// Optional argument + /// The shape of the variable tensor. + /// + /// + /// Optional argument + /// The type of elements in the variable tensor. + /// + /// + /// If non-empty, this variable is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this variable is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// A reference to the variable tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Outputs a ref to the tensor state so it may be read or modified. + /// TODO(zhifengc/mrry): Adds a pointer to a more detail document + /// about sharing states in tensorflow. + /// + public static Tensor variable_v2(Shape shape, TF_DataType dtype, string container = null, string shared_name = null, string name = "VariableV2") + { + var dict = new Dictionary(); + dict["shape"] = shape; + dict["dtype"] = dtype; + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("VariableV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Returns locations of nonzero / true values in a tensor. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Where'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// This operation returns the coordinates of true elements in condition. The + /// coordinates are returned in a 2-D tensor where the first dimension (rows) + /// represents the number of true elements, and the second dimension (columns) + /// represents the coordinates of the true elements. Keep in mind, the shape of + /// the output tensor can vary depending on how many true values there are in + /// condition. Indices are output in row-major order. + /// + /// For example: + /// + /// + /// # 'input' tensor is [[True, False] + /// # [True, False]] + /// # 'input' has two true values, so output has two coordinates. + /// # 'input' has rank of 2, so coordinates have two indices. + /// where(input) ==&gt; [[0, 0], + /// [1, 0]] + /// + /// # condition tensor is [[[True, False] + /// # [True, False]] + /// # [[False, True] + /// # [False, True]] + /// # [[False, False] + /// # [False, True]]] + /// # 'input' has 5 true values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==&gt; [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// + /// # condition tensor is [[[1.5, 0.0] + /// # [-0.5, 0.0]] + /// # [[0.0, 0.25] + /// # [0.0, 0.75]] + /// # [[0.0, 0.0] + /// # [0.0, 0.01]]] + /// # 'input' has 5 nonzero values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==&gt; [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// + /// # condition tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] + /// # [0.0 + 0.5j, 0.0 + 0.0j]] + /// # [[0.0 + 0.0j, 0.25 + 1.5j] + /// # [0.0 + 0.0j, 0.75 + 0.0j]] + /// # [[0.0 + 0.0j, 0.0 + 0.0j] + /// # [0.0 + 0.0j, 0.01 + 0.0j]]] + /// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. + /// # 'input' has rank of 3, so coordinates have three indices. + /// where(input) ==&gt; [[0, 0, 0], + /// [0, 1, 0], + /// [1, 0, 1], + /// [1, 1, 1], + /// [2, 1, 1]] + /// + /// + public static Tensor where(Tensor input, string name = "Where") + { + var dict = new Dictionary(); + dict["input"] = input; + var op = tf.OpDefLib._apply_op_helper("Where", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the entire contents of a file as a value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'WholeFileReader'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// To use, enqueue filenames in a Queue. The output of ReaderRead will + /// be a filename (key) and the contents of that file (value). + /// + public static Tensor whole_file_reader(string container = null, string shared_name = null, string name = "WholeFileReader") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("WholeFileReader", name: name, keywords: dict); + return op.output; + } + + /// + /// A Reader that outputs the entire contents of a file as a value. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'WholeFileReaderV2'. + /// + /// + /// If non-empty, this reader is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// If non-empty, this reader is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// The handle to reference the Reader. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// To use, enqueue filenames in a Queue. The output of ReaderRead will + /// be a filename (key) and the contents of that file (value). + /// + public static Tensor whole_file_reader_v2(string container = null, string shared_name = null, string name = "WholeFileReaderV2") + { + var dict = new Dictionary(); + if (container != null) + dict["container"] = container; + if (shared_name != null) + dict["shared_name"] = shared_name; + var op = tf.OpDefLib._apply_op_helper("WholeFileReaderV2", name: name, keywords: dict); + return op.output; + } + + /// + /// Worker heartbeat op. + /// + /// + /// A string tensor containing a serialized WorkerHeartbeatRequest + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'WorkerHeartbeat'. + /// + /// + /// A string tensor containing a serialized WorkerHeartbeatResponse + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Heartbeats may be sent periodically to indicate the coordinator is still active, + /// to retrieve the current worker status and to expedite shutdown when necessary. + /// + public static Tensor worker_heartbeat(Tensor request, string name = "WorkerHeartbeat") + { + var dict = new Dictionary(); + dict["request"] = request; + var op = tf.OpDefLib._apply_op_helper("WorkerHeartbeat", name: name, keywords: dict); + return op.output; + } + + /// + /// Writes contents to the file at input filename. Creates file and recursively + /// + /// + /// scalar. The name of the file to which we write the contents. + /// + /// + /// scalar. The content to be written to the output file. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'WriteFile'. + /// + /// + /// Returns the description of the operation + /// + /// + /// creates directory if not existing. + /// + public static Operation write_file(Tensor filename, Tensor contents, string name = "WriteFile") + { + var dict = new Dictionary(); + dict["filename"] = filename; + dict["contents"] = contents; + var op = tf.OpDefLib._apply_op_helper("WriteFile", name: name, keywords: dict); + return op; + } + + /// + /// Returns a tensor of zeros with the same shape and type as x. + /// + /// + /// a tensor of type T. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ZerosLike'. + /// + /// + /// a tensor of the same shape and type as x but filled with zeros. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor zeros_like(Tensor x, string name = "ZerosLike") + { + var dict = new Dictionary(); + dict["x"] = x; + var op = tf.OpDefLib._apply_op_helper("ZerosLike", name: name, keywords: dict); + return op.output; + } + + /// + /// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Zeta'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// The Hurwitz zeta function is defined as: + /// + /// + /// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) + /// + public static Tensor zeta(Tensor x, Tensor q, string name = "Zeta") + { + var dict = new Dictionary(); + dict["x"] = x; + dict["q"] = q; + var op = tf.OpDefLib._apply_op_helper("Zeta", name: name, keywords: dict); + return op.output; + } + + /// + /// Creates a dataset that zips together input_datasets. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ZipDataset'. + /// + /// + /// Optional argument + /// + /// + /// Optional argument + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + public static Tensor zip_dataset(Tensor[] input_datasets, TF_DataType[] output_types, Shape[] output_shapes, string name = "ZipDataset") + { + var dict = new Dictionary(); + dict["input_datasets"] = input_datasets; + dict["output_types"] = output_types; + dict["output_shapes"] = output_shapes; + var op = tf.OpDefLib._apply_op_helper("ZipDataset", name: name, keywords: dict); + return op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs new file mode 100644 index 000000000..a6cc47182 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -0,0 +1,118 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ +using static Tensorflow.ApiDef.Types; +using System.Reflection; +using static Tensorflow.Binding; +using System.Xml.Linq; + +namespace Tensorflow +{ + public class gen_random_ops + { + /// + /// Outputs random values from a normal distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) + => tf.Context.ExecuteOp("RandomStandardNormal", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); + + /// + /// Outputs random integers from a uniform distribution. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_uniform_int(Tensor shape, Tensor minval, Tensor maxval, int? seed = 0, int? seed2 = 0, string name = null) + => tf.Context.ExecuteOp("RandomUniformInt", name, new ExecuteOpArgs(shape, minval, maxval) + .SetAttributes(new { seed = seed ?? 0, seed2 = seed2 ?? 0 })); + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) + => tf.Context.ExecuteOp("RandomUniform", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); + + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_shuffle(Tensor value, int? seed = 0, int? seed2 = 0, + string name = null) + => tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value) + .SetAttributes(new { seed = seed ?? 0, seed2 = seed2 ?? 0 })); + + /// + /// Outputs random values from a truncated normal distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, + int? seed2 = 0, string name = null) + => tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); + public static Tensor stateless_random_normal_v2(Tensor shape, Tensor key, Tensor counter, + int alg, TF_DataType dtype, string name = null) + => tf.Context.ExecuteOp("StatelessRandomNormalV2", name, + new ExecuteOpArgs(shape, key, counter, alg) + .SetAttributes(new { dtype })); + + public static Tensors stateless_random_get_key_counter(int[] seed, string name = null) + => tf.Context.ExecuteOp("StatelessRandomGetKeyCounter", name, + new ExecuteOpArgs(seed)); + + public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, + int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + if (output_dtype == TF_DataType.DtInvalid) + output_dtype = TF_DataType.TF_INT64; + + var _op = tf.OpDefLib._apply_op_helper("Multinomial", + name: name, + args: new { logits, num_samples, seed, seed2, output_dtype }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs new file mode 100644 index 000000000..db5f6813c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -0,0 +1,1523 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_resource_variable_ops +{ + /// + /// Adds a value to the current value of a variable. + /// + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to + /// see the incremented value or a subsequent newer one. + /// + /// + /// + /// + /// + public static Operation assign_add_variable_op(Tensor resource, Tensor value, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AssignAddVariableOp", name) { args = new object[] { resource, value }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return assign_add_variable_op_eager_fallback(resource, value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["value"] = value; + var _op = tf.OpDefLib._apply_op_helper("AssignAddVariableOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("AssignAddVariableOp", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation assign_add_variable_op_eager_fallback(Tensor resource, Tensor value, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, value }; + object[] _attrs = new object[] { "dtype", value.dtype }; + var _result = _execute.execute("AssignAddVariableOp", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AssignAddVariableOp", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Subtracts a value from the current value of a variable. + /// + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to + /// see the decremented value or a subsequent newer one. + /// + /// + /// + /// + /// + public static Operation assign_sub_variable_op(Tensor resource, Tensor value, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AssignSubVariableOp", name) { args = new object[] { resource, value }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return assign_sub_variable_op_eager_fallback(resource, value, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["value"] = value; + var _op = tf.OpDefLib._apply_op_helper("AssignSubVariableOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("AssignSubVariableOp", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation assign_sub_variable_op_eager_fallback(Tensor resource, Tensor value, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, value }; + object[] _attrs = new object[] { "dtype", value.dtype }; + var _result = _execute.execute("AssignSubVariableOp", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AssignSubVariableOp", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Assigns a new value to a variable. + /// + /// + /// + /// Any ReadVariableOp with a control dependency on this op is guaranteed to return + /// this value or a subsequent newer value of the variable. + /// + /// + /// + /// + /// + /// + public static Operation assign_variable_op(Tensor resource, Tensor value, bool validate_shape = false, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AssignVariableOp", name) { args = new object[] { resource, value }, attrs = new Dictionary() { ["validate_shape"] = validate_shape } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return assign_variable_op_eager_fallback(resource, value, validate_shape: validate_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["value"] = value; + keywords["validate_shape"] = validate_shape; + var _op = tf.OpDefLib._apply_op_helper("AssignVariableOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "validate_shape", _op._get_attr_bool("validate_shape") }; + _execute.record_gradient("AssignVariableOp", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation assign_variable_op_eager_fallback(Tensor resource, Tensor value, bool validate_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, value }; + object[] _attrs = new object[] { "dtype", value.dtype, "validate_shape", validate_shape }; + var _result = _execute.execute("AssignVariableOp", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("AssignVariableOp", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// This op consumes a lock created by `MutexLock`. + /// + /// + /// + /// This op exists to consume a tensor created by `MutexLock` (other than + /// direct control dependencies). It should be the only that consumes the tensor, + /// and will raise an error if it is not. Its only purpose is to keep the + /// mutex lock tensor alive until it is consumed by this op. + /// + /// **NOTE**: This operation must run on the same device as its input. This may + /// be enforced via the `colocate_with` mechanism. + /// + /// + /// + /// + public static Operation consume_mutex_lock(Tensor mutex_lock, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ConsumeMutexLock", name) { args = new object[] { mutex_lock }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return consume_mutex_lock_eager_fallback(mutex_lock, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["mutex_lock"] = mutex_lock; + var _op = tf.OpDefLib._apply_op_helper("ConsumeMutexLock", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("ConsumeMutexLock", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation consume_mutex_lock_eager_fallback(Tensor mutex_lock, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { mutex_lock }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("ConsumeMutexLock", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ConsumeMutexLock", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Deletes the resource specified by the handle. + /// + /// + /// + /// All subsequent operations using the resource will result in a NotFound + /// error status. + /// + /// + /// + /// + /// + /// whether to ignore the error when the resource + /// doesn't exist. + /// + /// + /// + public static Operation destroy_resource_op(Tensor resource, bool ignore_lookup_error = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DestroyResourceOp", name) { args = new object[] { resource }, attrs = new Dictionary() { ["ignore_lookup_error"] = ignore_lookup_error } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return destroy_resource_op_eager_fallback(resource, ignore_lookup_error: ignore_lookup_error, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["ignore_lookup_error"] = ignore_lookup_error; + var _op = tf.OpDefLib._apply_op_helper("DestroyResourceOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "ignore_lookup_error", _op._get_attr_bool("ignore_lookup_error") }; + _execute.record_gradient("DestroyResourceOp", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation destroy_resource_op_eager_fallback(Tensor resource, bool ignore_lookup_error, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource }; + object[] _attrs = new object[] { "ignore_lookup_error", ignore_lookup_error }; + var _result = _execute.execute("DestroyResourceOp", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DestroyResourceOp", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Turns off the copy-on-read mode. + /// + /// + /// + /// Turns off the copy-on-read mode of a resource variable. If the variable is not in copy-on-read mode, this op has no effect. + /// + /// + /// + /// + public static Operation disable_copy_on_read(Tensor resource, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DisableCopyOnRead", name) { args = new object[] { resource }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return disable_copy_on_read_eager_fallback(resource, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + var _op = tf.OpDefLib._apply_op_helper("DisableCopyOnRead", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("DisableCopyOnRead", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation disable_copy_on_read_eager_fallback(Tensor resource, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("DisableCopyOnRead", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DisableCopyOnRead", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Locks a mutex resource. The output is the lock. So long as the lock tensor + /// + /// + /// + /// is alive, any other request to use `MutexLock` with this mutex will wait. + /// + /// This is particularly useful for creating a critical section when used in + /// conjunction with `MutexLockIdentity`: + /// + /// ```python + /// + /// mutex = mutex_v2( + /// shared_name=handle_name, container=container, name=name) + /// + /// def execute_in_critical_section(fn, *args, **kwargs): + /// lock = gen_resource_variable_ops.mutex_lock(mutex) + /// + /// with ops.control_dependencies([lock]): + /// r = fn(*args, **kwargs) + /// + /// with ops.control_dependencies(nest.flatten(r)): + /// with ops.colocate_with(mutex): + /// ensure_lock_exists = mutex_lock_identity(lock) + /// + /// # Make sure that if any element of r is accessed, all of + /// # them are executed together. + /// r = nest.map_structure(tf.identity, r) + /// + /// with ops.control_dependencies([ensure_lock_exists]): + /// return nest.map_structure(tf.identity, r) + /// ``` + /// + /// While `fn` is running in the critical section, no other functions which wish to + /// use this critical section may run. + /// + /// Often the use case is that two executions of the same graph, in parallel, + /// wish to run `fn`; and we wish to ensure that only one of them executes + /// at a time. This is especially important if `fn` modifies one or more + /// variables at a time. + /// + /// It is also useful if two separate functions must share a resource, but we + /// wish to ensure the usage is exclusive. + /// + /// + /// + /// + public static Tensor mutex_lock(Tensor mutex, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MutexLock", name) { args = new object[] { mutex }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mutex_lock_eager_fallback(mutex, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["mutex"] = mutex; + var _op = tf.OpDefLib._apply_op_helper("MutexLock", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("MutexLock", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mutex_lock_eager_fallback(Tensor mutex, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { mutex }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("MutexLock", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MutexLock", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a Mutex resource that can be locked by `MutexLock`. + /// + /// + /// + /// If non-empty, this variable is placed in the given container. + /// Otherwise, a default container is used. + /// + /// + /// + /// + /// If non-empty, this variable is named in the given bucket + /// with this shared_name. Otherwise, the node name is used instead. + /// + /// + /// + public static Tensor mutex_v2(string container = "", string shared_name = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MutexV2", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return mutex_v2_eager_fallback(container: container, shared_name: shared_name, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + var _op = tf.OpDefLib._apply_op_helper("MutexV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name") }; + _execute.record_gradient("MutexV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor mutex_v2_eager_fallback(string container, string shared_name, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name }; + var _result = _execute.execute("MutexV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("MutexV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Reads the value of a variable. + /// + /// + /// + /// The tensor returned by this operation is immutable. + /// + /// The value returned by this operation is guaranteed to be influenced by all the + /// writes on which this operation depends directly or indirectly, and to not be + /// influenced by any of the writes which depend directly or indirectly on this + /// operation. + /// + /// + /// + /// + /// + /// the dtype of the value. + /// + /// + /// + public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ReadVariableOp", name) { args = new object[] { resource }, attrs = new Dictionary() { ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return read_variable_op_eager_fallback(resource, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("ReadVariableOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype") }; + _execute.record_gradient("ReadVariableOp", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor read_variable_op_eager_fallback(Tensor resource, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource }; + object[] _attrs = new object[] { "dtype", dtype }; + var _result = _execute.execute("ReadVariableOp", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ReadVariableOp", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Gather slices from the variable pointed to by `resource` according to `indices`. + /// + /// + /// + /// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). + /// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: + /// + /// ```python + /// # Scalar indices + /// output[:, ..., :] = params[indices, :, ... :] + /// + /// # Vector indices + /// output[i, :, ..., :] = params[indices[i], :, ... :] + /// + /// # Higher rank indices + /// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, int batch_dims = 0, bool validate_indices = true, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceGather", name) { args = new object[] { resource, indices }, attrs = new Dictionary() { ["batch_dims"] = batch_dims, ["validate_indices"] = validate_indices, ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_gather_eager_fallback(resource, indices, batch_dims: batch_dims, validate_indices: validate_indices, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["batch_dims"] = batch_dims; + keywords["validate_indices"] = validate_indices; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("ResourceGather", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "batch_dims", _op._get_attr_int("batch_dims"), "validate_indices", _op._get_attr_bool("validate_indices"), "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceGather", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor resource_gather_eager_fallback(Tensor resource, Tensor indices, int batch_dims, bool validate_indices, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices }; + object[] _attrs = new object[] { "batch_dims", batch_dims, "validate_indices", validate_indices, "dtype", dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceGather", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceGather", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor resource_gather_nd(Tensor resource, Tensor indices, TF_DataType dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceGatherNd", name) { args = new object[] { resource, indices }, attrs = new Dictionary() { ["dtype"] = dtype } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_gather_nd_eager_fallback(resource, indices, dtype: dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["dtype"] = dtype; + var _op = tf.OpDefLib._apply_op_helper("ResourceGatherNd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceGatherNd", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor resource_gather_nd_eager_fallback(Tensor resource, Tensor indices, TF_DataType dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices }; + object[] _attrs = new object[] { "dtype", dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceGatherNd", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceGatherNd", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Adds sparse updates to the variable referenced by `resource`. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] += updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] += updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions add. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_add(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterAdd", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_add_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterAdd", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterAdd", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_add_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterAdd", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterAdd", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Divides sparse updates into the variable referenced by `resource`. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] /= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] /= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions multiply. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_div(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterDiv", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_div_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterDiv", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterDiv", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_div_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterDiv", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterDiv", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = max(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions are combined. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_max(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterMax", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_max_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterMax", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterMax", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_max_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterMax", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterMax", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = min(ref[indices, ...], updates[...]) + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions are combined. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_min(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterMin", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_min_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterMin", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterMin", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_min_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterMin", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterMin", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Multiplies sparse updates into the variable referenced by `resource`. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] *= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] *= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions multiply. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_mul(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterMul", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_mul_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterMul", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterMul", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_mul_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterMul", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterMul", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Subtracts sparse updates from the variable referenced by `resource`. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] -= updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] -= updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + /// + /// Duplicate entries are handled correctly: if multiple `indices` reference + /// the same location, their contributions add. + /// + /// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + /// + ///
+ /// + ///
+ /// + ///
+ /// + /// + /// + /// + public static Operation resource_scatter_sub(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterSub", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_sub_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterSub", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterSub", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_sub_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterSub", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterSub", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Assigns sparse updates to the variable referenced by `resource`. + /// + /// + /// + /// This operation computes + /// + /// # Scalar indices + /// ref[indices, ...] = updates[...] + /// + /// # Vector indices (for each i) + /// ref[indices[i], ...] = updates[i, ...] + /// + /// # High rank indices (for each i, ..., j) + /// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + /// + /// + /// + /// + /// + /// + public static Operation resource_scatter_update(Tensor resource, Tensor indices, Tensor updates, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ResourceScatterUpdate", name) { args = new object[] { resource, indices, updates }, attrs = new Dictionary() { } }); + return null; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return resource_scatter_update_eager_fallback(resource, indices, updates, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + keywords["indices"] = indices; + keywords["updates"] = updates; + var _op = tf.OpDefLib._apply_op_helper("ResourceScatterUpdate", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "Tindices", _op._get_attr_type("Tindices") }; + _execute.record_gradient("ResourceScatterUpdate", _op.inputs, _attrs, _result); + } + return _op; + } + + public static Operation resource_scatter_update_eager_fallback(Tensor resource, Tensor indices, Tensor updates, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource, indices, updates }; + object[] _attrs = new object[] { "dtype", updates.dtype, "Tindices", indices.dtype }; + var _result = _execute.execute("ResourceScatterUpdate", 0, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ResourceScatterUpdate", _inputs_flat, _attrs, _result); + } + return null; + } + /// + /// Creates a handle to a Variable resource. + /// + /// + /// + /// the container this variable is placed in. + /// + /// + /// + /// + /// the name by which this variable is referred to. + /// + /// + /// + /// + /// the type of this variable. Must agree with the dtypes + /// of all ops using this variable. + /// + /// + /// + /// + /// The (possibly partially specified) shape of this variable. + /// + /// + /// + /// + /// DEPRECATED. The allowed devices containing the resource variable. Set when the + /// output ResourceHandle represents a per-replica/partitioned resource variable. + /// + /// + /// + public static Tensor var_handle_op(TF_DataType dtype, Shape shape, string container = "", string shared_name = "", string[] allowed_devices = null, string? name = null) + { + var _ctx = tf.Context; + if (allowed_devices is null) + { + allowed_devices = new string[] { }; + } + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "VarHandleOp", name) { args = new object[] { }, attrs = new Dictionary() { ["container"] = container, ["shared_name"] = shared_name, ["dtype"] = dtype, ["shape"] = shape, ["allowed_devices"] = allowed_devices } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return var_handle_op_eager_fallback(container: container, shared_name: shared_name, dtype: dtype, shape: shape, allowed_devices: allowed_devices, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (container is null) + { + container = ""; + } + if (shared_name is null) + { + shared_name = ""; + } + Dictionary keywords = new(); + keywords["container"] = container; + keywords["shared_name"] = shared_name; + keywords["dtype"] = dtype; + keywords["shape"] = shape; + keywords["allowed_devices"] = allowed_devices; + var _op = tf.OpDefLib._apply_op_helper("VarHandleOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "container", _op.get_attr("container"), "shared_name", _op.get_attr("shared_name"), "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape"), "allowed_devices", _op.get_attr("allowed_devices") }; + _execute.record_gradient("VarHandleOp", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor var_handle_op_eager_fallback(string container, string shared_name, TF_DataType dtype, Shape shape, string[] allowed_devices, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "container", container, "shared_name", shared_name, "dtype", dtype, "shape", shape, "allowed_devices", allowed_devices }; + var _result = _execute.execute("VarHandleOp", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("VarHandleOp", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Checks whether a resource handle-based variable has been initialized. + /// + /// + /// + public static Tensor var_is_initialized_op(Tensor resource, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "VarIsInitializedOp", name) { args = new object[] { resource }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return var_is_initialized_op_eager_fallback(resource, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["resource"] = resource; + var _op = tf.OpDefLib._apply_op_helper("VarIsInitializedOp", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("VarIsInitializedOp", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor var_is_initialized_op_eager_fallback(Tensor resource, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { resource }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("VarIsInitializedOp", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("VarIsInitializedOp", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the shape of the variable pointed to by `resource`. + /// + /// + /// + /// This operation returns a 1-D integer tensor representing the shape of `input`. + /// + /// For example: + /// + /// ``` + /// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] + /// shape(t) ==> [2, 2, 3] + /// ``` + /// + /// + /// + /// + /// + public static Tensor variable_shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "VariableShape", name) { args = new object[] { input }, attrs = new Dictionary() { ["out_type"] = out_type } }); + return _fast_path_result[0]; + } + catch (NotOkStatusException ex) + { + throw ex; + } + catch (Exception) + { + } + try + { + return variable_shape_eager_fallback(input, out_type: out_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["out_type"] = out_type; + var _op = tf.OpDefLib._apply_op_helper("VariableShape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "out_type", _op._get_attr_type("out_type") }; + _execute.record_gradient("VariableShape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor variable_shape_eager_fallback(Tensor input, TF_DataType out_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "out_type", out_type }; + var _result = _execute.execute("VariableShape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("VariableShape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs new file mode 100644 index 000000000..73829b29c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs @@ -0,0 +1,71 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_sparse_ops + { + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_to_dense(Tensor sparse_indices, + int[] output_shape, + T sparse_values, + T default_value, + bool validate_indices = true, + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("SparseToDense", name, args: new + { + sparse_indices, + output_shape, + sparse_values, + default_value, + validate_indices + }); + + return _op.output; + } + + public static Tensor sparse_to_dense(Tensor sparse_indices, + Tensor output_shape, + Tensor sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("SparseToDense", name, args: new + { + sparse_indices, + output_shape, + sparse_values, + default_value, + validate_indices + }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs new file mode 100644 index 000000000..363d3144e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -0,0 +1,60 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; +using static Tensorflow.CppShapeInferenceResult.Types; + +namespace Tensorflow.Operations +{ + public static class handle_data_util + { + public static void copy_handle_data(Tensor source_t, Tensor target_t) + { + if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) + { + HandleData handle_data; + if(source_t is EagerTensor) + { + handle_data = source_t.HandleData; + } + else + { + handle_data = ops.get_resource_handle_data(source_t); + } + if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null + && handle_data.ShapeAndType.Count > 0) + { + set_handle_data(target_t, handle_data); + } + } + } + + public static HandleData create_handle_data(Shape shape, TF_DataType dtype) + { + HandleData handle_data = new(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new HandleShapeAndType() + { + Shape = shape.as_proto(), + Dtype = dtype.as_datatype_enum() + }); + return handle_data; + } + + public static void set_handle_data(Tensor target_t, HandleData handle_data) + { + if(target_t is EagerTensor) + { + target_t.HandleData = handle_data; + return; + } + Status status = new(); + var proto = handle_data.ToByteArray(); + c_api.TF_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); + status.Check(true); + } + + public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); + } +} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs new file mode 100644 index 000000000..f1aff28ee --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -0,0 +1,2242 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class image_ops_impl + { + internal static Operation _assert(Tensor cond, Type ex_type, string msg) + { + if (_is_tensor(cond)) + return control_flow_ops.Assert(cond, new object[] { msg }); + else + if (cond != null) + { + Exception ex_type2 = (Exception)Activator.CreateInstance(ex_type, msg, ex_type); + throw ex_type2; + } + else + { + Operation x = null; + return x; + } + } + + internal static bool _is_tensor(object x) + { + if (isinstance(x, typeof(Tensor))) + return true; + else if (isinstance(x, typeof(IVariableV1))) + return true; + else + return false; + } + + internal static long[] _ImageDimensions(Tensor image, int rank) + { + if (image.shape.IsFullyDefined) + return image.shape.dims; + else + { + var static_shape = image.shape.with_rank(rank).dims; + var dynamic_shape = array_ops.unstack(array_ops.shape(image), rank); + + long[] ss_storage = null; + long[] ds_storage = null; + // var sd = static_shape.Zip(dynamic_shape, (first, second) => storage[storage.Length] = first; + var sd = static_shape.Zip(dynamic_shape, (ss, ds) => + { + ss_storage[ss_storage.Length] = ss; + ds_storage[ds_storage.Length] = (long)ds; + return true; + }); + + if (ss_storage != null) + return ss_storage; + else + return ds_storage; + } + } + + internal static Tensor _AssertAtLeast3DImage(Tensor image) + => control_flow_ops.with_dependencies(_CheckAtLeast3DImage(image, require_static: false), image); + + internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_static) + { + Shape image_shape; + try + { + if (image.shape.ndim == Unknown) + { + image_shape = image.shape.with_rank(3); + } + else + { + image_shape = image.shape.with_rank_at_least(3); + } + } + catch (ValueError) + { + throw new ValueError("'image' must be at least three-dimensional."); + } + if (require_static & !image_shape.IsFullyDefined) + { + throw new ValueError("\'image\' must be fully defined."); + } + var dims = new Shape(new[] { + image_shape.dims[image_shape.dims.Length - 3], + image_shape.dims[image_shape.dims.Length - 2], + image_shape.dims[image_shape.dims.Length - 1]}); + foreach (var dim in dims.dims) + { + if (dim == 0) + { + throw new ValueError("inner 3 dimensions of \'image\' must be > 0: " + image_shape); + } + } + + var image_shape_last_three_elements = new Shape(new[] { + image_shape.dims[image_shape.dims.Length - 3], + image_shape.dims[image_shape.dims.Length - 2], + image_shape.dims[image_shape.dims.Length - 1]}); + if (!image_shape_last_three_elements.IsFullyDefined) + { + Tensor image_shape_ = array_ops.shape(image); + var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 }); + + //var image_shape_return = tf.constant(new[] { + // image_shape_.dims[image_shape_.dims.Length - 3], + // image_shape_.dims[image_shape_.dims.Length - 2], + // image_shape_.dims[image_shape_.dims.Length - 1]}); + + return new Operation[] { + check_ops.assert_positive( + image_shape_return, + new object[] {"inner 3 dims of 'image.shape must be > 0."} + ), + check_ops.assert_greater_equal( + x: array_ops.rank(image), + y: tf.constant(3), + message: "'image' must be at least three-dimensional." + ) + }; + } + else + { + return new Operation[] { }; + } + } + + internal static Tensor fix_image_flip_shape(Tensor image, Tensor result) + { + Shape image_shape = image.shape; + if (image_shape == image_shape.unknown_shape()) + { + // c# defaults null types to 0 anyhow, so this should be a pretty equivalent port + result.shape = new long[] { 0, 0, 0 }; + } + else + { + result.shape = image_shape; + } + return result; + } + + public static Tensor random_flip_up_down(Tensor image, int seed = 0) + => _random_flip(image: image, + flip_index: 0, + seed: seed, + scope_name: "random_flip_up_down"); + + public static Tensor random_flip_left_right(Tensor image, int seed = 0) + => _random_flip(image: image, + flip_index: 1, + seed: seed, + scope_name: "random_flip_left_right"); + + internal static Tensor _random_flip(Tensor image, int flip_index, int seed, string scope_name) + { + return tf_with(ops.name_scope(null, scope_name, new[] { image }), scope => + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + Shape shape = image.shape; + if (shape.ndim == 3 || shape.ndim == Unknown) + { + Tensor uniform_random = random_ops.random_uniform(new int[] { }, 0f, 1.0f, seed: seed); + var mirror_cond = gen_math_ops.less(uniform_random, ops.convert_to_tensor(.5)); + + var result = control_flow_ops.cond( + pred: mirror_cond, + true_fn: () => gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })), + false_fn: () => image, + name: scope + ); + return fix_image_flip_shape(image, result); + } + else if (shape.ndim == 4) + { + var batch_size = array_ops.shape(image); + var uniform_random = random_ops.random_uniform(batch_size.shape, + 0f, + 1.0f, + seed: seed); + var flips = math_ops.round( + array_ops.reshape(uniform_random, shape: array_ops.constant(value: new object[] { batch_size[0], 1, 1, 1 }))); + flips = math_ops.cast(flips, image.dtype); + var flipped_input = gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index + 1 })); + return flips * flipped_input + (1 - flips) * image; + } + else + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", shape)); + } + }); + } + + public static Tensor flip_left_right(Tensor image) + => _flip(image, 1, "flip_left_right"); + + public static Tensor flip_up_down(Tensor image) + => _flip(image, 0, "flip_up_down"); + + internal static Tensor _flip(Tensor image, int flip_index, string scope_name) + { + return tf_with(ops.name_scope(null, scope_name, new { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + Shape shape = image.shape; + if (shape.ndim == 3 || shape.ndim == Unknown) + { + return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index }))); + } + else if (shape.ndim == 4) + { + return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 })); + } + else + { + throw new ValueError("\'image\' must have either 3 or 4 dimensions."); + } + }); + } + + public static Tensor rot90(Tensor image, int k = 1, string name = null) + { + return tf_with(ops.name_scope(name, "rot90", new[] { image, tf.constant(k) }), scope => + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + + // can't get k to convert to tensor without throwing error about it being an int--- + // might rework later. for now, k2 == k as Tensor + Tensor k2 = ops.convert_to_tensor(k, dtype: dtypes.int32, name: "k"); + k2.shape.assert_has_rank(0); + k2 = gen_ops.mod(k2, tf.constant(4)); + + Shape shape = image.shape; + if (shape.ndim == 3 || shape.ndim == Unknown) + { + return _rot90_3D(image, k, scope); + } + else if (shape.ndim == 4) + { + return _rot90_3D(image, k, scope); + } + else + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", shape)); + } + }); + } + + internal static Tensor _rot90_3D(Tensor image, int k, string name_scope) + { + Tensor _rot90() + { + return array_ops.transpose(gen_array_ops.reverse(image, ops.convert_to_tensor(new[] { 1, 0, 2 })), new int[] { 1 }); + }; + Tensor _rot180() + { + return gen_array_ops.reverse(image, ops.convert_to_tensor(new[] { 0, 1 })); + }; + Tensor _rot270() + { + return gen_array_ops.reverse(array_ops.transpose(image, new[] { 1, 0, 2 }), ops.convert_to_tensor(new[] { 1 })); + }; + + var cases = new[] {math_ops.equal(k, 1), _rot90(), + math_ops.equal(k, 2), _rot180(), + math_ops.equal(k, 3), _rot270()}; + + var result = control_flow_ops.case_v2(cases, callable_default: () => new Tensor[] { image }, exclusive: true, name: name_scope); + result.shape = new long[] { -1, -1, image.shape.dims[2] }; + return result; + } + + public static Tensor transpose(Tensor image, string name = null) + { + using (ops.name_scope(name, "transpose", new[] { image })) + return tf_with(ops.name_scope(name, "transpose", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + Shape shape = image.shape; + if (shape.ndim == 3 || shape.ndim == Unknown) + { + return array_ops.transpose(image, new[] { 1, 0, 2 }, name: name); + } + else if (shape.ndim == 4) + { + return array_ops.transpose(image, new[] { 0, 2, 1, 3 }, name: name); + } + else + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.")); + } + }); + } + + public static Tensor central_crop(Tensor image, float central_fraction) + { + using (ops.name_scope(null, "central_crop", new[] { image })) + { + image = ops.convert_to_tensor(image, name: "image"); + if (central_fraction <= 0.0 || central_fraction > 1.0) + throw new ValueError("central_fraction must be within (0, 1]"); + if (central_fraction == 1.0) + return image; + + _AssertAtLeast3DImage(image); + var rank = image.shape.ndim; + if (rank != 3 && rank != 4) + throw new ValueError(String.Format(@"`image` should either be a Tensor with rank = 3 +or rank = 4. Had rank = {0}", rank)); + + object[] _get_dim(Tensor tensor, int idx) + { + var static_shape = tensor.shape.dims[idx]; + if (static_shape != (int)None) + return new object[2] { static_shape, false }; + return new object[2] { array_ops.shape(tensor)[idx], true }; + }; + + object[] h, w; + int d, bs = 0; + if (rank == 3) + { + h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1] + w = _get_dim(image, 1); + d = (int)image.shape[3]; + } + else + { + bs = (int)image.shape[0]; + h = _get_dim(image, 1); + w = _get_dim(image, 2); + d = (int)image.shape[3]; + } + + object hd, bbox_h_start; + if ((bool)h[1]) + { + hd = math_ops.cast((IVariableV1)h[0], dtypes.float64); + bbox_h_start = ((int)hd - (int)hd * central_fraction) / 2; + } + else + { + hd = (float)w[0]; + bbox_h_start = (int)(((int)hd - (int)hd * central_fraction) / 2); + } + + object wd, bbox_w_start; + if ((bool)w[1]) + { + wd = math_ops.cast((IVariableV1)w[0], dtypes.float64); + bbox_w_start = ((int)wd - (int)wd * central_fraction) / 2; + } + else + { + wd = (float)w[0]; + bbox_w_start = (int)(((int)wd - (int)wd * central_fraction) / 2); + } + + var bbox_h_size = (int)h[0] - (int)bbox_h_start * 2; + var bbox_w_size = (int)w[0] - (int)bbox_w_start * 2; + + Tensor bbox_begin, bbox_size; + if (rank == 3) + { + bbox_begin = array_ops.stack(ops.convert_to_tensor(new[] { bbox_h_start, bbox_w_start, 0 })); + bbox_size = array_ops.stack(ops.convert_to_tensor(new[] { bbox_h_size, bbox_w_size, -1 })); + } + else + { + bbox_begin = array_ops.stack(ops.convert_to_tensor(new[] { 0, bbox_h_start, bbox_w_start, 0 })); + bbox_size = array_ops.stack(ops.convert_to_tensor(new[] { -1, bbox_h_size, bbox_w_size, -1 })); + } + + image = array_ops.slice(image, bbox_begin, bbox_size); + + int arg1() + { + if ((bool)h[1]) + { + // 0 == null for nullable ints anyways + return 0; + } + else + { + return bbox_h_size; + } + }; + int arg2() + { + if ((bool)w[1]) + { + return 0; + } + else + { + return bbox_w_size; + } + }; + if (rank == 3) + { + var _arg1 = arg1(); + var _arg2 = arg2(); + + image.set_shape(ops.convert_to_tensor(new object[ + _arg1, _arg2, d + ])); + } + else + { + var _arg1 = arg1(); + var _arg2 = arg2(); + image.set_shape(ops.convert_to_tensor(new object[] { + bs, _arg1, _arg2, d + })); + } + } + + return image; + } + + public static Tensor pad_to_bounding_box(Tensor image, int offset_height, int offset_width, + int target_height, int target_width) + { + return tf_with(ops.name_scope(null, "pad_to_bounding_box", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + + bool is_batch = true; + Shape image_shape = image.shape; + if (image_shape.ndim == 3) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + } + else if (image_shape.ndim == Unknown) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + image.shape = new Shape(0, 0, 0, 0); + } + else if (image_shape.ndim != 4) + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", + image_shape)); + } + + var assert_ops = _CheckAtLeast3DImage(image, require_static: false); + + // batch: [0], height: [1], width: [2], depth: [3] + var bhwd = _ImageDimensions(image, rank: 4); + + var after_padding_width = target_width - offset_width - bhwd[2]; + + var after_padding_height = target_height - offset_height - bhwd[1]; + + assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(offset_height), + tf.constant(0)), typeof(ValueError), + "offset_height must be >= 0"); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(offset_width), + tf.constant(0)), typeof(ValueError), + "offset_width must be >= 0"); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(after_padding_width), + tf.constant(0)), typeof(ValueError), + "width must be <= target - offset"); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(after_padding_height), + tf.constant(0)), typeof(ValueError), + "height must be <= target - offset"); + image = control_flow_ops.with_dependencies(assert_ops, image); + + var paddings = array_ops.reshape( + array_ops.stack(new[] { + 0, 0, offset_height, after_padding_height, offset_width, + after_padding_width, 0, 0 + }), new[] { 4, 2 } + ); + var padded = array_ops.pad(image, paddings); + + Shape padded_shape_result() + { + long[] i_remnants = { }; + foreach (var i in new[] { bhwd[0], target_height, target_width, bhwd[3] }) + if (_is_tensor(i)) + return null; + else + i_remnants[i_remnants.Length] = i; + return new Shape(i_remnants); + }; + Shape padded_shape = padded_shape_result(); + padded.shape = padded_shape; + + if (!is_batch) + { + padded = array_ops.squeeze(padded, axis: new int[] { 0 }); + } + + return padded; + }); + } + + public static Tensor crop_to_bounding_box(Tensor image, int offset_height, int offset_width, + int target_height, int target_width) + { + return tf_with(ops.name_scope(null, "crop_to_bounding_box", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + + bool is_batch = true; + Shape image_shape = image.shape; + if (image_shape.ndim == 3) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + } + else if (image_shape.ndim == Unknown) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + image.shape = new long[] { 0, 0, 0, 0 }; + } + else if (image_shape.ndim != 4) + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", + image_shape)); + } + + var assert_ops = _CheckAtLeast3DImage(image, require_static: false).ToList(); + + // batch: [0], height: [1], width: [2], depth: [3] + var bhwd = _ImageDimensions(image, rank: 4); + + assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(offset_height), + tf.constant(0)), typeof(ValueError), + "offset_height must be >= 0.")); + assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(offset_width), + tf.constant(0)), typeof(ValueError), + "offset_width must be >= 0.")); + assert_ops.Add(_assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_width)), typeof(ValueError), + "target_width must be > 0.")); + assert_ops.Add(_assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_height)), typeof(ValueError), + "target_height must be > 0.")); + assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(bhwd[2]), + tf.constant(target_width + offset_width)), + typeof(ValueError), + "width must be >= target + offset.")); + assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(bhwd[1]), + tf.constant(target_height + offset_height)), + typeof(ValueError), + "height must be >= target + offset.")); + image = control_flow_ops.with_dependencies(assert_ops.ToArray(), image); + + var cropped = array_ops.slice( + image, array_ops.stack(new[] { 0, offset_height, offset_width, 0 }), + array_ops.stack(new[] { -1, target_height, target_width, -1 })); + + Shape cropped_shape_result() + { + long[] i_remnants = new long[4]; + int idx = 0; + foreach (var i in new[] { bhwd[0], target_height, target_width, bhwd[3] }) + { + if (_is_tensor(i)) + i_remnants[idx] = -1; + else + i_remnants[idx] = i; + idx++; + } + return new Shape(i_remnants); + }; + var cropped_shape = cropped_shape_result(); + cropped.shape = cropped_shape; + + if (!is_batch) + { + cropped = array_ops.squeeze(cropped, axis: new int[] { 0 }); + } + + return cropped; + }); + } + + public static Tensor resize_image_with_crop_or_pad(Tensor image, object target_height, object target_width) + { + using (ops.name_scope(null, "resize_image_with_crop_or_pad", new[] { image })) + return tf_with(ops.name_scope(null, "resize_image_with_crop_or_pad", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + Shape image_shape = image.shape; + bool is_batch = true; + if (image_shape.ndim == 3) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + } + else if (image_shape.ndim == Unknown) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + image.shape = new long[] { 0, 0, 0, 0 }; + } + else if (image_shape.ndim != 4) + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", + image_shape)); + } + + var assert_ops = _CheckAtLeast3DImage(image, require_static: false); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_width)), + typeof(ValueError), + "target_width must be > 0."); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_height)), + typeof(ValueError), + "target_height must be > 0."); + + image = control_flow_ops.with_dependencies(assert_ops, image); + + if (_is_tensor(target_height)) + { + target_height = control_flow_ops.with_dependencies( + assert_ops, tf.constant(target_height)); + } + if (_is_tensor(target_width)) + { + target_width = control_flow_ops.with_dependencies( + assert_ops, tf.constant(target_width)); + } + + + object max_(object x, object y) + { + if (_is_tensor(x) || _is_tensor(y)) + return math_ops.maximum(x, y); + else + return Math.Max((int)x, (int)y); + } + + object min_(object x, object y) + { + if (_is_tensor(x) || _is_tensor(y)) + return math_ops.minimum(x, y); + else + return Math.Min((int)x, (int)y); + } + + object equal_(object x, object y) + { + if (_is_tensor(x) || _is_tensor(y)) + return math_ops.equal(x, y); + else + return x == y; + } + + var _hw_ = _ImageDimensions(image, rank: 4); + var width_diff = (long)target_width - _hw_[2]; + int offset_crop_width = (int)max_(Math.Floor(Math.Abs((decimal)width_diff) / 2), 0); + int offset_pad_width = (int)max_(Math.Floor((decimal)width_diff / 2), 0); + + var height_diff = (long)target_height - _hw_[1]; + int offset_crop_height = (int)max_(Math.Floor(Math.Abs((decimal)height_diff) / 2), 0); + int offset_pad_height = (int)max_(Math.Floor((decimal)height_diff / 2), 0); + + Tensor cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width, + (int)min_(target_height, _hw_[1]), + (int)min_(target_width, _hw_[2])); + + Tensor resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, + (int)target_height, (int)target_width); + + if (resized.shape.ndim == Unknown) + throw new ValueError("resized contains no shape."); + + var _rhrw_ = _ImageDimensions(resized, rank: 4); + + assert_ops = new Operation[2]; + assert_ops[0] = _assert( + (Tensor)equal_(_rhrw_[1], target_height), typeof(ValueError), + "resized height is not correct."); + assert_ops[1] = _assert( + (Tensor)equal_(_rhrw_[2], target_width), typeof(ValueError), + "resized width is not correct."); + + resized = control_flow_ops.with_dependencies(assert_ops, resized); + + if (!is_batch) + { + resized = array_ops.squeeze(resized, axis: new int[] { 0 }); + } + + return resized; + }); + } + + internal static Tensor _resize_images_common(Tensor images, Func resizer_fn, + Tensor size, bool preserve_aspect_ratio, string name, bool skip_resize_if_same) + { + return tf_with(ops.name_scope(name, "resize", new[] { images, size }), delegate + { + if (images.shape.ndim == Unknown) + throw new ValueError("\'images\' contains no shape."); + bool is_batch = true; + if (images.shape.ndim == 3) + { + is_batch = false; + images = array_ops.expand_dims(images, 0); + } + else if (images.shape.ndim != 4) + throw new ValueError("\'images\' must have either 3 or 4 dimensions."); + + var (height, width) = (images.dims[1], images.dims[2]); + + if (!size.shape.is_compatible_with(new[] { 2 })) + throw new ValueError(@"\'size\' must be a 1-D Tensor of 2 elements: +new_height, new_width"); + + if (preserve_aspect_ratio) + { + var _chcw_ = _ImageDimensions(images, rank: 4); + + var scale_factor_height = + math_ops.cast(size[0], dtypes.float32) / _chcw_[1]; + var scale_factor_width = + math_ops.cast(size[1], dtypes.float32) / _chcw_[2]; + var scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width); + var scaled_height_const = math_ops.cast( + math_ops.round(scale_factor * _chcw_[1]), + dtypes.int32); + var scaled_width_const = math_ops.cast( + math_ops.round(scale_factor * _chcw_[2]), + dtypes.int32); + + size = ops.convert_to_tensor(new[] { scaled_height_const, scaled_width_const }, + dtypes.int32, + name: "size"); + } + + var size_const_as_shape = tensor_util.constant_value_as_shape(size); + var new_height_const = tensor_shape.dimension_at_index(size_const_as_shape, + 0).value; + var new_width_const = tensor_shape.dimension_at_index(size_const_as_shape, + 1).value; + + bool x_null = true; + if (skip_resize_if_same) + { + foreach (int x in new[] { new_width_const, width, new_height_const, height }) + { + if (width != new_width_const && height == new_height_const) + { + break; + } + if (x != 0) + { + x_null = false; + } + } + if (!x_null) + images = array_ops.squeeze(images, axis: new int[] { 0 }); + return images; + } + + images = resizer_fn(images, size); + + images.shape = new Shape(Unknown, new_height_const, new_width_const, Unknown); + + if (!is_batch) + images = array_ops.squeeze(images, axis: new int[] { 0 }); + return images; + }); + } + + public static Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, + bool preserve_aspect_ratio = false, bool antialias = false, string name = null) + { + Tensor resize_fn(Tensor images_t, Tensor new_size) + { + var scale_and_translate_methods = new string[] { + ResizeMethod.LANCZOS3, ResizeMethod.LANCZOS5, ResizeMethod.GAUSSIAN, + ResizeMethod.MITCHELLCUBIC + }; + + Tensor resize_with_scale_and_translate(string method) + { + var scale = new Tensor[] { + math_ops.cast(new_size, dtype: dtypes.float32), + // does this need to be reworked into only elements 1-3 being + // passed like it is in the tensorflow code? or does it matter? + math_ops.cast(array_ops.shape(images_t), dtype: dtypes.float32) + }; + return gen_ops.scale_and_translate( + images_t, + new_size, + scale, + array_ops.zeros(new[] { 2 }), + kernel_type: method, + antialias: antialias + ); + } + + if (method == ResizeMethod.BILINEAR) + if (antialias) + return resize_with_scale_and_translate("triangle"); + else + return gen_image_ops.resize_bilinear(images_t, + new_size, + half_pixel_centers: true); + else if (method == ResizeMethod.NEAREST_NEIGHBOR) + return gen_image_ops.resize_nearest_neighbor(images_t, + new_size, + half_pixel_centers: true); + else if (method == ResizeMethod.BICUBIC) + if (antialias) + return resize_with_scale_and_translate("keyscubic"); + else + return gen_image_ops.resize_bicubic(images_t, + new_size, + half_pixel_centers: true); + else if (method == ResizeMethod.AREA) + return gen_ops.resize_area(images_t, new_size); + else if (Array.Exists(scale_and_translate_methods, method => method == method)) + return resize_with_scale_and_translate(method); + else + throw new ValueError(String.Format("Resize method is not implemented: {0}", + method)); + } + + return _resize_images_common( + images, + resize_fn, + size, + preserve_aspect_ratio: preserve_aspect_ratio, + name: name, + skip_resize_if_same: false + ); + } + + internal static Tensor _resize_image_with_pad_common(Tensor image, int target_height, int target_width, + Func resize_fn) + { + using (ops.name_scope(null, "resize_image_with_pad", new[] { image })) + return tf_with(ops.name_scope(null, "resize_image_with_pad", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "tensor"); + var image_shape = image.shape; + bool is_batch = true; + if (image_shape.ndim == 3) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + } + else if (image_shape.ndim == Unknown) + { + is_batch = false; + image = array_ops.expand_dims(image, 0); + image.shape = new Shape(Unknown, Unknown, Unknown, Unknown); + } + else if (image_shape.ndim != 4) + { + throw new ValueError(String.Format("\'image\' {0} must have either 3 or 4 dimensions.", + image_shape)); + } + + var assert_ops = _CheckAtLeast3DImage(image, require_static: false); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_width)), + typeof(ValueError), + "target_width must be > 0."); + assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0), + tf.constant(target_height)), + typeof(ValueError), + "target_height must be > 0."); + + image = control_flow_ops.with_dependencies(assert_ops, image); + + object max_(object x, object y) + { + if (_is_tensor(x) || _is_tensor(y)) + return math_ops.maximum(x, y); + else + return Math.Max((int)x, (int)y); + } + + var _hw_ = _ImageDimensions(image, rank: 4); + + var f_height = _hw_[1]; + var f_width = _hw_[2]; + var f_target_height = target_height; + var f_target_width = target_width; + + var ratio = (Tensor)max_(f_width / f_target_width, f_height / f_target_height); + var resized_height_float = f_height / ratio; + var resized_width_float = f_width / ratio; + var resized_height = math_ops.cast( + gen_math_ops.floor(resized_height_float), dtype: dtypes.int32); + var resized_width = math_ops.cast( + gen_math_ops.floor(resized_width_float), dtype: dtypes.int32); + + var padding_height = (f_target_height - resized_height_float) / 2; + var padding_width = (f_target_width - resized_width_float) / 2; + var f_padding_height = gen_math_ops.floor(padding_height); + var f_padding_width = gen_math_ops.floor(padding_width); + int p_height = (int)max_(0, math_ops.cast(f_padding_height, dtype: dtypes.int32)); + int p_width = (int)max_(0, math_ops.cast(f_padding_width, dtype: dtypes.int32)); + + var resized = resize_fn(image, array_ops.concat(new[] { resized_height, resized_width }, 0)); + + var padded = pad_to_bounding_box(resized, p_height, p_width, target_height, + target_width); + + if (padded.shape.ndim == Unknown) + throw new ValueError("padded contains no shape."); + + _ImageDimensions(padded, rank: 4); + + if (!is_batch) + { + padded = array_ops.squeeze(padded, axis: new int[] { 0 }); + } + + return padded; + }); + } + + public static Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, + string method, bool antialias) + { + Tensor _resize_fn(Tensor im, Tensor new_size) + { + return resize_images(im, new_size, method, antialias: antialias); + } + + return _resize_image_with_pad_common(image, target_height, target_width, + _resize_fn); + } + + public static Tensor per_image_standardization(Tensor image) + { + return tf_with(ops.name_scope(null, "per_image_standardization", new[] { image }), scope => + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + + var orig_dtype = image.dtype; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + image = convert_image_dtype(image, dtypes.float32); + + var x = image.shape["-3:"]; + var num_pixels = math_ops.reduce_prod(x); + + Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true); + + var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true); + var min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype)); + var adjusted_stddev = math_ops.maximum(stddev, min_stddev); + + image = image - image_mean; + image = tf.div(image, adjusted_stddev, name: scope); // name: scope in python version + return convert_image_dtype(image, orig_dtype, saturate: true); + }); + } + + public static Tensor random_brightness(Tensor image, float max_delta, int seed = 0) + { + if (max_delta < 0) + throw new ValueError("max_delta must be non-negative."); + + var delta = random_ops.random_uniform(new int[] { }, max_delta * -1, max_delta, seed: seed); + return adjust_brightness(image, delta); + } + + public static Tensor random_contrast(Tensor image, float lower, float upper, int seed = 0) + { + if (upper <= lower) + throw new ValueError("upper must be > lower."); + + if (lower < 0) + throw new ValueError("lower must be non-negative."); + + var contrast_factor = random_ops.random_uniform(new int[] { }, lower, upper, seed: seed); + return adjust_contrast(image, contrast_factor); + } + + public static Tensor adjust_brightness(Tensor image, Tensor delta) + { + return tf_with(ops.name_scope(null, "adjust_brightness", new[] { image, delta }), name => + { + image = ops.convert_to_tensor(image, name: "image"); + var orig_dtype = image.dtype; + + Tensor flt_image; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + { + flt_image = image; + } + else + { + flt_image = convert_image_dtype(image, dtypes.float32); + } + + var adjusted = math_ops.add( + flt_image, math_ops.cast(delta, flt_image.dtype), name: name); + + return convert_image_dtype(adjusted, orig_dtype, saturate: true); + }); + } + + public static Tensor adjust_contrast(Tensor images, Tensor contrast_factor) + { + return tf_with(ops.name_scope(null, "adjust_brightness", new[] { images, contrast_factor }), name => + { + images = ops.convert_to_tensor(images, name: "images"); + var orig_dtype = images.dtype; + + Tensor flt_images; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + { + flt_images = images; + } + else + { + flt_images = convert_image_dtype(images, dtypes.float32); + } + + var adjusted = gen_ops.adjust_contrastv2( + flt_images, contrast_factor: contrast_factor, name: name); + + return convert_image_dtype(adjusted, orig_dtype, saturate: true); + }); + } + + public static Tensor adjust_gamma(Tensor image, int gamma = 1, int gain = 1) + { + return tf_with(ops.name_scope(null, "adjust_gamma", new[] {image, + tf.constant(gamma), tf.constant(gain)}), name => + { + image = ops.convert_to_tensor(image, name: "image"); + var orig_dtype = image.dtype; + + Tensor flt_image; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + { + flt_image = image; + } + else + { + flt_image = convert_image_dtype(image, dtypes.float32); + } + + var assert_op = _assert(ops.convert_to_tensor(gamma >= 0), typeof(ValueError), + "Gamma should be a non-negative real number."); + + // python code has this if as: + // `if (assert_op)` + // + // given that assert_op is an Operation, that comparison can't be done here, + // so this just checks to see if it's empty, as that's what _assert returns + // if it fails to continue down the line of the assert + Tensor gamma_as_tensor; + if (assert_op != null) + gamma_as_tensor = control_flow_ops.with_dependencies(new[] { assert_op }, tf.constant(gamma)); + else + gamma_as_tensor = tf.constant(gamma); + + var adjusted_img = gain * math_ops.pow(flt_image, gamma_as_tensor); + + return convert_image_dtype(adjusted_img, orig_dtype, saturate: true); + }); + } + + public static Tensor rgb_to_grayscale(Tensor images, string name = null) + { + return tf_with(ops.name_scope(name, "rgb_to_grayscale", new[] { images }), name => + { + images = ops.convert_to_tensor(images, name: "images"); + var orig_dtype = images.dtype; + var flt_image = convert_image_dtype(images, dtypes.float32); + + var rgb_weights = new Tensor(new double[] { 0.2989, 0.5870, 0.1140 }); + var gray_float = math_ops.tensordot(flt_image, rgb_weights, new[] { -1, -1 }); + gray_float = array_ops.expand_dims(gray_float, -1); + return convert_image_dtype(gray_float, orig_dtype, name: name); + }); + } + + public static Tensor grayscale_to_rgb(Tensor images, string name = null) + { + return tf_with(ops.name_scope(name, "grayscale_to_rgb", new[] { images }), name => + { + images = _AssertAtLeast3DImage(images); + + images = ops.convert_to_tensor(images, name: "images"); + var rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0); + var shape_list = (array_ops.ones(rank_1, dtype: dtypes.int32) + + array_ops.expand_dims(tf.constant(3), 0)); + var multiples = array_ops.concat(new Tensor[] { shape_list }, 0); + var rgb = array_ops.tile(images, multiples, name: name); + int[] rgb_temp = images.shape.dims.Take(images.shape.ndim - 1).Select(x => (int)x).ToArray(); + rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3)); + return rgb; + }); + } + + public static Tensor random_hue(Tensor image, float max_delta, int seed = 0) + { + if (max_delta > 0.5) + throw new ValueError("max_delta must be <= 0.5."); + + if (max_delta < 0) + throw new ValueError("max_delta must be non-negative."); + + var delta = random_ops.random_uniform(new int[] { }, max_delta * -1, max_delta, seed: seed); + return adjust_hue(image, delta); + } + + public static Tensor adjust_hue(Tensor image, Tensor delta, string name = null) + { + return tf_with(ops.name_scope(name, "adjust_hue", new[] { image }), name => + { + image = ops.convert_to_tensor(image, name: "image"); + var orig_dtype = image.dtype; + + Tensor flt_image; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + flt_image = image; + else + flt_image = convert_image_dtype(image, dtypes.float32); + + var rgb_altered = gen_ops.adjust_hue(flt_image, delta); + + return convert_image_dtype(rgb_altered, orig_dtype); + }); + } + + public static Tensor random_jpeg_quality(Tensor image, float min_jpeg_quality, float max_jpeg_quality, + int seed = 0) + { + if (min_jpeg_quality < 0 || max_jpeg_quality < 0 || min_jpeg_quality > 100 || + max_jpeg_quality > 100) + throw new ValueError("jpeg encoding range must be between 0 and 100."); + + if (min_jpeg_quality >= max_jpeg_quality) + throw new ValueError("`min_jpeg_quality` must be less than `max_jpeg_quality`."); + + var jpeg_quality = random_ops.random_uniform(new int[] { }, + min_jpeg_quality, + max_jpeg_quality, + seed: seed, + dtype: dtypes.int32); + return adjust_jpeg_quality(image, jpeg_quality); + } + + public static Tensor adjust_jpeg_quality(Tensor image, Tensor jpeg_quality, string name = null) + { + return tf_with(ops.name_scope(name, "adjust_jpeg_quality", new[] { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + var channels = image.shape[image.shape.dims.Length - 1]; + var orig_dtype = image.dtype; + // python code checks to ensure jpeq_quality is a tensor; unnecessary here since + // it is passed as a tensor + image = gen_ops.encode_jpeg_variable_quality(image, quality: jpeg_quality); + + image = gen_ops.decode_jpeg(image, channels: (int)channels); + return convert_image_dtype(image, orig_dtype, saturate: true); + }); + } + + public static Tensor random_saturation(Tensor image, float lower, float upper, int seed = 0) + { + if (upper <= lower) + throw new ValueError("upper must be > lower."); + + if (lower < 0) + throw new ValueError("lower must be non-negative"); + + var saturation_factor = random_ops.random_uniform(new int[] { }, lower, upper, seed: seed); + return adjust_saturation(image, saturation_factor); + } + + public static Tensor adjust_saturation(Tensor image, Tensor saturation_factor, string name = null) + { + return tf_with(ops.name_scope(name, "adjust_saturation", new[] { image }), name => + { + image = ops.convert_to_tensor(image, name: "image"); + var orig_dtype = image.dtype; + + Tensor flt_image; + if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype)) + flt_image = image; + else + flt_image = convert_image_dtype(image, dtypes.float32); + + var adjusted = gen_ops.adjust_saturation(flt_image, saturation_factor); + + return convert_image_dtype(adjusted, orig_dtype); + }); + } + + public static Tensor total_variation(Tensor images, string name = null) + { + /* + return tf_with(ops.name_scope(name, "total_variation"), delegate + { + + }); + */ + throw new NotImplementedException(""); + } + + public static (Tensor begin, Tensor size, Tensor bboxes) sample_distorted_bounding_box_v2(Tensor image_size, Tensor bounding_boxes, int seed = 0, + Tensor min_object_covered = null, float[] aspect_ratio_range = null, float[] area_range = null, int max_attempts = 100, + bool use_image_if_no_bounding_boxes = false, string name = null) + { + // set default values that couldn't be set in function declaration, if necessary + if (min_object_covered == null) + min_object_covered = ops.convert_to_tensor(0.1); + if (aspect_ratio_range == null) + aspect_ratio_range = new float[] { 0.75f, 1.33f }; + if (area_range == null) + area_range = new float[] { 0.05f, 1f }; + + int? seed1, seed2; + if (seed != 0) + (seed1, seed2) = random_seed.get_seed(seed); + else + (seed1, seed2) = (0, 0); + + return sample_distorted_bounding_box(image_size, bounding_boxes, seed1, seed2, + min_object_covered, aspect_ratio_range, + area_range, max_attempts, + use_image_if_no_bounding_boxes, name); + } + + internal static (Tensor begin, Tensor size, Tensor bboxes) sample_distorted_bounding_box(Tensor image_size, Tensor bounding_boxes, int? seed = 0, int? seed2 = 0, + Tensor min_object_covered = null, float[] aspect_ratio_range = null, float[] area_range = null, int max_attempts = 100, + bool use_image_if_no_bounding_boxes = false, string name = null) + { + return tf_with(ops.name_scope(name, "sample_distorted_bounding_box"), delegate + { + return gen_ops.sample_distorted_bounding_box_v2( + image_size, + bounding_boxes, + seed: seed, + seed2: seed2, + min_object_covered: min_object_covered, + aspect_ratio_range: aspect_ratio_range, + area_range: area_range, + max_attempts: max_attempts, + use_image_if_no_bounding_boxes: use_image_if_no_bounding_boxes, + name: name); + }); + } + + public static Tensor non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f, + float score_threshold = -1f / 0f, string name = null) + { + return tf_with(ops.name_scope(name, "non_max_suppression,"), delegate + { + Tensor iou_threshold_tensor = ops.convert_to_tensor(iou_threshold, name: "iou_threshold"); + Tensor score_threshold_tensor = ops.convert_to_tensor(score_threshold, name: "score_threshold"); + return gen_ops.non_max_suppression_v3(boxes, scores, max_output_size, + iou_threshold_tensor, score_threshold_tensor); + }); + } + + public static (Tensor, Tensor) non_max_suppression_with_scores(Tensor boxes, Tensor scores, Tensor max_output_size, + float iou_threshold = 0.5f, float score_threshold = -1f / 0f, /*float soft_nms_sigma = 0.0f,*/ string name = null) + { + return tf_with(ops.name_scope(name, "non_max_suppression_with_scores"), delegate + { + Tensor iou_threshold_tensor = ops.convert_to_tensor(iou_threshold, name: "iou_threshold"); + Tensor score_threshold_tensor = ops.convert_to_tensor(score_threshold, name: "score_threshold"); + + // non_max_suppression_v5 apparently doesn't exist yet, so use v4 + // and adapt the arguments to fit + + // Tensor soft_nms_sigma_tensor = ops.convert_to_tensor(soft_nms_sigma, name: "soft_nms_sigma"); + (Tensor selected_indices, Tensor selected_scores) = gen_ops.non_max_suppression_v4( + boxes, + scores, + max_output_size, + iou_threshold_tensor, + score_threshold_tensor, + // soft_nms_sigma_tensor, + false + ); + return (selected_indices, selected_scores); + }); + } + + public static Tensor non_max_suppression_with_overlaps(Tensor overlaps, Tensor scores, Tensor max_output_size, + float overlap_threshold = 0.5f, float score_threshold = -1f / 0f, string name = null) + { + return tf_with(ops.name_scope(name, "non_max_suppression_overlaps"), delegate + { + Tensor overlap_threshold_tensor = ops.convert_to_tensor(overlap_threshold, name: "overlap_threshold"); + return gen_ops.non_max_suppression_with_overlaps( + overlaps, scores, max_output_size, overlap_threshold_tensor, ops.convert_to_tensor(score_threshold)); + }); + } + + public static Tensor rgb_to_yiq(Tensor images) + { + images = ops.convert_to_tensor(images, name: "images"); + var _rgb_to_yiq_kernel = new float[,] { {0.299f, 0.59590059f, 0.2115f}, + {0.587f, -0.27455667f, -0.52273617f}, + {0.114f, -0.32134392f, 0.31119955f}}; + Tensor kernel = ops.convert_to_tensor(_rgb_to_yiq_kernel, dtype: images.dtype, name: "kernel"); + var ndims = images.shape.ndim; + return math_ops.tensordot(images, kernel, axes: new int[] { ndims - 1, 0 }); + } + + public static Tensor yiq_to_rgb(Tensor images) + { + images = ops.convert_to_tensor(images, name: "images"); + var _yiq_to_rgb_kernel = new float[,] { {1f, 1f, 1f}, + {0.95598634f, -0.27201283f, -1.10674021f}, + {0.6208248f, -0.64720424f, 1.70423049f}}; + Tensor kernel = ops.convert_to_tensor(_yiq_to_rgb_kernel, dtype: images.dtype, name: "kernel"); + var ndims = images.shape.ndim; + return math_ops.tensordot(images, kernel, axes: new int[] { ndims - 1, 0 }); + } + + public static Tensor rgb_to_yuv(Tensor images) + { + images = ops.convert_to_tensor(images, name: "images"); + var _rgb_to_yuv_kernel = new float[,] { {0.299f, -0.14714119f, 0.61497538f}, + {0.587f, -0.28886916f, -0.51496512f}, + {0.114f, 0.43601035f, -0.10001026f}}; + Tensor kernel = ops.convert_to_tensor(_rgb_to_yuv_kernel, dtype: images.dtype, name: "kernel"); + var ndims = images.shape.ndim; + return math_ops.tensordot(images, kernel, axes: new int[] { ndims - 1, 0 }); + } + + public static Tensor yuv_to_rgb(Tensor images) + { + images = ops.convert_to_tensor(images, name: "images"); + var _yuv_to_rgb_kernel = new float[,] { {1f, 1f, 1f,}, + {0f, -0.394642334f, 2.03206185f}, + {1.13988303f, -0.58062185f, 0f}}; + Tensor kernel = ops.convert_to_tensor(_yuv_to_rgb_kernel, dtype: images.dtype, name: "kernel"); + var ndims = images.shape.ndim; + return math_ops.tensordot(images, kernel, axes: new int[] { ndims - 1, 0 }); + } + + internal static (Tensor, Tensor, Operation[]) _verify_compatible_image_shapes(Tensor img1, Tensor img2) + { + Shape shape1 = img1.shape.with_rank_at_least(3); + Shape shape2 = img2.shape.with_rank_at_least(3); + shape1 = new Shape(shape1.dims.Skip(shape1.dims.Length - 3).Take(shape1.dims.Length - (shape1.dims.Length - 3)).ToArray()); + tensor_shape.assert_is_compatible_with(self: new Tensor(shape1.dims), other: new Tensor(shape2.dims.Skip(shape2.dims.Length - 3).Take(shape2.dims.Length - (shape2.dims.Length - 3)).ToArray())); + + if (shape1.ndim != -1 && shape2.ndim != -1) + { + var shape1_temp = shape1.dims.Skip(shape1.dims.Length - 3).Take(shape1.dims.Length - (shape1.dims.Length - 3)).ToArray(); + var shape2_temp = shape2.dims.Skip(shape2.dims.Length - 3).Take(shape2.dims.Length - (shape1.dims.Length - 3)).ToArray(); + Array.Reverse(shape1_temp); + Array.Reverse(shape2_temp); + foreach (var (dim1, dim2) in shape1_temp.Zip(shape2_temp, Tuple.Create)) + { + if (dim1 != 1 || dim2 != 1 /*|| !dim1.is_compatible_with(dim2)*/) + throw new ValueError(String.Format("Two images are not compatible: {0} and {1}", shape1, shape2)); + } + } + + Tensor shape1_tensor = gen_array_ops.shape_n(new Tensor[] { img1, img2 })[0]; + Tensor shape2_tensor = gen_array_ops.shape_n(new Tensor[] { img1, img2 })[1]; + Operation[] checks = new Operation[] { }; + checks.append( + control_flow_ops.Assert( + gen_math_ops.greater_equal(array_ops.size(shape1_tensor), ops.convert_to_tensor(3)), new[] { shape1, shape2 }, + summarize: 10)); + checks.append( + control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(shape1_tensor.dims.Skip(shape1_tensor.dims.Length - 3).Take(shape1_tensor.dims.Length - (shape1_tensor.dims.Length - 3)).ToArray(), + shape2_tensor.dims.Skip(shape1_tensor.dims.Length - 3).Take(shape1_tensor.dims.Length - (shape1_tensor.dims.Length - 3)))), + new[] { shape1, shape2 }, + summarize: 10)); + return (shape1_tensor, shape2_tensor, checks); + } + + public static Tensor psnr(Tensor a, Tensor b, Tensor max_val, string name = null) + { + return tf_with(ops.name_scope(name, "PSNR", new[] { a, b }), delegate + { + max_val = math_ops.cast(max_val, a.dtype); + max_val = convert_image_dtype(max_val, dtypes.float32); + a = convert_image_dtype(a, dtypes.float32); + b = convert_image_dtype(b, dtypes.float32); + Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new(-3, -2, -1)); + var psnr_val = math_ops.subtract( + (20 * math_ops.log(max_val)) / math_ops.log(ops.convert_to_tensor(10.0)), + math_ops.cast(10 / math_ops.log(ops.convert_to_tensor(10)), dtypes.float32) * math_ops.log(mse), + name: "psnr"); + + (object _a, object _b, Operation[] checks) = _verify_compatible_image_shapes(a, b); + return tf_with(ops.control_dependencies(checks), delegate + { + return array_ops.identity(psnr_val); + }); + }); + } + + internal static (Tensor, Tensor) _ssim_helper(Tensor x, Tensor y, Func reducer, float max_val, + float compensation = 1.0f, float k1 = 0.01f, float k2 = 0.03f) + { + var c1 = Math.Pow((k1 * max_val), 2); + var c2 = Math.Pow((k2 * max_val), 2); + + var mean0 = reducer(x); + var mean1 = reducer(y); + var num0 = mean0 * mean1 * 2.0; + var den0 = math_ops.square(mean0) + math_ops.square(mean1); + var luminance = (num0 + c1) / (den0 + c1); + + var num1 = reducer(x * y) * 2.0; + var den1 = reducer(math_ops.square(x) + math_ops.square(y)); + c2 = c2 * compensation; + var cs = (num1 - num0 + c2) / (den1 - den0 + c2); + + return (luminance, cs); + } + + internal static Tensor _fspecial_gauss(Tensor size, Tensor sigma) + { + size = ops.convert_to_tensor(size, dtypes.int32); + sigma = ops.convert_to_tensor(sigma); + + var coords = math_ops.cast(math_ops.range(size), sigma.dtype); + coords = coords - math_ops.cast(size - 1, sigma.dtype) / 2.0; + + var g = math_ops.square(coords); + g = g * -0.5 / math_ops.square(sigma); + + g = array_ops.reshape(g, shape: new int[] { 1, -1 }) + array_ops.reshape(g, shape: new int[] { -1, 1 }); + g = array_ops.reshape(g, shape: new int[] { 1, -1 }); + g = nn_ops.softmax(g); + + // shape takes an int, python code passes size, a Tensor. NDims is the only int type + // i could think of a Tensor having. it might be incorrect tho, so keep that in mind. + return array_ops.reshape(g, shape: new int[] { size.ndim, size.ndim, 1, 1 }); + } + + internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f, + float filter_size = 11f, float filter_sigma = 1.5f, float k1 = 0.01f, float k2 = 0.03f) + { + Tensor filter_size_tensor = constant_op.constant(filter_size, dtype: dtypes.int32); + Tensor filter_sigma_tensor = constant_op.constant(filter_sigma, dtype: img1.dtype); + + Tensor shape1_tensor = gen_array_ops.shape_n(new Tensor[] { img1, img2 })[0]; + Tensor shape2_tensor = gen_array_ops.shape_n(new Tensor[] { img1, img2 })[1]; + Operation[] checks = new Operation[] { + control_flow_ops.Assert( + math_ops.reduce_all( + gen_math_ops.greater_equal(new Tensor(shape1_tensor.dims.Skip(shape1_tensor.dims.Length - 3).Take(shape1_tensor.dims.Length - (shape1_tensor.dims.Length - 3 - 1)).ToArray()), filter_size_tensor)), + new object[] {shape1_tensor, filter_size}, + summarize: 8), + control_flow_ops.Assert( + math_ops.reduce_all( + gen_math_ops.greater_equal(new Tensor(shape2_tensor.dims.Skip(shape2_tensor.dims.Length - 3).Take(shape2_tensor.dims.Length - (shape2_tensor.dims.Length - 3 - 1)).ToArray()), filter_size_tensor)), + new object[] {shape2_tensor, filter_size}, + summarize: 8) + }; + + using (ops.control_dependencies(checks)) + img1 = array_ops.identity(img1); + + var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor); + kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, (int)shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 })); + + float compensation = 1.0f; + + Tensor reducer(Tensor x) + { + var shape = array_ops.shape(x); + x = array_ops.reshape(x, shape: array_ops.concat(new Tensor[] { new Tensor(-1), new Tensor(shape1_tensor.dims.Skip(shape1_tensor.dims.Length - 3).Take(shape1_tensor.dims.Length - (shape1_tensor.dims.Length - 3 - 1)).ToArray()) }, 0)); + var y = gen_ops.depthwise_conv2d_native(x, kernel, strides: new int[] { 1, 1, 1, 1 }, padding: "VALID"); + return array_ops.reshape( + y, array_ops.concat(new Tensor[] { new Tensor(shape.dims.Take(shape.dims.Length - 3).ToArray()), new Tensor(array_ops.shape(y).dims.Skip(1).Take(array_ops.shape(y).dims.Length - 2).ToArray()) }, 0)); + } + + (Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); + + var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); + var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims); + cs = math_ops.reduce_mean(cs, axes.dims); + return (ssim_val, cs); + } + + public static Tensor ssim(Tensor img1, Tensor img2, float max_val = 1f, float filter_size = 11f, float filter_sigma = 1.5f, + float k1 = 0.01f, float k2 = 0.03f) + { + return tf_with(ops.name_scope(null, "SSIM", new[] { img1, img2 }), delegate + { + img1 = ops.convert_to_tensor(img1, name: "img1"); + img2 = ops.convert_to_tensor(img2, name: "img2"); + + (Tensor _, Tensor __, Operation[] checks) = _verify_compatible_image_shapes(img1, img2); + using (ops.control_dependencies(checks)) + img1 = array_ops.identity(img1); + + Tensor max_val_tensor = constant_op.constant(max_val, img1.dtype); + max_val_tensor = convert_image_dtype(max_val_tensor, dtypes.float32); + img1 = convert_image_dtype(img1, dtypes.float32); + img2 = convert_image_dtype(img2, dtypes.float32); + (Tensor ssim_per_channel, Tensor ___) = _ssim_per_channel(img1, img2, max_val, filter_size, + filter_sigma, k1, k2); + + return math_ops.reduce_mean(ssim_per_channel, new(-1)); + }); + } + + public static Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] power_factors = null, float filter_size = 11f, + float filter_sigma = 1.5f, float k1 = 0.01f, float k2 = 0.03f) + { + if (power_factors == null) + power_factors = new float[] { 0.0448f, 0.2856f, 0.3001f, 0.2363f, 0.1333f }; + + return tf_with(ops.name_scope(null, "MS-SSIM", new[] { img1, img2 }), delegate + { + img1 = ops.convert_to_tensor(img1, name: "img1"); + img2 = ops.convert_to_tensor(img2, name: "img2"); + + (Tensor shape1, Tensor shape2, Operation[] checks) = _verify_compatible_image_shapes(img1, img2); + using (ops.control_dependencies(checks)) + img1 = array_ops.identity(img1); + + Tensor max_val_tensor = constant_op.constant(max_val); + max_val_tensor = convert_image_dtype(max_val_tensor, dtypes.float32); + img1 = convert_image_dtype(img1, dtypes.float32); + img2 = convert_image_dtype(img2, dtypes.float32); + + var imgs = new[] { img1, img2 }; + var shapes = new[] { shape1, shape2 }; + + Tensor[] heads = new Tensor[] { }; + Tensor[] tails = new Tensor[] { }; + foreach (Tensor s in shapes) + { + heads[heads.Length] = new Tensor(s.dims.Take(s.dims.Length - 3).ToArray()); + tails[tails.Length] = new Tensor(s.dims.Skip(s.dims.Length - 3).Take(s.dims.Length - (s.dims.Length - 3)).ToArray()); + } + + var divisor = new[] { 1, 2, 2, 1 }; + var divisor_tensor = constant_op.constant(divisor.Skip(1).Take(divisor.Length - 1).ToArray(), dtype: dtypes.int32); + + Tensor[] do_pad(Tensor[] images, Tensor remainder) + { + var padding = array_ops.expand_dims(remainder, -1); + padding = array_ops.pad(padding, new Tensor(new int[,] { { 1, 0 }, { 1, 0 } })); + + Tensor[] x_arr = new Tensor[] { }; + foreach (Tensor x in images) + { + x_arr[x_arr.Length] = array_ops.pad(x, padding, mode: "SYMMETRIC"); + } + return x_arr; + } + + var mcs = new Tensor[] { }; + var ssim_per_channel = new Tensor(new int[] { }); + var cs = ssim_per_channel; + foreach (var k in range(0, len(power_factors))) + { + using (ops.name_scope(null, String.Format("Scale{0}", k), imgs)) + { + if (k > 0) + { + // handle flat_imgs + Tensor[] flat_imgs = new Tensor[] { }; + foreach ((Tensor x, Tensor t) in imgs.Zip(tails, Tuple.Create)) + { + flat_imgs[flat_imgs.Length] = array_ops.reshape(x, array_ops.concat(new Tensor[] { constant_op.constant(-1), t }, 0)); + } + + var remainder = tails[0] % divisor_tensor; + var need_padding = math_ops.reduce_any(math_ops.not_equal(remainder, 0)); + + Tensor[] padded_func_pass() { return do_pad(flat_imgs, remainder); } + var padded = control_flow_ops.cond(need_padding, + true_fn: () => padded_func_pass(), + false_fn: () => flat_imgs); + + // handle downscaled + Tensor[] downscaled = new Tensor[] { }; + foreach (Tensor x in padded) + { + downscaled[downscaled.Length] = gen_ops.avg_pool(x, ksize: divisor, strides: divisor, padding: "VALID"); + } + + // handle tails + tails = new Tensor[] { }; + foreach (Tensor x in gen_array_ops.shape_n(downscaled)) + { + tails[tails.Length] = new Tensor(x.dims.Skip(1).Take(tails.Length - 1).ToArray()); + } + + imgs = new Tensor[] { }; + // tuples weren't working; this is hacky, but should work similarly. + // zip loads the values into a tuple (Tensor, Tensor, Tensor) for each + // zip entry; this just gets the length of the longest array, and loops + // that many times, getting values (like zip) and using them similarly. + for (int x = 0; x < Math.Max(Math.Max(downscaled.Length, heads.Length), tails.Length); x++) + { + imgs[imgs.Length] = array_ops.reshape(downscaled[x], array_ops.concat(new Tensor[] { heads[x], tails[x] }, 0)); + } + } + } + + // python code uses * to unpack imgs; how to replicate that here? + // don't think that this is doing the same thing as the python code. + (ssim_per_channel, cs) = _ssim_per_channel( + img1: imgs[0], + img2: imgs[1], + max_val: max_val, + filter_size: filter_size, + filter_sigma: filter_sigma, + k1: k1, + k2: k2); + mcs.append(gen_nn_ops.relu(cs)); + } + + mcs = mcs.Skip(1).ToArray(); + var mcs_and_ssim = array_ops.stack( + math_ops.add(mcs, new[] { gen_nn_ops.relu(ssim_per_channel) }), axis: -1); + var ms_ssim = math_ops.reduce_prod( + math_ops.pow(mcs_and_ssim, power_factors), new(-1)); + + return math_ops.reduce_mean(ms_ssim, new(-1)); + }); + } + + public static (Tensor, Tensor) image_gradients(Tensor image) + { + if (image.shape.ndim != 4) + throw new ValueError(String.Format(@"image_gradients expects a 4D tensor [batch_size, h, w, d], not {0}.", image.shape)); + + var image_shape = array_ops.shape(image); + var bs_h_w_d = array_ops.unstack(image_shape); + Tensor dy; //= image[:, 1:, :, :] - image[:, :-1, :, :]; + Tensor dx = new Tensor(new int[] { }); //= image[:, :, 1:, :] - image[:, :, :-1, :]; + + var shape = array_ops.stack(new Tensor[] { bs_h_w_d[0], constant_op.constant(1), bs_h_w_d[2], bs_h_w_d[3] }); + dy = array_ops.concat(new Tensor[] { dx, array_ops.zeros(shape, image.dtype) }, 2); + dy = array_ops.reshape(dy, image_shape); + + shape = array_ops.stack(new Tensor[] { bs_h_w_d[0], bs_h_w_d[1], constant_op.constant(1), bs_h_w_d[3] }); + dx = array_ops.concat(new Tensor[] { dx, array_ops.zeros(shape, image.dtype) }, 2); + dx = array_ops.reshape(dx, image_shape); + + return (dx, dy); + } + + public static Tensor sobel_edges(Tensor image) + { + var static_image_shape = image.shape; + var image_shape = array_ops.shape(image); + var kernels = new Tensor(new int[,] {{-1, -2, -1}, {0, 0, 0}, {1, 2, 1}, + {-1, 0, 1}, {-2, 0, 2}, {-1, 0, 1}}); + var num_kernels = len(kernels); + // kernels.dims != np.asarray(kernels) ? + kernels = array_ops.transpose(kernels.dims, (1, 2, 0)); + kernels = array_ops.expand_dims(kernels, -2); + var kernels_tf = constant_op.constant(kernels, dtype: image.dtype); + + kernels_tf = array_ops.tile( + kernels_tf, new Tensor(new int[] { 1, 1, (int)image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters"); + + var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } }; + var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect"); + + var strides = new int[] { 1, 1, 1, 1 }; + var output = gen_ops.depthwise_conv2d_native(padded, kernels_tf, strides, "VALID"); + + var shape = array_ops.concat(new Tensor[] { image_shape, ops.convert_to_tensor(num_kernels) }, 0); + output = array_ops.reshape(output, shape: shape); + output.shape = static_image_shape.concatenate(new int[] { num_kernels }); + return output; + } + + public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + { + var scope = ops.name_scope(name, "decode_image"); + scope.__enter__(); + + var result = gen_image_ops.decode_image(contents, + channels: channels, + dtype: dtype, + expand_animations: expand_animations); + + scope.__exit__(); + return result; + } + + public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method, float extrapolation_value, string name) + { + var _op = tf.OpDefLib._apply_op_helper("CropAndResize", name: name, args: new + { + image, + boxes, + box_ind, + crop_size, + method, + extrapolation_value + }); + + return _op.outputs[0]; + } + + public static Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, + bool uniform_noise = true, string name = null) + { + return gen_ops.extract_glimpse( + input: input, + size: size, + offsets: offsets, + centered: centered, + normalized: normalized, + uniform_noise: uniform_noise, + name: name); + } + + public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, + Tensor max_total_size, float iou_threshold = 0.5f, float score_threshold = -1f / 0f, bool pad_per_class = false, bool clip_boxes = true, + string name = null) + { + return tf_with(ops.name_scope(null, "combined_non_max_suppression"), delegate + { + Tensor iou_threshold_tensor = ops.convert_to_tensor( + iou_threshold, dtype: dtypes.float32, name: "iou_threshold"); + Tensor score_threshold_tensor = ops.convert_to_tensor( + score_threshold, dtype: dtypes.float32, name: "score_threshold"); + return gen_image_ops.combined_non_max_suppression( + boxes, scores, max_output_size_per_class, max_total_size, iou_threshold_tensor, + score_threshold_tensor, pad_per_class, clip_boxes); + }); + } + + internal static (Tensor, Tensor, Tensor, Tensor) _cross_suppression(Tensor boxes, Tensor box_slice, Tensor iou_threshold, Tensor inner_idx, int tile_size) + { + var batch_size = array_ops.shape(boxes)[0]; + var new_slice = array_ops.slice( + boxes, new Tensor[] { ops.convert_to_tensor(0), ops.convert_to_tensor(inner_idx * tile_size), ops.convert_to_tensor(0) }, + new Tensor[] { ops.convert_to_tensor(batch_size), ops.convert_to_tensor(tile_size), ops.convert_to_tensor(4) }); + var iou = _bbox_overlap(new_slice, box_slice); + var box_slice_after_suppression = array_ops.expand_dims( + math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new(1)), + box_slice.dtype), + 2) * box_slice; + return (boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1); + } + + internal static Tensor _bbox_overlap(Tensor boxes_a, Tensor boxes_b) + { + return tf_with(ops.name_scope("bbox_overlap"), delegate + { + // a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] + var a_xy_minmax = array_ops.split( + value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); + // b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] + var b_xy_minmax = array_ops.split( + value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); + + var i_xmin = math_ops.maximum( + a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); + var i_xmax = math_ops.minimum( + a_xy_minmax[3], array_ops.transpose(b_xy_minmax[3], new[] { 0, 2, 1 })); + var i_ymin = math_ops.maximum( + a_xy_minmax[0], array_ops.transpose(b_xy_minmax[0], new[] { 0, 2, 1 })); + var i_ymax = math_ops.minimum( + a_xy_minmax[3], array_ops.transpose(b_xy_minmax[3], new[] { 0, 2, 1 })); + var i_area = math_ops.maximum( + (i_xmax - i_xmin), 0) * math_ops.maximum((i_ymax - i_ymin), 0); + + var a_area = (a_xy_minmax[2] - a_xy_minmax[0]) * (a_xy_minmax[3] - a_xy_minmax[1]); + var b_area = (b_xy_minmax[2] - b_xy_minmax[0]) * (b_xy_minmax[3] - b_xy_minmax[1]); + double EPSILON = 1e-8; + + var u_area = a_area + array_ops.transpose(b_area, new[] { 0, 2, 1 }) - i_area + EPSILON; + + var intersection_over_union = i_area / u_area; + + return intersection_over_union; + }); + } + + internal static (Tensor, float, Tensor, int) _suppression_loop_body(Tensor boxes, float iou_threshold, Tensor output_size, int idx, int tile_size) + { + using (ops.name_scope("suppression_loop_body")) + { + var num_tiles = array_ops.shape(boxes).dims[1] / tile_size; + var batch_size = array_ops.shape(boxes).dims[0]; + + (Tensor, Tensor, Tensor, Tensor) cross_suppression_func(Tensor boxes, Tensor box_slice, Tensor iou_threshold, Tensor inner_idx, int tile_size) + => _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size); + + var box_slice = array_ops.slice(boxes, new Tensor[]{ ops.convert_to_tensor(0), ops.convert_to_tensor(idx * tile_size), ops.convert_to_tensor(0) }, + new Tensor[] { ops.convert_to_tensor(batch_size), ops.convert_to_tensor(tile_size), ops.convert_to_tensor(4) }); + + var iou = _bbox_overlap(box_slice, box_slice); + var mask = array_ops.expand_dims( + array_ops.reshape( + math_ops.range(tile_size), new[] { 1, -1 }) > array_ops.reshape( + math_ops.range(tile_size), new[] { -1, 1 }), 0); + iou = iou * math_ops.cast( + math_ops.logical_and(mask, iou >= iou_threshold), iou.dtype); + + /* + I have no idea what's going on here. Not even going to try to port it yet. + var suppressed_iou = control_flow_ops.while_loop( + todo + ) + */ + var suppressed_iou = new Tensor(new int[] { }); + var suppressed_box = math_ops.reduce_sum(suppressed_iou, constant_op.constant(1)) > 0; + box_slice = box_slice * array_ops.expand_dims( + 1.0f - math_ops.cast(suppressed_box, box_slice.dtype), 2); + + mask = array_ops.reshape( + math_ops.cast( + math_ops.equal(math_ops.range(num_tiles), idx), boxes.dtype), + new[] { 1, -1, 1, 1 }); + boxes = array_ops.tile(array_ops.expand_dims( + box_slice, 1), ops.convert_to_tensor(new[] { 1, num_tiles, 1, 1 }) * mask + array_ops.reshape( + boxes, new[] { batch_size, num_tiles, tile_size, 4 }) * (1 - mask)); + boxes = array_ops.reshape(boxes, new[] { batch_size, -1, 4 }); + + output_size = output_size + math_ops.reduce_sum( + math_ops.cast( + math_ops.reduce_any(box_slice > 0, new(2)), dtypes.int32), constant_op.constant(new int[] { 1 })); + } + return (boxes, iou_threshold, output_size, idx + 1); + } + + public static (Tensor, Tensor) non_max_suppression_padded(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f, float score_threshold = -1f / 0f, + bool pad_to_max_output_size = false, string name = null, bool sorted_input = false, bool canonicalized_coordinates = false, int tile_size = 512) + { + if (!sorted_input && !canonicalized_coordinates && tile_size == 512 /*&& !compat.forward_compatible(2020, 6, 23)*/) + return non_max_suppression_padded_v1( + boxes, scores, max_output_size, iou_threshold, score_threshold, + pad_to_max_output_size, name); + else + { + return tf_with(ops.name_scope(name, "non_max_suppression_padded"), delegate + { + if (!pad_to_max_output_size) + if (boxes.shape.ndim != -1 && boxes.shape.ndim > 2) + throw new ValueError(String.Format( + "'pad_to_max_output_size' (value {0}) must be true for 'batched input'", pad_to_max_output_size)); + if (name == null) + name = ""; + (Tensor idx, Tensor num_valid) = non_max_suppression_padded_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold, + sorted_input, canonicalized_coordinates, tile_size); + if (!pad_to_max_output_size) + // idx = idx[0, :num_valid], passes: + // 0, slice(None, num_valid, None) + // which is what I tried to replicate below, but i don't think that Unknown is the exact + // equivalent to None, and don't know about the slice function bit. + idx = idx[0, slice(Unknown, num_valid.shape.ndim, Unknown).ToArray()[0]]; + else + { + var batch_dims = array_ops.concat(new Tensor[] { + new Tensor(array_ops.shape(boxes).dims.Take(boxes.shape.dims.Length - 2).ToArray()), + array_ops.expand_dims(max_output_size, 0) + }, 0); + idx = array_ops.reshape(idx, batch_dims); + } + return (idx, num_valid); + }); + } + } + + public static (Tensor, Tensor) non_max_suppression_padded_v2(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f, float score_threshold = -1f / 0f, + bool sorted_input = false, bool canonicalized_coordinates = false, int tile_size = 512) + { + (Tensor, Tensor, Tensor) _sort_scores_and_boxes(Tensor scores, Tensor boxes) + { + int batch_size, num_boxes; + Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices; + using (ops.name_scope("sort_scores_and_boxes")) + { + batch_size = (int)array_ops.shape(boxes).dims[0]; + num_boxes = (int)array_ops.shape(boxes).dims[1]; + sorted_scores_indices = null; /*sort_ops.argsort( + scores, axis: 1, direction: "DESCENDING); */ + index_offsets = math_ops.range(batch_size) * num_boxes; + indices = array_ops.reshape( + sorted_scores_indices + array_ops.expand_dims(index_offsets, 1), new[] { -1 }); + sorted_scores = array_ops.reshape( + array_ops.gather(array_ops.reshape(boxes, new[] { -1, 4 }), indices), + new[] { batch_size, -1 }); + sorted_boxes = array_ops.reshape( + array_ops.gather(array_ops.reshape(boxes, new[] { -1, 4 }), indices), + new[] { batch_size, -1, 4 }); + }; + + return (sorted_scores, sorted_boxes, sorted_scores_indices); + } + + var batch_dims = array_ops.shape(boxes).dims.Take(boxes.shape.dims.Length - 2).ToArray(); + var num_boxes = array_ops.shape(boxes).dims[boxes.shape.dims.Length - 2]; + boxes = array_ops.reshape(boxes, new[] { -1, num_boxes, 4 }); + scores = array_ops.reshape(scores, new[] { -1, num_boxes }); + var batch_size = array_ops.shape(boxes).dims[0]; + + // initialization for later + Tensor sorted_indices; + + if (score_threshold != -1f / 0f) + using (ops.name_scope("filter_by_score")) + { + var score_mask = math_ops.cast(scores > score_threshold, scores.dtype); + scores = scores * score_mask; + var box_mask = array_ops.expand_dims( + math_ops.cast(score_mask, boxes.dtype), 2); + boxes = boxes * box_mask; + } + + if (!canonicalized_coordinates) + using (ops.name_scope("canonicalize_coordinates")) + { + // y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] + var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); + var y_1_is_min = math_ops.reduce_all( + gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); + var y_minmax = control_flow_ops.cond( + y_1_is_min, true_fn: () => yx[0] /*yx[2]*/, false_fn: () => yx[2] /*yx[0]*/); + var x_1_is_min = math_ops.reduce_all( + gen_math_ops.less_equal(yx[1][0, 0, 0], yx[3][0, 0, 0])); + var x_minmax = control_flow_ops.cond( + x_1_is_min, true_fn: () => yx[1] /*yx[3]*/, false_fn: () => yx[3] /*yx[1]*/); + boxes = array_ops.concat(new Tensor[] { y_minmax, x_minmax }, axis: 2); + } + + if (!sorted_input) + (scores, boxes, sorted_indices) = _sort_scores_and_boxes(scores, boxes); + else + sorted_indices = array_ops.zeros_like(scores, dtype: dtypes.int32); + + var pad = math_ops.cast( + gen_math_ops.ceil( + math_ops.cast( + math_ops.maximum(num_boxes, max_output_size), dtypes.float32) / tile_size), + dtypes.int32) * tile_size - num_boxes; + boxes = array_ops.pad( + math_ops.cast(scores, dtypes.float32), ops.convert_to_tensor(new object[,] { { 0, 0 }, { 0, pad }, { 0, 0 } })); + scores = array_ops.pad( + math_ops.cast(scores, dtypes.float32), ops.convert_to_tensor(new object[,] { { 0, 0 }, { 0, pad } })); + var num_boxes_after_padding = num_boxes + pad; + var num_iterations = math_ops.floordiv(num_boxes_after_padding, ops.convert_to_tensor(tile_size)); + + // Tensor unused_boxes, Tensor unused_threshold, Tensor output_size, Tensor idx go into args + Tensor _loop_cond(object[] args) + => /*new object[] {*/math_ops.logical_and( + math_ops.reduce_min((Tensor)args[2]) < max_output_size, + (Tensor)args[3] < num_iterations); + + // Tensor boxes, Tensor iou_threshold, Tensor output_size, Tensor idx go into args + object[] suppression_loop_body(object[] args) + { + (Tensor a, float b, Tensor c, int d) = _suppression_loop_body((Tensor)args[0], (float)args[1], (Tensor)args[2], (int)args[3], tile_size); + return new object[] { a, b, c, d }; + } + + object[] selboxes__output_size_ = null; + /* + errors here regarding the while loop and types + + object[] selboxes__output_size_= control_flow_ops.while_loop( + cond: (Tensor[] args) => _loop_cond(args), + body: (Tensor[] args) => suppression_loop_body(args), + loop_vars: new object[] { + boxes, iou_threshold, + array_ops.zeros(new Shape(batch_size), dtypes.int32), + constant_op.constant(0) + }, + shape_invariants: new Shape[] { + new Shape(new int[] {Unknown, Unknown, 4}), + new Shape(new int[] {}), + new Shape(new int[] {Unknown}), + new Shape(new int[] {}) + } + ); + */ + var num_valid = math_ops.minimum(selboxes__output_size_[2], max_output_size); + + (Tensor values, Tensor indices) = gen_ops.top_k_v2( + math_ops.cast(math_ops.reduce_any( + (Tensor)selboxes__output_size_[0] > 0, new(2)), dtypes.int32) * + array_ops.expand_dims( + math_ops.range(num_boxes_after_padding, 0, -1), 0), + max_output_size); + Tensor idx = num_boxes_after_padding - values.shape.as_int_list()[0]; + idx = math_ops.minimum(idx, num_boxes - 1); + + if (!sorted_input) + { + var index_offsets = math_ops.range(batch_size) * num_boxes; + var gather_idx = array_ops.reshape( + idx + array_ops.expand_dims(index_offsets, 1), new[] { -1 }); + idx = array_ops.reshape( + array_ops.gather(array_ops.reshape(sorted_indices, new[] { -1 }), + gather_idx), + new[] { batch_size, -1 }); + } + var invalid_index = array_ops.fill(new Shape((int)batch_size, (int)max_output_size), 0); + var idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0); + var num_valid_expanded = array_ops.expand_dims(num_valid, 1); + idx = array_ops.where(idx_index < num_valid_expanded, + idx, invalid_index); + num_valid = array_ops.reshape(num_valid, batch_dims); + return (idx, num_valid); + } + + internal static (Tensor, Tensor) non_max_suppression_padded_v1(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f, + float score_threshold = -1f / 0f, bool pad_to_max_output_size = false, string name = null) + { + return tf_with(ops.name_scope(name, "non_max_supression_padded"), delegate + { + var iou_threshold_tensor = ops.convert_to_tensor(iou_threshold, name: "iou_threshold"); + var score_threshold_tensor = ops.convert_to_tensor(score_threshold, name: "score_threshold"); + return gen_ops.non_max_suppression_v4(boxes, scores, max_output_size, iou_threshold_tensor, score_threshold_tensor, pad_to_max_output_size); + }); + } + + public static Tensor encode_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_jpeg"), scope => + { + return gen_ops.encode_jpeg(contents, name:name); + }); + } + + public static Tensor encode_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_png"), scope => + { + return gen_ops.encode_png(contents, name: name); + }); + } + + public static Tensor is_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_jpeg"), scope => + { + var substr = tf.strings.substr(contents, 0, 3); + var jpg = tf.constant(new byte[] { 0xff, 0xd8, 0xff }, TF_DataType.TF_STRING); + var result = math_ops.equal(substr, jpg, name: name); + return result; + }); + } + + static Tensor is_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_png"), scope => + { + var substr = tf.strings.substr(contents, 0, 3); + return math_ops.equal(substr, @"\211PN", name: name); + }); + } + + static Tensor is_gif(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_gif"), scope => + { + var substr = tf.strings.substr(contents, 0, 3); + var gif = tf.constant(new byte[] { 0x47, 0x49, 0x46 }, TF_DataType.TF_STRING); + var result = math_ops.equal(substr, gif, name: name); + return result; + }); + } + + public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, + string name = null) + { + image = ops.convert_to_tensor(image, name: "image"); + // var tf_dtype = dtypes.as_dtype(dtype); + if (!dtype.is_floating() && !dtype.is_integer()) + throw new TypeError("dtype must be either floating point or integer"); + if (dtype == image.dtype) + return array_ops.identity(image, name: name); + + // declarations for later + Tensor cast; + + return tf_with(ops.name_scope(name, "convert_image", new[] { image }), name => + { + if (image.dtype.is_integer() && dtype.is_integer()) + { + var scale_in = image.dtype.max(); + var scale_out = dtype.max(); + if (scale_in > scale_out) + { + var scale = Math.Floor((decimal)(scale_in + 1) / (scale_out + 1)); + var scaled = math_ops.floordiv(image, ops.convert_to_tensor(scale)); + + if (saturate) + return math_ops.saturate_cast(scaled, dtype, name: name); + else + return math_ops.cast(scaled, dtype, name: name); + } + else + { + if (saturate) + cast = math_ops.saturate_cast(image, dtype); + else + cast = math_ops.cast(image, dtype); + var scale = Math.Floor((decimal)(scale_in + 1) / (scale_out + 1)); + return math_ops.multiply(cast, scale, name: name); + } + } + else if (image.dtype.is_floating() && dtype.is_floating()) + return math_ops.cast(image, dtype, name: name); + else + { + if (image.dtype.is_integer()) + { + cast = math_ops.cast(image, dtype); + var scale = 1 / image.dtype.max(); + return math_ops.multiply(cast, scale, name: name); + } + else + { + var scale = dtype.max() + 0.5; + var scaled = math_ops.multiply(image, scale); + if (saturate) + return math_ops.saturate_cast(scaled, dtype, name: name); + else + return math_ops.cast(scaled, dtype, name: name); + } + } + }); + } + + /// + /// Resize `images` to `size` using the specified `method`. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor resize_images_v2(Tensor images, T size, string method = ResizeMethod.BILINEAR, + bool preserve_aspect_ratio = false, + bool antialias = false, + string name = null) + { + Func resize_fn = (images, size) => + { + if (method == ResizeMethod.BILINEAR) + return gen_image_ops.resize_bilinear(images, size, half_pixel_centers: true); + else if (method == ResizeMethod.NEAREST_NEIGHBOR) + return gen_image_ops.resize_nearest_neighbor(images, size, half_pixel_centers: true); + + throw new NotImplementedException("resize_images_v2"); + }; + + var size_tensor = ops.convert_to_tensor(size, dtype: tf.int32); + return _resize_images_common(images, resize_fn, size_tensor, + preserve_aspect_ratio: preserve_aspect_ratio, + skip_resize_if_same: false, + name: name); + } + + /// + /// Resize `images` to `size` using nearest neighbor interpolation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + string name = null, bool half_pixel_centers = false) + => gen_image_ops.resize_nearest_neighbor(images: images, + size: size, + align_corners: align_corners, + half_pixel_centers: half_pixel_centers, + name: name); + + public static Tensor draw_bounding_boxes(Tensor images, Tensor boxes, Tensor colors = null, string name = null) + { + if (colors == null) + return gen_ops.draw_bounding_boxes(images, boxes, name); + return gen_ops.draw_bounding_boxes(images, boxes, /*colors,*/ name); + } + + // TOOD: implement arguments, gen_ops + public static Tensor generate_bounding_box_proposals() + { + throw new NotImplementedException("generate_bounding_box_propsosals"); + } + } + + public class ResizeMethod + { + public ResizeMethod() + { + } + + public const string BILINEAR = "bilinear"; + public const string NEAREST_NEIGHBOR = "nearest"; + public const string BICUBIC = "bicubic"; + public const string AREA = "area"; + public const string LANCZOS3 = "lanczos3"; + public const string LANCZOS5 = "lanczos5"; + public const string GAUSSIAN = "gaussian"; + public const string MITCHELLCUBIC = "mitchellcubic"; + } +} diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs new file mode 100644 index 000000000..0b77689d5 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -0,0 +1,91 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class io_ops + { + public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute( + new FastPathOpExecInfo(tf.Context, "SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); + result = null; + return null; + } + catch (System.Exception) + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); + } + } + var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); + + return _op; + } + + public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) + { + DataType[] attr_dtypes; + (attr_dtypes, tensors) = _execute.onvert_to_mixed_eager_tensors(tensors, ctx); + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); + var attrs = new object[] { "dtypes", attr_dtypes }; + + var result = _execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); + result = null; + return null; + } + + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + { + // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. + var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); + + return _op.outputs; + } + + public Tensor read_file(T filename, string name = null) + { + if (tf.Context.executing_eagerly()) + { + return read_file_eager_fallback(filename, name: name, tf.Context); + } + + var _op = tf.OpDefLib._apply_op_helper("ReadFile", name: name, args: new { filename }); + + return _op.outputs[0]; + } + + private Tensor read_file_eager_fallback(T filename, string name = null, Context ctx = null) + { + var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); + var _inputs_flat = new[] { filename_tensor }; + + return tf.Runner.Execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs new file mode 100644 index 000000000..42da1a279 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs @@ -0,0 +1,140 @@ +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class linalg_ops + { + public Tensor eye(int num_rows, + int num_columns = -1, + Shape batch_shape = null, + TF_DataType dtype = TF_DataType.TF_DOUBLE, + string name = null) + { + return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope => + { + if (num_columns == -1) + num_columns = num_rows; + + bool is_square = num_columns == num_rows; + var diag_size = Math.Min(num_rows, num_columns); + if (batch_shape == null) + batch_shape = new Shape(new int[0]); + var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape"); + var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0); + + Tensor shape = null; + if (!is_square) + shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0); + + var diag_ones = array_ops.ones(diag_shape, dtype: dtype); + if (is_square) + return array_ops.matrix_diag(diag_ones); + else + { + var zero_matrix = array_ops.zeros(shape, dtype: dtype); + return array_ops.matrix_set_diag(zero_matrix, diag_ones); + } + }); + } + + public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixInverse", name, + new ExecuteOpArgs(input).SetAttributes(new + { + adjoint + })); + + public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs, + Tensor l2_regularizer = null, bool fast = true, string name = null) + { + return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); + } + + public Tensor norm(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null, bool keepdims = true) + { + var is_matrix_norm = axis != null && len(axis) == 2; + return tf_with(ops.name_scope(name, default_name: "norm", tensor), scope => + { + if (is_matrix_norm) + throw new NotImplementedException(""); + var result = math_ops.sqrt(math_ops.reduce_sum(tensor * math_ops.conj(tensor), axis, keepdims: true)); + + if(!keepdims) + result = array_ops.squeeze(result, axis); + return result; + }); + } + + Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + Shape matrix_shape = matrix.shape.dims.Skip(matrix.shape.ndim - 2).ToArray(); + if (matrix_shape.IsFullyDefined) + { + if (matrix_shape[-2] >= matrix_shape[-1]) + return _overdetermined(matrix, rhs, l2_regularizer); + else + return _underdetermined(matrix, rhs, l2_regularizer); + } + + throw new NotImplementedException(""); + } + + Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true); + return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true)); + } + + Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false); + return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true); + } + + Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind) + { + var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind); + + if (l2_regularizer != null) + { + var matrix_shape = array_ops.shape(matrix); + var batch_shape = matrix_shape[":-2"]; + var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2]; + var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype); + var small_dim_static = matrix.shape[first_kind ? -1 : -2]; + identity.shape = matrix.shape.dims.Take(matrix.shape.ndim - 2).ToArray().concat(new[] { small_dim_static, small_dim_static }); + gramian += l2_regularizer * identity; + } + + return cholesky(gramian); + } + + public Tensor cholesky(Tensor input, string name = null) + => tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input)); + + public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null) + => tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope => + { + var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true); + var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true); + return x; + }); + + public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixTriangularSolve", name, + new ExecuteOpArgs(matrix, rhs).SetAttributes(new + { + lower, + adjoint + })); + + public Tensors qr(Tensor input, bool full_matrices = false, string name = null) + => tf.Context.ExecuteOp("Qr", name, + new ExecuteOpArgs(input).SetAttributes(new + { + full_matrices + })); + } +} diff --git a/src/TensorFlowNET.Core/Operations/list_ops.cs b/src/TensorFlowNET.Core/Operations/list_ops.cs new file mode 100644 index 000000000..3791a2c19 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/list_ops.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow.Operations +{ + internal class list_ops + { + private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype) + { + if(list_handle is EagerTensor eagerTensor) + { + var handle_data = new CppShapeInferenceResult.Types.HandleData(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType() + { + Shape = element_shape.as_proto(), + Dtype = element_dtype.as_datatype_enum(), + Type = new FullTypeDef() { TypeId = FullTypeId.TftArray } + }); + list_handle.HandleData = handle_data; + } + } + + private static Tensor _build_element_shape(Shape? shape) + { + if(shape is null || shape.IsNull) + { + return ops.convert_to_tensor(-1); + } + else + { + return ops.convert_to_tensor(shape, dtype: dtypes.int32); + } + } + + public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null) + { + var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name); + _set_handle_data(result, shape, element_dtype); + return result; + } + + public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null) + { + var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name); + _set_handle_data(result, tensor.shape, tensor.dtype); + return result; + } + + public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape), + element_dtype, name); + } + + public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, + bool resize_if_index_out_of_bounds = false, string? name = null) + { + if (resize_if_index_out_of_bounds) + { + var input_list_size = gen_list_ops.tensor_list_length(input_handle); + input_handle = control_flow_ops.cond(index >= input_list_size, + () => gen_list_ops.tensor_list_resize(input_handle, index + 1), + () => input_handle); + } + var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name); + handle_data_util.copy_handle_data(input_handle, output_handle); + return output_handle; + } + + public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name); + } + + public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name); + } + + public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null, + string? name = null) + { + if(input_handle is not null) + { + var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name); + handle_data_util.copy_handle_data(input_handle, output_handle); + return output_handle; + } + else + { + var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape), + constant_op.constant(-1), name); + _set_handle_data(output_handle, element_shape, tensor.dtype); + return output_handle; + } + } + + public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1, + string? name = null) + { + return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype, + max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/logging_ops.cs b/src/TensorFlowNET.Core/Operations/logging_ops.cs new file mode 100644 index 000000000..3303cadc3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/logging_ops.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class logging_ops + { + public Tensor print_v2(Tensor input, string output_stream = "stderr", string end = "\n", string name = null) + { + var formatted_string = tf.strings.format("{}", + new[] { input }, + placeholder: "{}", + summarize: 3, + name: name); + + return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string) + .SetAttributes(new { output_stream, end })).SingleOrNull; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/map_fn.cs b/src/TensorFlowNET.Core/Operations/map_fn.cs new file mode 100644 index 000000000..a754f230a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/map_fn.cs @@ -0,0 +1,185 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ +#pragma warning disable CS0659 // 'Operation' overrides Object.Equals(object o) but does not override Object.GetHashCode() + public partial class Operation +#pragma warning restore CS0659 // 'Operation' overrides Object.Equals(object o) but does not override Object.GetHashCode() + { + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public static Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + { + bool input_is_sequence = nest.is_sequence(elems); + Tensor[] input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x).ToArray() : new[] { x }; + Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; + + bool output_is_sequence; + Func output_flatten; + Func output_pack; + if (dtype == TF_DataType.DtInvalid) + { + output_is_sequence = input_is_sequence; + output_flatten = input_flatten; + output_pack = input_pack; + } + else + { + output_is_sequence = nest.is_sequence(dtype); + output_flatten = (x) => output_is_sequence ? nest.flatten(x).ToArray() : new[] { x }; + output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(dtype, x) : x[0]; + } + + var elems_flat = input_flatten(elems); + return tf_with(ops.name_scope(name, "map", elems_flat), delegate + { + //if in_graph_mode: + //# Any get_variable calls in fn will cache the first call locally + //# and not issue repeated network I/O requests for each iteration. + //varscope = vs.get_variable_scope() + //varscope_caching_device_was_none = False + //if varscope.caching_device is None: + // # TODO(ebrevdo): Change to using colocate_with here and in other + // # methods. + // varscope.set_caching_device(lambda op: op.device) + // varscope_caching_device_was_none = True + + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")) + .ToArray(); + + dtype = elems_flat.Select(elem => elem.dtype).First(); + var dtype_flat = new[] { dtype }; + + // Convert elems to tensor array. n may be known statically. + var static_shape = elems_flat[0].shape; + + var n = static_shape[0]; + + // TensorArrays are always flat + var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype, + size: Convert.ToInt32(n), + dynamic_size: false, + infer_shape: true)).ToArray(); + + // Unpack elements + var elems_ta_1 = new List(); + foreach (var (elem_ta, elem) in zip(elems_ta, elems_flat)) + elems_ta_1.Add(elem_ta.unstack(elem)); + + elems_ta = elems_ta_1.ToArray(); + + var i = constant_op.constant(0); + + var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt, + size: Convert.ToInt32(n), + dynamic_size: false, + infer_shape: infer_shape)).ToArray(); + + + BodyItem compute(BodyItem item) + { + var packed_values = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray()); + var packed_fn_values = fn(packed_values); + //nest.assert_same_structure(dtype or elems, packed_fn_values) + + var flat_fn_values = output_flatten(packed_fn_values); + for (int j = 0; j < item.Accs_ta.Length; j++) + { + item.Accs_ta[j].write(item.I, flat_fn_values[j]); + } + + return new BodyItem(item.I + 1, item.Accs_ta); + } + + var r_a = control_flow_ops.while_loop( + (x) => x.I < n, + compute, + new BodyItem(i, accs_ta), + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + maximum_iterations: tf.constant(n)); + var results_flat = r_a.Accs_ta.Select(r => r.stack()).ToArray(); + + var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape.with_rank_at_least(1).dims[0])); + + foreach (var elem in elems_flat.Skip(1)) + { + n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape.with_rank_at_least(1).dims[0]))); + } + + foreach (Tensor r in results_flat) + { + r.shape = new Shape(n_static).concatenate(r.dims.Skip(1).ToArray()); + } + + // todo get working when the above caching_device is fixed + //if (in_graph_mode && varscope_caching_device_was_none) { + // varscope.set_caching_device(None); + //} + + return output_pack(results_flat); + }); + } + + internal class BodyItem : ICanBeFlattened, IPackable, IFromMergeVars + { + public Tensor I { get; set; } + public TensorArray[] Accs_ta { get; set; } + + public BodyItem() + { + } + + public BodyItem(Tensor i, TensorArray[] accs_ta) + { + I = i; + Accs_ta = accs_ta; + } + + public object[] Flatten() + { + var elements = new List { I }; + elements.AddRange(Accs_ta); + return elements.ToArray(); + } + + public BodyItem Pack(object[] sequences) + { + I = sequences[0] as Tensor; + Accs_ta = new[] { sequences[1] as TensorArray }; + + return new BodyItem(I, Accs_ta); + } + + public BodyItem FromMergeVars(ITensorOrTensorArray[] merge_vars) + { + I = (Tensor)merge_vars[1]; + Accs_ta = new[] { (TensorArray)merge_vars[2] }; + return this; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs new file mode 100644 index 000000000..e77df702f --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -0,0 +1,1122 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using Tensorflow.Operations; +using System.Runtime.CompilerServices; + +namespace Tensorflow +{ + /// + /// python\ops\math_ops.py + /// + public class math_ops + { + public static Tensor abs(Tensor x, string name = null) + { + return tf_with(ops.name_scope(name, "Abs", new { x }), scope => + { + name = scope; + x = ops.convert_to_tensor(x, name: "x"); + if (x.dtype.is_complex()) + { + return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name); + } + return gen_math_ops.abs(x, name: name); + }); + } + + public static Tensor add(Tx x, Ty y, string name = null) + => gen_math_ops.add(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + public static Tensor add_v2(Tensor x, Tensor y, string name = null) + => tf.Context.ExecuteOp("AddV2", name, new ExecuteOpArgs(x, y)); + + public static Tensor add_v2(Tx x, Ty y, string name = null) + => gen_math_ops.add_v2(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + + /// + /// Adds all input tensors element-wise. + /// + /// + /// + /// + public static Tensor add_n(Tensor[] inputs, string name = null) + { + inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs); + + if (inputs.Length == 1) + { + var values = inputs[0]; + if (name != null) + return array_ops.identity(values, name: name); + return values; + } + + return gen_math_ops.add_n(inputs, name: name); + } + + public static Tensor argmax(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); + + public static Tensor argmin(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name); + + public static Tensor round(Tensor x, string name = null) + { + x = ops.convert_to_tensor(x, name: "x"); + if (x.dtype.is_integer()) + return x; + else + return gen_math_ops.round(x, name: name); + } + + public static Tensor cast(IVariableV1 x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + { + var base_type = dtype.as_base_dtype(); + if (base_type == x.dtype) + return x.AsTensor(); + + return tf_with(ops.name_scope(name, "Cast", new { x }), scope => + { + name = scope; + var t_x = ops.convert_to_tensor(x, name: "x"); + if (t_x.dtype.as_base_dtype() != base_type) + t_x = gen_math_ops.cast(t_x, base_type, name: name); + + return x.AsTensor(); + }); + } + + public static ResourceVariable cast(ResourceVariable x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + { + var base_type = dtype.as_base_dtype(); + if (base_type == x.dtype) + return x; + + return tf_with(ops.name_scope(name, "Cast", new { x }), scope => + { + name = scope; + var t_x = ops.convert_to_tensor(x, name: "x"); + if (t_x.dtype.as_base_dtype() != base_type) + t_x = gen_math_ops.cast(t_x, base_type, name: name); + + return x; + }); + } + + public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + { + var base_type = dtype.as_base_dtype(); + if (base_type == x.dtype) + return x; + + return tf_with(ops.name_scope(name, "Cast", new { x }), scope => + { + name = scope; + if (x.dtype.as_base_dtype() != base_type) + x = gen_math_ops.cast(x, base_type, name: name); + + return x; + }); + } + + public static Tensor cos(Tensor x, string name = null) + => tf.Context.ExecuteOp("Cos", name, new ExecuteOpArgs(x)); + + public static Tensor saturate_cast(Tensor value, TF_DataType dtype, string name = null) + { + return tf_with(ops.name_scope(name, "saturate_cast", new[] { value }), name => + { + value = ops.convert_to_tensor(value, name: "value"); + // dtype = dtypes.as_dtype(dtype).as_base_dtype(); + if (value.dtype.min() < dtype.min()) + value = gen_math_ops.maximum( + value, + ops.convert_to_tensor(dtype.min(), dtype: value.dtype, name: "min")); + if (value.dtype.max() > dtype.max()) + value = gen_math_ops.minimum( + value, + ops.convert_to_tensor(dtype.max(), dtype: value.dtype, name: "max")); + return cast(value, dtype, name: name); + }); + } + + public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) + => tf_with(ops.name_scope(name, "Cumsum", new { x }), scope => + { + name = scope; + return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis) + .SetAttributes(new { exclusive, reverse })); + }); + + /// + /// Computes Psi, the derivative of Lgamma (the log of the absolute value of + /// `Gamma(x)`), element-wise. + /// + /// + /// + /// + public static Tensor digamma(Tensor x, string name = null) + => gen_math_ops.digamma(x, name: name); + + /// + /// Divide two values using Python 2 semantics. Used for Tensor.__div__. + /// + /// `Tensor` numerator of real numeric type. + /// `Tensor` denominator of real numeric type. + /// A name for the operation + /// `x / y` returns the quotient of x and y. + public static Tensor div(Tensor x, Tensor y, string name = null) + { + return tf_with(ops.name_scope(name, "div", (x, y)), name_scope => + { + name = name_scope; + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name = "y"); + var x_dtype = x.dtype.as_base_dtype(); + var y_dtype = y.dtype.as_base_dtype(); + if (x_dtype != y_dtype) + throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}"); + if (x_dtype.is_floating() || x_dtype.is_complex()) + return gen_math_ops.real_div(x, y, name: name); + else + return gen_math_ops.floor_div(x, y, name: name); + }); + } + + /// + /// Returns 0 if the denominator is zero. + /// + /// + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DivNoNan'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// + /// *NOTE*: DivNoNan supports broadcasting. More about broadcasting + /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + /// + public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) + { + return tf_with(ops.name_scope(name, "div_no_nan", (x, y)), name_scope => + { + name = name_scope; + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, name: "y", dtype: x.dtype.as_base_dtype()); + var x_dtype = x.dtype.as_base_dtype(); + var y_dtype = y.dtype.as_base_dtype(); + if (x_dtype != y_dtype) + throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}"); + return gen_math_ops.div_no_nan(x, y, name: name); + }); + } + + public static Tensor einsum(string equation, Tensors inputs, string name = null) + { + return tf_with(ops.name_scope(name, "einsum", inputs), scope => + { + name = scope; + return tf.Context.ExecuteOp("Einsum", name, new ExecuteOpArgs + { + OpInputArgs = new object[] { inputs.ToArray() }, + GetGradientAttrs = (op) => new + { + equation = op.get_attr("equation"), + N = op.get_attr("N"), + T = op.get_attr("T") + } + }.SetAttributes(new + { + equation = equation + })); + }); + } + + public static Tensor greater_equal(Tx x, Ty y, string name = null) + => gen_math_ops.greater_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + public static Tensor equal(Tx x, Ty y, string name = null) + => gen_math_ops.equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public static Tensor erf(Tensor x, string name = null) + => tf.Context.ExecuteOp("Erf", name, new ExecuteOpArgs(x)); + + public static Tensor sqrt(Tensor x, string name = null) + => tf.Context.ExecuteOp("Sqrt", name, new ExecuteOpArgs(x)); + + public static Tensor multiply(Tensor x, Tensor y, string name = null) + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); + + public static Tensor multiply(Tx x, Ty y, string name = null) + => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + public static Tensor not_equal(Tx x, Ty y, string name = null) + => gen_math_ops.not_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + public static Tensor mul_no_nan(Tx x, Ty y, string name = null) + => gen_math_ops.mul_no_nan(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + public static Tensor scalar_mul(Tscale scale, Tx x, string name = null) + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(scale, x)); + + public static Tensor real(Tensor input, string name = null) + { + return tf_with(ops.name_scope(name, "Real", new[] { input }), scope => + { + // name = scope; + input = ops.convert_to_tensor(input, name: "input"); + if (input.dtype.is_complex()) + { + var real_dtype = input.dtype.real_dtype(); + return real(input, name: scope); + } + else + { + return input; + } + }); + } + + /// + /// Computes the mean of elements across dimensions of a tensor. + /// Reduces `input_tensor` along the dimensions given in `axis`. + /// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + /// entry in `axis`. If `keepdims` is true, the reduced dimensionsare retained with length 1. + /// If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. + /// + /// The tensor to reduce. Should have numeric type. + /// The dimensions to reduce. If `None` (the default), reduces all + /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. + /// If true, retains reduced dimensions with length 1. + /// A name for the operation (optional). + public static Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) + { + var r = _ReductionDims(input_tensor, axis); + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); + var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis_tensor, m); + } + + /// + /// Computes the product of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + if (axis == null) + { + var m = gen_math_ops.prod(input_tensor, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + else + { + var m = gen_math_ops.prod(input_tensor, axis, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + } + + public static Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + { + if (name == null) + name = "reduce_std"; + // else {name = name;} + + return tf_with(ops.name_scope(name, "reduce_std", new[] { input_tensor }), scope => + { + var variance = reduce_variance(input_tensor, axis: axis, keepdims: keepdims); + return gen_math_ops.sqrt(variance); + }); + } + + public static Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + { + if (name == null) + name = "reduce_variance"; + // else {name = name;} + + return tf_with(ops.name_scope(name, "reduce_variance", new[] { input_tensor }), scope => + { + var means = reduce_mean(input_tensor, axis: axis, keepdims: true); + if (means.dtype.is_integer()) + throw new TypeError("Input must be either real or complex"); + var diff = input_tensor - means; + + Tensor squared_deviations; + if (diff.dtype.is_complex()) + { + var real_dtype = diff.dtype.real_dtype(); + squared_deviations = real( + gen_math_ops.mul(conj(diff), diff)); + } + else + { + squared_deviations = gen_math_ops.square(diff); + } + return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims); + }); + } + + public static Tensor sigmoid(T x, string name = null) + => tf_with(ops.name_scope(name, "Sigmoid", x), scope => + { + name = scope; + var x_tensor = ops.convert_to_tensor(x, name: "x"); + return gen_math_ops.sigmoid(x_tensor, name: name); + }); + + public static Tensor sign(T x, string name = null) + => gen_math_ops.sign(ops.convert_to_tensor(x), name: name); + + public static Tensor sin(Tensor x, string name = null) + => tf.Context.ExecuteOp("Sin", name, new ExecuteOpArgs(x)); + + /// + /// Returns (x - y)(x - y) element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`. + /// A `Tensor`. Must have the same type as `x`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. + public static Tensor square_difference(Tensor x, Tensor y, string name = null) + { + var m = gen_math_ops.squared_difference(x, y); + return m; + } + + public static Tensor square(Tensor x, string name = null) + { + return gen_math_ops.square(x, name); + } + + public static Tensor subtract(Tx x, Ty y, string name = null) + { + return gen_math_ops.sub(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); + } + + public static Tensor log(Tensor x, string name = null) + { + return gen_math_ops.log(x, name); + } + + public static Tensor logical_and(Tensor x, Tensor y, string name = null) + => gen_math_ops.logical_and(x, y, name: name); + + public static Tensor lgamma(Tensor x, string name = null) + => gen_math_ops.lgamma(x, name: name); + + public static Tensor linspace(Tensor start, Tensor stop, int num = 50, string name = null, int axis = 0) + { + return tf_with(ops.name_scope(name, "linspace", new { start, stop }), scope => + { + var num_int_tensor = array_ops.constant(num); + var num_tensor = array_ops.constant(num, dtype: start.dtype); + + var broadcast_shape = array_ops.broadcast_dynamic_shape(array_ops.shape(start), array_ops.shape(stop)); + start = gen_array_ops.broadcast_to(start, broadcast_shape); + stop = gen_array_ops.broadcast_to(stop, broadcast_shape); + + var expanded_start = array_ops.expand_dims(start, axis: axis); + var expanded_stop = array_ops.expand_dims(stop, axis: axis); + + var shape = array_ops.shape(expanded_start); + var ndims = array_ops.shape(shape)[0]; + + var axis_tensor = array_ops.where_v2(constant_op.constant(axis >= 0), x: axis, y: ndims + axis); + + // The purpose is to avoid having negative values when repeating. + var num_fill = gen_math_ops.maximum(num_int_tensor - 2, ops.convert_to_tensor(0)); + var n_steps = gen_math_ops.maximum(num_int_tensor - 1, ops.convert_to_tensor(1)); + var delta = (expanded_stop - expanded_start) / cast(n_steps, expanded_stop.dtype); + + var range_end = array_ops.where_v2(num_int_tensor >= 0, n_steps, -1); + var desired_range = cast(range(1, range_end, dtype: dtypes.int64), delta.dtype); + var mask = gen_math_ops.equal(axis_tensor, range(ndims)); + var desired_range_shape = array_ops.where_v2(mask, num_fill, 1); + desired_range = array_ops.reshape(desired_range, desired_range_shape); + var res = expanded_start + delta * desired_range; + + // Add the start and endpoints to the result, and slice out the desired + // portion. + var all_tensors = new[] { expanded_start, res, expanded_stop }; + var concatenated = array_ops.concat(all_tensors, axis: axis); + var begin = array_ops.zeros_like(shape); + var size = array_ops.where_v2(mask, num_int_tensor, shape); + + return array_ops.slice(concatenated, begin, size); + }); + + throw new NotImplementedException(""); + } + + /// + /// Helper function for reduction ops. + /// + /// 1-D Tensor, the shape of the Tensor being reduced. + /// 1-D Tensor, the reduction axes. + /// A 1-D Tensor, the output shape as if keepdims were set to True. + public static Tensor reduced_shape(Tensor input_shape, Tensor axes) + { + if (tf.Context.executing_eagerly()) + { + var input_shape_val = input_shape.numpy(); + foreach (var axes_val in axes.ToArray()) + input_shape_val[axes_val] = 1; + return tf.constant(input_shape_val); + } + + input_shape = to_int32(input_shape); + axes = to_int32(axes); + + var input_rank = array_ops.size(input_shape); + axes = (axes + input_rank) % input_rank; + var axes_shape = array_ops.shape(axes); + var rng = math_ops.range(input_rank); + var a1 = new Tensor[] { rng, axes }; + var fill = gen_array_ops.fill(axes_shape, ops.convert_to_tensor(1)); + var a2 = new Tensor[] { input_shape, fill }; + + return gen_data_flow_ops.dynamic_stitch(a1, a2); + } + + /// + /// Computes the reciprocal of x element-wise. + /// + /// + /// + /// + public static Tensor reciprocal(Tensor x, string name = null) + => gen_math_ops.reciprocal(x, name: name); + + /// + /// Computes the "logical and" of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) + { + var all = gen_math_ops.all(input_tensor, + _ReductionDims(input_tensor, axis), + keepdims, + name: name); + + return _may_reduce_to_scalar(keepdims, axis, all); + } + + public static Tensor realdiv(Tensor x, Tensor y, string name = null) + => gen_math_ops.real_div(x, y, name: name); + + /// + /// Computes log(sum(exp(elements across dimensions of a tensor))). + /// Reduces `input_tensor` along the dimensions given in `axis`. + /// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + /// entry in `axis`. If `keepdims` is true, the reduced dimensions + /// are retained with length 1. + /// + /// If `axis` has no entries, all dimensions are reduced, and a + /// tensor with a single element is returned. + /// + /// This function is more numerically stable than log(sum(exp(input))). It avoids + /// overflows caused by taking the exp of large inputs and underflows caused by + /// taking the log of small inputs. + /// + /// The tensor to reduce. Should have numeric type. + /// The dimensions to reduce. If `None` (the default), reduces all + /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. + /// + /// The reduced tensor. + public static Tensor reduce_logsumexp(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) + { + return tf_with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => + { + var raw_max = reduce_max(input_tensor, axis, true); + var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max))); + var result = gen_math_ops.log( + reduce_sum( + gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), + constant_op.constant(axis[0]), + keepdims)); + if (!keepdims) + { + my_max = array_ops.reshape(my_max, array_ops.shape(result)); + } + result = gen_math_ops.add(result, my_max); + return _may_reduce_to_scalar(keepdims, axis, result); + }); + } + + public static Tensor reduce_any(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var max = (axis != null) ? gen_math_ops.any(input_tensor, axis, keepdims, name) : + gen_math_ops.any(input_tensor, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, max); + } + + public static Tensor reduce_euclidean_norm(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var distance = tf.Context.ExecuteOp("EuclideanNorm", name, + new ExecuteOpArgs(input_tensor, r).SetAttributes(new + { + keep_dims = keepdims + })); + return _may_reduce_to_scalar(keepdims, axis, distance); + } + + public static Tensor reduce_max(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var max = (axis != null) ? gen_math_ops.max(input_tensor, axis, keepdims, name) : + gen_math_ops.max(input_tensor, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, max); + } + + public static Tensor reduce_min(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var min = gen_math_ops.min(input_tensor, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, min); + } + + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) + => gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name); + + /// + /// Casts a tensor to type `int32`. + /// + /// A `Tensor` or `SparseTensor` or `IndexedSlices`. + /// A name for the operation (optional). + /// A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with type `int32`. + private static Tensor to_int32(Tensor x, string name = "ToInt32") + { + return __case__(x, TF_DataType.TF_INT32, name: name); + } + + /// + /// Casts a tensor to a new type. + /// + /// + /// + /// + /// A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and same type as `dtype`. + public static Tensor __case__(Tensor x, TF_DataType dtype, string name = null) + { + var base_type = dtype.as_base_dtype(); + if (x is Tensor && base_type == x.dtype) + return x; + + // math_ops.py cast + throw new NotImplementedException(); + } + + public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var m = gen_math_ops.sum(input_tensor, r, keep_dims: keepdims, name: name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + + private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output) + { + if (!common_shapes.has_fully_defined_shape(output) && + !keepdims && + axis == null) + // We want set_shape to be reflected in the C API graph for when we run it. + output.shape = new int[0]; + return output; + } + + private static Tensor _may_reduce_to_scalar(bool keepdims, Axis axis, Tensor output) + { + if (!common_shapes.has_fully_defined_shape(output) && + !keepdims && + axis == null) + output.shape = new int[0]; + return output; + } + + private static Tensor _may_reduce_to_scalar(bool keepdims, int? axis, Tensor output) + { + return output; + } + + private static Tensor _ReductionDims(Tensor x, Tensor axis) + { + if (axis != null) + { + return axis; + } + else + { + var rank = array_ops.rank(x); + return range(0, rank, 1); + } + } + + private static Tensor _ReductionDims(Tensor x, Axis? axis) + { + if (axis != null) + { + // should return axis. or check before. + return ops.convert_to_tensor(axis, TF_DataType.TF_INT32); + } + else + { + var rank = common_shapes.rank(x); + + // we rely on Range and Rank to do the right thing at run-time. + if (rank == -1) return range(0, array_ops.rank(x)); + + return range(0, rank, 1); + } + } + + /// + /// Computes reciprocal of square root of x element-wise. + /// + /// + /// + /// + public static Tensor rsqrt(Tensor x, string name = null) + => gen_math_ops.rsqrt(x, name: name); + + public static Tensor pow(Tx x, Ty y, string name = null) + => tf_with(ops.name_scope(name, "Pow", new { x, y }), scope => + { + name = scope; + var x_tensor = ops.convert_to_tensor(x, name: "x"); + var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); + + return tf.Context.ExecuteOp("Pow", name, new ExecuteOpArgs(x_tensor, y_tensor)); + }); + + public static Tensor range(object start, object limit = null, object delta = null, TF_DataType? dtype = null, string name = "range") + { + if (limit == null) + { + limit = start; + start = 0; + } + + var dtype1 = dtype ?? limit.GetDataType(); + + return tf_with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => + { + name = scope; + var start1 = ops.convert_to_tensor(start, name: "start", dtype: dtype1); + var limit1 = ops.convert_to_tensor(limit, name: "limit", dtype: dtype1); + var delta1 = ops.convert_to_tensor(delta ?? 1, name: "delta", dtype: dtype1); + return gen_math_ops.range(start1, limit1, delta1, name); + }); + } + public static Tensor floor(Tensor x, string name = null) + => tf.Context.ExecuteOp("Floor", name, new ExecuteOpArgs(x)); + + public static Tensor floordiv(Tensor x, Tensor y, string name = null) + { + return tf_with(ops.name_scope(name, "floordiv", new { x, y }), scope => + { + return gen_math_ops.floor_div(x, y, scope); + }); + } + + public static Tensor minimum(Tx x, Ty y, string name = null) + => gen_math_ops.minimum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + public static Tensor maximum(Tx x, Ty y, string name = null) + => gen_math_ops.maximum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); + + /// + /// Multiplies matrix `a` by matrix `b`, producing `a` * `b`. + /// + /// + /// + /// If `True`, `a` is transposed before multiplication. + /// If `True`, `b` is transposed before multiplication. + /// If `True`, `a` is conjugated and transposed before multiplication. + /// If `True`, `b` is conjugated and transposed before multiplication. + /// If `True`, `a` is treated as a sparse matrix. + /// If `True`, `b` is treated as a sparse matrix. + /// Name for the operation (optional). + /// + /// A `Tensor` of the same type as `a` and `b` where each inner-most matrix is + /// the product of the corresponding matrices in `a` and `b`, e.g. if all + /// transpose or adjoint attributes are `False`: + /// + public static Tensor matmul(Tensor a, Tensor b, + bool transpose_a = false, bool transpose_b = false, + bool adjoint_a = false, bool adjoint_b = false, + bool a_is_sparse = false, bool b_is_sparse = false, + string name = null) + => tf_with(ops.name_scope(name, "MatMul", (a, b)), scope => + { + name = scope; + + if (transpose_a && adjoint_a) + throw new ValueError("Only one of transpose_a and adjoint_a can be True."); + if (transpose_b && adjoint_b) + throw new ValueError("Only one of transpose_b and adjoint_b can be True."); + + if(adjoint_a) + { + a = conj(a); + transpose_a = true; + } + + if (adjoint_b) + { + b = conj(b); + transpose_b = true; + } + + return tf.Context.ExecuteOp("MatMul", name, new ExecuteOpArgs(a, b) + .SetAttributes(new { transpose_a, transpose_b })); + }); + + public static Tensor batch_matmul(Tensor x, Tensor y, + bool adj_x = false, bool adj_y = false, + string name = null) + => tf_with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope => + { + name = scope; + + x = ops.convert_to_tensor(x, name: "a"); + y = ops.convert_to_tensor(y, name: "b"); + + return tf.Context.ExecuteOp("BatchMatMul", name, new ExecuteOpArgs(x, y) + .SetAttributes(new { adj_x, adj_y })); + }); + + public static Tensor count_nonzero_v2(Tensor input, + Axis? axis, + bool keepdims = false, + string name = null, + TF_DataType dtype = TF_DataType.TF_INT64) + => tf_with(ops.name_scope(name, "count_nonzero", input), scope => + { + name = scope; + var zero = array_ops.zeros(Shape.Scalar, dtype: input.dtype); + return reduce_sum(cast(gen_math_ops.not_equal(input, zero), dtype), axis: axis, keepdims: keepdims); + }); + + public static Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + Shape axis = null, + bool binary_output = false) + => tf_with(ops.name_scope(name, "bincount"), scope => + { + name = scope; + if(!binary_output && axis == null) + { + var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0; + var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1); + if (minlength != null) + output_size = math_ops.maximum(minlength, output_size); + if (maxlength != null) + output_size = math_ops.minimum(maxlength, output_size); + weights = weights ?? constant_op.constant(new int[0], dtype: dtype); + return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights)); + } + else + { + var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0; + var output_size = math_ops.cast(array_is_nonempty, arr.dtype) * (math_ops.reduce_max(arr) + 1); + if (minlength != null) + output_size = math_ops.maximum(minlength, output_size); + if (maxlength != null) + output_size = math_ops.minimum(maxlength, output_size); + weights = weights ?? array_ops.constant(new int[0], dtype: dtype); + + return tf.Context.ExecuteOp("DenseBincount", name, + new ExecuteOpArgs(arr, output_size, weights, binary_output) + .SetAttributes(new { binary_output })); + } + + throw new NotImplementedException(""); + }); + + /// + /// Returns the complex conjugate of a complex number. + /// + /// `Tensor` to conjugate. Must have numeric or variant type. + /// A name for the operation (optional). + /// A `Tensor` that is the conjugate of `x` (with the same type). + public static Tensor conj(Tensor x, string name = null) + { + var dt = x.dtype; + if (dt.is_floating() || dt.is_integer()) + return x; + + return tf_with(ops.name_scope(name, "Conj", new List { x }), scope => + { + + return x; + }); + } + + public static Tensor tanh(Tensor x, string name = null) + => gen_math_ops.tanh(x, name); + + public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = null) + { + return tf_with(ops.name_scope(name, "Tensordot", new { a, b, axes }), scope => + { + name = scope; + var (a_axes, b_axes) = _tensordot_axes(a, axes); + var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); + var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); + var ab_matmul = matmul(a_reshape, b_reshape); + if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list) + { + var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray(); + if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims)) + { + return ab_matmul; + } + else + { + return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name); + } + } + else + { + var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32); + var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32); + var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name); + if(a_free_dims_static is not null && b_free_dims_static is not null) + { + product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray()); + } + return product; + } + }); + } + + static (int[], int[]) _tensordot_axes(Tensor a, NDArray axes) + { + if (axes.rank == 0) + { + int axe = axes; + if (axe > a.shape.ndim) + throw new ValueError("`axes` must not be larger than the number of " + + $"dimensions of tensor {a}. Received {axes}, vs " + + $"tensor dimensions {a.ndim}."); + return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), + Binding.range(0, axe).ToArray()); + } + else if(axes.rank == 1) + { + if (axes.shape[0] != 2) + { + throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); + } + (int a_axe, int b_axe) = (axes[0], axes[1]); + return (new[] { a_axe }, new[] { b_axe }); + } + else if(axes.rank == 2) + { + if (axes.shape[0] != 2) + { + throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}."); + } + int[] a_axes = new int[axes.shape[1]]; + int[] b_axes = new int[axes.shape[1]]; + for(int i = 0; i < a_axes.Length; i++) + { + a_axes[i] = axes[0, i]; + b_axes[i] = axes[1, i]; + if (a_axes[i] == -1 || b_axes[i] == -1) + { + throw new ValueError($"Different number of contraction axes `a` and `b`," + + $"{len(a_axes)} != {len(b_axes)}."); + } + } + return (a_axes, b_axes); + } + else + { + throw new ValueError($"Invalid rank {axes.rank} to make tensor dot."); + } + } + + static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) + { + if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) + { + var shape_a = a.shape.as_int_list(); + + // axes + axes = axes.Select(i => i >= 0 ? i : i + len(shape_a)).ToArray(); + + // free + int[] free = Binding.range(a.shape.ndim).Where(i => !axes.Contains(i)).ToArray(); + + // free_dims + int[] free_dims = free.Select(i => shape_a[i]).ToArray(); + + int prod_free = np.prod(free_dims); + + // prod_axes + int prod_axes = np.prod(axes.Select(i => shape_a[i]).ToArray()); + + // perm + List perm = new List(); + if (flipped) + { + perm.AddRange(axes); + perm.AddRange(free); + } + else + { + perm.AddRange(free); + perm.AddRange(axes); + } + + // new_shape + Shape new_shape; + if (flipped) + new_shape = new Shape(new int[] { prod_axes, prod_free }); + else + new_shape = new Shape(new int[] { prod_free, prod_axes }); + var a_trans = a; + var reshaped_a = array_ops.reshape(a_trans, new_shape); + return (reshaped_a, free_dims, free_dims); + } + else + { + int[] free_dims_static; + Tensor converted_shape_a, converted_axes, converted_free; + if (a.shape.ndim != -1) + { + var shape_a = a.shape.as_int_list(); + for(int i = 0; i < axes.Length; i++) + { + if (axes[i] < 0) + { + axes[i] += shape_a.Length; + } + } + var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray(); + + var axes_dims = axes.Select(i => shape_a[i]); + var free_dims = free.Select(i => shape_a[i]).ToArray(); + free_dims_static = free_dims; + converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); + converted_free = ops.convert_to_tensor(free, dtypes.int32, "free"); + converted_shape_a = array_ops.shape(a); + } + else + { + free_dims_static = null; + converted_shape_a = array_ops.shape(a); + var rank_a = array_ops.rank(a); + converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes"); + converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a); + (converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)), + converted_axes, dtypes.int32); + } + var converted_free_dims = array_ops.gather(converted_shape_a, converted_free); + var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes); + var prod_free_dims = reduce_prod(converted_free_dims); + var prod_axes_dims = reduce_prod(converted_axes_dims); + Tensor reshaped_a; + if (flipped) + { + var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0); + var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims }); + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); + } + else + { + var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0); + var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims }); + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape); + } + return (reshaped_a, converted_free_dims, free_dims_static); + } + + throw new NotImplementedException("_tensordot_reshape"); + } + + + public static Tensor truediv(Tensor x, Tensor y, string name = null) + => _truediv_python3(x, y, name); + + public static Tensor _truediv_python3(Tensor x, Tensor y, string name = null) + { + return tf_with(ops.name_scope(name, "truediv", new { x, y }), scope => + { + name = scope; + var x_dtype = x.dtype.as_base_dtype(); + var y_dtype = y.dtype.as_base_dtype(); + + if (x_dtype != y_dtype) + throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}"); + + var dtype = x_dtype switch + { + TF_DataType.TF_UINT8 => TF_DataType.TF_FLOAT, + TF_DataType.TF_INT8 => TF_DataType.TF_FLOAT, + TF_DataType.TF_INT16 => TF_DataType.TF_FLOAT, + TF_DataType.TF_UINT16 => TF_DataType.TF_FLOAT, + TF_DataType.TF_INT32 => TF_DataType.TF_DOUBLE, + TF_DataType.TF_INT64 => TF_DataType.TF_DOUBLE, + _ => x_dtype + }; + x = cast(x, dtype); + y = cast(y, dtype); + + return gen_math_ops.real_div(x, y, name: name); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs new file mode 100644 index 000000000..ca4b885f7 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -0,0 +1,258 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class nn_impl + { + public static Tensor conv2d_transpose(Tensor value = null, + IVariableV1 filter = null, + Tensor output_shape = null, + Shape strides = null, + string padding = "SAME", + string data_format = "NHWC", + string name = null, + Shape dilations = null) + { + if (dilations == null) + dilations = (1, 1, 1, 1); + return tf_with(ops.name_scope(name, "conv2d_transpose", new { value, filter, output_shape }), scope => + { + return gen_nn_ops.conv2d_backprop_input( + input_sizes: output_shape, + filter: filter.AsTensor(), + out_backprop: value, + strides: strides, + padding: padding, + data_format: data_format, + dilations: dilations, + name: name); + }); + } + + /// + /// Normalizes along dimension `axis` using an L2 norm. + /// + /// + /// + /// + /// + /// + public static Tensor l2_normalize(Tensor x, + int axis = 0, + Tensor epsilon =null, + string name = null) + { + return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope => + { + x = ops.convert_to_tensor(x, name: "x"); + var sq = math_ops.square(x); + var square_sum = math_ops.reduce_sum(sq, axis: constant_op.constant(axis), keepdims: true); + var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon == null ? tf.Variable(1e-12f) : epsilon)); + return math_ops.multiply(x, x_inv_norm, name: name); + }); + } + + /// + /// Calculate the mean and variance of `x` + /// + /// A `Tensor`. + /// Array of ints. Axes along which to compute mean and variance. + /// Name used to scope the operations that compute the moments. + /// Produce moments with the same dimensionality as the input. + /// Two `Tensor` objects: `mean` and `variance`. + public static (Tensor, Tensor) moments(Tensor x, + Axis axes, + string name = null, + bool keep_dims = false) + { + return tf_with(ops.name_scope(name, "moments", new { x, axes }), scope => + { + // The dynamic range of fp16 is too limited to support the collection of + // sufficient statistics. As a workaround we simply perform the operations + // on 32-bit floats before converting the mean and variance back to fp16 + var y = math_ops.cast(x, TF_DataType.TF_FLOAT); + // Compute true mean while keeping the dims for proper broadcasting. + var mean = math_ops.reduce_mean(y, axes, true, name = "mean"); + // Sample variance, not unbiased variance + // Note: stop_gradient does not change the gradient that gets + // backpropagated to the mean from the variance calculation, + // because that gradient is zero + var variance = math_ops.reduce_mean(math_ops.square_difference(y, array_ops.stop_gradient(mean)), axes, true, name = "Variance"); + if (!keep_dims) + { + mean = array_ops.squeeze(mean, axes); + variance = array_ops.squeeze(variance, axes); + } + // TODO: if x.dtype == dtypes.float16: + if (x.dtype == TF_DataType.TF_HALF) + return (math_ops.cast(mean, x.dtype), math_ops.cast(variance, x.dtype)); + else + return (mean, variance); + }); + } + + public static Tensor normalize(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null) + { + return tf_with(ops.name_scope(name, "normalize", tensor), scope => + { + var norm = tf.linalg.norm(tensor, ord: ord, axis: axis, name: name); + var normalized = tensor / norm; + return normalized; + }); + } + + public static Tensor batch_normalization(Tensor x, + Tensor mean, + Tensor variance, + Tensor offset, + Tensor scale, + float variance_epsilon = 0.001f, + string name = null) + { + return tf_with(ops.name_scope(name, "batchnorm", new { x, mean, variance, scale, offset }), scope => + { + var inv = math_ops.rsqrt(variance + variance_epsilon); + inv *= scale; + return x * math_ops.cast(inv, x.dtype) + math_ops.cast( + offset == null ? (-mean * inv) : (offset - mean * inv), x.dtype); + }); + } + + /// + /// Batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] fused_batch_norm(Tensor x, + Tensor scale, + Tensor offset, + Tensor mean = null, + Tensor variance = null, + float epsilon = 0.001f, + string data_format = "NHWC", + bool is_training = true, + string name = null, + float exponential_avg_factor = 1.0f) + { + mean = mean ?? constant_op.constant(new float[0]); + variance = variance ?? constant_op.constant(new float[0]); + var min_epsilon = 1.001e-5f; + epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; + + var results = gen_nn_ops.fused_batch_norm_v3(x, + scale, + offset, + mean, + variance, + epsilon: epsilon, + exponential_avg_factor: exponential_avg_factor, + data_format: data_format, + is_training: is_training, + name: name); + + var y = results[0]; + var running_mean = results[1]; + var running_var = results[2]; + + return new[] { y, running_mean, running_var }; + } + + /// + /// Same as math_ops.count_nonzero. + /// The reduction is done in dtype, which can be faster for 32-bit dtypes. + /// + /// The numeric tensor. + /// The reduction dtype. + /// number of nonzero values with type dtype + private static Tensor _count_nonzero(Tensor input_tensor, TF_DataType dtype = TF_DataType.TF_INT64) + { + return tf_with(ops.name_scope("count_nonzero", "count_nonzero", new { input_tensor }), scope => + { + var zero = array_ops.zeros(Shape.Null, dtype: input_tensor.dtype); + var nonzero_count = math_ops.reduce_sum( + math_ops.cast(gen_math_ops.not_equal(input_tensor, zero), dtype: dtype), name: "nonzero_count"); + return nonzero_count; + }); + } + + public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) + { + return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope => + { + name = scope; + logits = ops.convert_to_tensor(logits, name: "logits"); + labels = ops.convert_to_tensor(labels, name: "labels"); + labels.shape.merge_with(logits.shape); + + var zeros = array_ops.zeros_like(logits, dtype: logits.dtype); + var cond = (logits >= zeros); + var relu_logits = array_ops.where(cond, logits, zeros); + var neg_abs_logits = array_ops.where(cond, -logits, logits); + + return math_ops.add( + relu_logits - logits * labels, + gen_math_ops.log1p(gen_math_ops.exp(neg_abs_logits)), + name: name); + }); + } + + /// + /// Returns the fraction of zeros in value. + /// + /// A tensor of numeric type. + /// A name for the operation (optional). + /// The fraction of zeros in value, with type float32. + public static Tensor zero_fraction(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "zero_fraction", new { value }), scope => + { + value = ops.convert_to_tensor(value, name: "value"); + Tensor size = array_ops.size(value, out_type: dtypes.int64); + Tensor zero_fraction_float32 = null; + + size = gen_math_ops.less_equal(size, ops.convert_to_tensor(dtypes.int32.max())); + Tensor num_nonzero = control_flow_ops.cond( + size, + () => math_ops.cast(_count_nonzero(value, dtype: dtypes.int32), TF_DataType.TF_INT64), + () => _count_nonzero(value, dtype: dtypes.int64) + ); + + tf_with(ops.name_scope("counts_to_fraction"), count_scope => + { + var num_zero = math_ops.subtract(math_ops.cast(size, TF_DataType.TF_INT64), num_nonzero); + var num_zero_float32 = math_ops.cast(num_zero, dtype: dtypes.float32); + var size_float32 = math_ops.cast(size, dtype: dtypes.float32); + zero_fraction_float32 = num_zero_float32 / size_float32; + }); + + return array_ops.identity(zero_fraction_float32, "fraction"); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs new file mode 100644 index 000000000..00d7d316b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -0,0 +1,325 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class nn_ops + { + public static ConvolutionInternal convolution_internal(string padding, + int[] strides, + int[] dilation_rate, + int rank, + string name = null, + string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs + { + Rank = rank, + Padding = padding, + Strides = strides, + DilationRate = dilation_rate, + DataFormat = data_format, + Name = name + }); + + /// + /// Adds `bias` to `value`. + /// + /// + /// + /// + /// + /// + public static Tensor bias_add(Tensor value, + IVariableV1 bias, + string data_format = null, + string name = null) + { + return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => + { + name = scope; + return gen_nn_ops.bias_add(value, ops.convert_to_tensor(bias), data_format: data_format, name: name); + }); + } + + /// + /// Computes dropout. + /// + /// + /// + /// + /// + /// + /// + public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) + { + return tf_with(ops.name_scope(name, "dropout", x), scope => + { + name = scope; + x = ops.convert_to_tensor(x, name: "x"); + if (!x.dtype.is_floating()) + throw new NotImplementedException($"x has to be a floating point tensor since it's going to" + + $" be scaled. Got a {x.dtype} tensor instead."); + + var keep_prob = 1 - rate; + var scale = 1 / keep_prob; + var scale_tensor = ops.convert_to_tensor(scale, dtype: x.dtype); + var ret = gen_math_ops.mul(x, scale_tensor); + + noise_shape = _get_noise_shape(x, noise_shape); + + // Sample a uniform distribution on [0.0, 1.0) and select values larger than + // rate. + // + // NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0) + // and subtract 1.0. + var random_tensor = random_ops.random_uniform(noise_shape, seed: seed, dtype: x.dtype); + // NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that + // float to be selected, hence we use a >= comparison. + var keep_mask = random_tensor >= rate; + ret = x * scale * math_ops.cast(keep_mask, x.dtype); + if (!tf.executing_eagerly()) + ret.shape = x.shape; + return ret; + }); + } + + private static Tensor _get_noise_shape(Tensor x, Tensor noise_shape) + { + if (noise_shape == null) + return array_ops.shape(x); + else + return noise_shape; + } + + public static Tensors top_kv2(Tensor input, int k, bool sorted = true, string name = null) + => tf.Context.ExecuteOp("TopKV2", name, new ExecuteOpArgs(input, k) + .SetAttributes(new { sorted })); + + public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null) + { + return tf_with(ops.name_scope(name, "in_top_k"), delegate + { + return gen_nn_ops.in_top_kv2(predictions, targets, ops.convert_to_tensor(k), name: name); + }); + } + + public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null) + { + return _softmax(logits, gen_nn_ops.log_softmax, axis, name); + } + + /// equivalent to `dim` + public static Tensor softmax(Tensor logits, int axis = -1, string name = null) + { + return _softmax(logits, gen_nn_ops.softmax, axis, name); + } + + public static Tensor softplus(Tensor features, string name = null) + => tf.Context.ExecuteOp("Softplus", name, new ExecuteOpArgs(features)); + + public static Tensor l2_loss(Tensor t, string name = null) + => tf.Context.ExecuteOp("L2Loss", name, new ExecuteOpArgs(t)); + + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + { + return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => + { + name = scope; + features = ops.convert_to_tensor(features, name: "features"); + if (features.dtype.is_integer()) + features = math_ops.cast(features, dtypes.float32); + return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name); + //return math_ops.maximum(alpha * features, features, name: name); + }); + } + + /// + /// Performs the max pooling on the input. + /// + /// A 4-D `Tensor` of the format specified by `data_format`. + /// + /// A list or tuple of 4 ints. The size of the window for each dimension + /// of the input tensor. + /// + /// + /// A list or tuple of 4 ints. The stride of the sliding window for + /// each dimension of the input tensor. + /// + /// A string, either `'VALID'` or `'SAME'`. The padding algorithm. + /// A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported. + /// Optional name for the operation. + /// + public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) + { + return tf_with(ops.name_scope(name, "MaxPool", value), scope => + { + name = scope; + value = ops.convert_to_tensor(value, name: "input"); + return gen_nn_ops.max_pool( + value, + ksize: ksize, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + }); + } + + public static Tensor _softmax(Tensor logits, Func compute_op, int dim = -1, string name = null) + { + logits = ops.convert_to_tensor(logits); + + var shape = logits.shape; + bool is_last_dim = dim == -1 || dim == shape.ndim - 1; + if (is_last_dim) + return compute_op(logits, name); + + throw new NotImplementedException("_softmax helper"); + } + + /// + /// Computes sparse softmax cross entropy between `logits` and `labels`. + /// + /// + /// + /// + /// + public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null, + Tensor logits = null, string name = null) + { + // Reshape logits and labels to rank 2. + return tf_with(ops.name_scope(name, default_name: "SparseSoftmaxCrossEntropyWithLogits", (labels, logits)), delegate + { + labels = ops.convert_to_tensor(labels); + logits = ops.convert_to_tensor(logits); + var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits; + + // Store label shape for result later. + var labels_static_shape = labels.shape; + var labels_shape = array_ops.shape(labels); + /*bool static_shapes_fully_defined = ( + labels_static_shape.is_fully_defined() && + logits.get_shape()[:-1].is_fully_defined());*/ + + // Check if no reshapes are required. + if (logits.shape.ndim == 2) + { + var cost = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( + precise_logits, labels, name: name)[0]; + if (logits.dtype == dtypes.float16) + return math_ops.cast(cost, dtypes.float32); + else + return cost; + } + + // Perform a check of the dynamic shapes if the static shapes are not fully + // defined. + throw new NotImplementedException("sparse_softmax_cross_entropy_with_logits"); + }); + } + + public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels, + Tensor logits, + int axis = -1, + string name = null) + { + return tf_with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { logits, labels }), scope => + { + name = scope; + var precise_logits = logits; + var input_rank = array_ops.rank(precise_logits); + var shape = logits.shape; + + if (axis != -1) + throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); + + var input_shape = array_ops.shape(precise_logits); + + // Make precise_logits and labels into matrices. + precise_logits = _flatten_outer_dims(precise_logits); + labels = _flatten_outer_dims(labels); + + // Do the actual op computation. + // The second output tensor contains the gradients. We use it in + // _CrossEntropyGrad() in nn_grad but not here. + + var entropy = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name); + var (cost, unused_backprop) = (entropy[0], entropy[1]); + + // The output cost shape should be the input minus axis. + var output_shape = array_ops.slice(input_shape, + new Tensor[] { constant_op.constant(0) }, + new Tensor[] { math_ops.subtract(input_rank, 1) }); + + cost = array_ops.reshape(cost, output_shape); + + return cost; + }); + } + + /// + /// Flattens logits' outer dimensions and keep its last dimension. + /// + /// + /// + private static Tensor _flatten_outer_dims(Tensor logits) + { + var rank = array_ops.rank(logits); + var last_dim_size = array_ops.slice(array_ops.shape(logits), + new[] { math_ops.subtract(rank, 1) }, + new[] { constant_op.constant(1) }); + + var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0); + var output = array_ops.reshape(logits, ops); + + // Set output shape if known. + if (!tf.Context.executing_eagerly()) + { + var shape = logits.shape; + if (shape != null && shape.ndim > 0) + { + var product = 1L; + var product_valid = true; + foreach (var d in shape.dims.Take(shape.ndim - 1)) + { + if (d == -1) + { + product_valid = false; + break; + } + else + { + product *= d; + } + } + + if (product_valid) + { + var output_shape = new[] { product }; + throw new NotImplementedException("_flatten_outer_dims product_valid"); + } + } + } + + return output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs deleted file mode 100644 index f7eb38535..000000000 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ /dev/null @@ -1,97 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; -using System.Threading; -using Tensorflow; -using node_def_pb2 = Tensorflow; -using Google.Protobuf; - -namespace Tensorflow -{ - public static class ops - { - public static Graph get_default_graph() - { - return tf.Graph(); - } - - public static Tensor convert_to_tensor() - { - return internal_convert_to_tensor(); - } - - private static Tensor internal_convert_to_tensor() - { - return null; - } - - - - public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) - { - var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); - - // Add inputs - if(inputs != null) - { - foreach (var op_input in inputs) - { - c_api.TF_AddInput(op_desc, op_input._as_tf_output()); - } - } - - var status = new Status(); - - // Add control inputs - - // Add attrs - foreach (var attr in node_def.Attr) - { - var bytes = attr.Value.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); - Marshal.Copy(bytes, 0, proto, bytes.Length); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); - - if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); - } - - var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); - - if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); - - return c_op; - } - - public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary attrs = null) - { - var node_def = new node_def_pb2.NodeDef(); - node_def.Op = op_type; - node_def.Name = name; - - foreach (var attr in attrs) - { - node_def.Attr.Add(attr.Key, attr.Value); - } - - return node_def; - } - - public static string _name_from_scope_name(string name) - { - if (name.EndsWith("/")) - { - return name.Substring(0, name.Length - 1); - } - else - { - return name; - } - } - - public static int uid() - { - return 1; - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/random_ops.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs new file mode 100644 index 000000000..dddcc05a1 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/random_ops.cs @@ -0,0 +1,205 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class random_ops + { + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "random_normal", new { shape, mean, stddev }), scope => + { + name = scope; + var shape_tensor = _ShapeTensor(shape); + var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); + var (seed1, seed2) = random_seed.get_seed(seed); + var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); + var mul = rnd * stddev_tensor; + var value = math_ops.add(mul, mean_tensor, name: name); + // tensor_util.maybe_set_static_shape(value, shape) + return value; + }); + } + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// The type of the output + /// Used to create a random seed for the distribution. + /// A name for the operation + /// A tensor of the specified shape filled with random uniform values. + public static Tensor random_uniform(int[] shape, + float minval = 0, + float maxval = 1, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => + { + name = scope; + var (seed1, seed2) = random_seed.get_seed(seed); + var tensorShape = tensor_util.shape_tensor(shape); + var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); + var rnd = gen_random_ops.random_uniform(tensorShape, dtype, seed: seed1, seed2: seed2); + return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name); + }); + } + + /// + /// Outputs random values from a uniform distribution. + /// + /// + /// + /// + /// The type of the output + /// Used to create a random seed for the distribution. + /// A name for the operation + /// A tensor of the specified shape filled with random uniform values. + public static Tensor random_uniform_int(int[] shape, + int minval = 0, + int maxval = 1, + int? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "random_uniform_int", new { shape, minval, maxval }), scope => + { + name = scope; + var (seed1, seed2) = random_seed.get_seed(seed); + var tensorShape = tensor_util.shape_tensor(shape); + var minTensor = ops.convert_to_tensor(minval, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval, name: "max"); + return gen_random_ops.random_uniform_int(tensorShape, minTensor, maxTensor, seed: seed1, seed2: seed2); + }); + } + + public static Tensor random_uniform(Tensor shape, + int minval = 0, + Tensor maxval = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => + { + name = scope; + var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval == null ? 1 : (int)maxval, dtype: dtype, name: "max"); + var (seed1, seed2) = random_seed.get_seed(seed); + if (dtype.is_integer()) + { + return gen_random_ops.random_uniform_int(shape, minTensor, maxTensor, seed: seed1, seed2: seed2, name: name); + } + else + { + var rnd = gen_random_ops.random_uniform(shape, dtype); + return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name); + } + }); + } + + /// + /// Randomly shuffles a tensor along its first dimension. + /// + /// + /// + /// + /// + public static Tensor random_shuffle(Tensor value, int? seed = null, string name = null) + { + var (seed1, seed2) = random_seed.get_seed(seed); + return gen_random_ops.random_shuffle(value, seed: seed1, seed2: seed2, name: name); + } + + public static Tensor truncated_normal(int[] shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "truncated_normal", new { shape, mean, stddev }), scope => + { + name = scope; + var shape_tensor = _ShapeTensor(shape); + var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); + var (seed1, seed2) = random_seed.get_seed(seed); + var rnd = gen_random_ops.truncated_normal(shape_tensor, dtype, seed: seed1, seed2: seed2); + var mul = rnd * stddev_tensor; + var value = math_ops.add(mul, mean_tensor, name: name); + return value; + }); + } + + private static Tensor _ShapeTensor(int[] shape) + { + return ops.convert_to_tensor(shape, name: "shape"); + } + + public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, + string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate + { + return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); + }); + } + + /// + /// Implementation for random.categorical (v1) and random.categorical (v2). + /// + /// + /// + /// + /// + /// + private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, + int? seed = null) + { + logits = ops.convert_to_tensor(logits, name: "logits"); + var (seed1, seed2) = random_seed.get_seed(seed); + return gen_random_ops.multinomial(logits, + num_samples, + seed: seed1, + seed2: seed2, + output_dtype: dtype); + } + } +} + diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs new file mode 100644 index 000000000..c06e822d2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -0,0 +1,320 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Variables; +using static Tensorflow.CppShapeInferenceResult.Types; +using static Tensorflow.Binding; +using Tensorflow.Operations; +using System.Buffers; +using Tensorflow.Eager; +using Tensorflow.Graphs; + +namespace Tensorflow +{ + /// + /// tensorflow\python\ops\resource_variable_ops.py + /// + public static class resource_variable_ops + { + public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) + { + // TODO(Rinne): deal with `_handle_graph`. + var value_tensor = ops.convert_to_tensor(value); + return gen_resource_variable_ops.assign_variable_op(handle, + value_tensor, + name: name); + } + + public static bool is_resource_variable(object var) + { + return var is BaseResourceVariable; + } + + /// + /// Creates a variable handle with information to do shape inference. + /// + /// + /// + /// + /// + /// + /// + public static Tensor eager_safe_variable_handle(Tensor initial_value, Shape shape, + string shared_name, string name, bool graph_mode) + { + var dtype = initial_value.dtype.as_base_dtype(); + return variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, graph_mode, initial_value); + } + + /// + /// Create a new variable handle, optionally copying in `extra_handle_data` + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor variable_handle_from_shape_and_dtype(Shape shape, TF_DataType dtype, + string shared_name, string name, bool graph_mode, Tensor initial_value = null) + { + var container = ops.get_default_graph().Container; + if(container is null) + { + container = ""; + } + if (!graph_mode) + { + if(shared_name is not null) + { + throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); + } + shared_name = tf.Context.anonymous_name(); + } + var handle = gen_resource_variable_ops.var_handle_op(shape: shape, + dtype: dtype, + shared_name: shared_name, + name: name, + container: container); + + if (initial_value == null) + initial_value = handle; + + if (graph_mode) + { + var full_handle_data = _combine_handle_data(handle, initial_value); + _set_handle_shapes_and_types(handle, full_handle_data, graph_mode); + return handle; + } + else + { + var handle_data = handle_data_util.create_handle_data(shape, dtype); + if (initial_value is not null && initial_value.dtype == dtypes.variant) + { + var extra_handle_data = get_eager_safe_handle_data(initial_value); + if (extra_handle_data is not null && extra_handle_data.IsSet) + { + if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + + $"but saw: '{handle_data}'"); + } + handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); + } + } + _set_handle_shapes_and_types(handle, handle_data, graph_mode); + return handle; + } + } + + /// + /// Sets the shape inference result HandleData on tensor. + /// + /// + /// + /// + internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) + { + if (!graph_mode) + return; + + var size = handle_data.ShapeAndType.Count; + + var shapes = new IntPtr[size]; + var types = new DataType[size]; + var ranks = new int[size]; + + for (int i = 0; i < size; i++) + { + var shapeAndType = handle_data.ShapeAndType[i]; + types[i] = shapeAndType.Dtype; + ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; + var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); + } + + //tensor.HandleData = handle_data; + //if (!graph_mode) + // return; + + //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); + //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); + //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); + //var converted_shapes = shapes.Select>(s => + //{ + // if (!s.UnknownRank) + // { + // return s.Dim.Select(d => (int)d.Size).ToArray(); + // } + // else + // { + // return Memory.Empty; + // } + //}).ToArray(); + + //List handles = new(); + //IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; + //foreach(var (i, m) in enumerate(converted_shapes)) + //{ + // if(m.IsEmpty) + // { + // shapes_with_ptr[i] = IntPtr.Zero; + // } + // else + // { + // var handle = m.Pin(); + // handles.Add(handle); + // shapes_with_ptr[i] = new IntPtr(handle.Pointer); + // } + //} + + //Status status = new(); + //// TODO(Rinne): enable it. + //c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), + // shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); + //handles = null; + } + + /// + /// Concats HandleData from tensors `handle` and `initial_value`. + /// + /// + /// + /// + private static HandleData _combine_handle_data(Tensor handle, Tensor initial_value) + { + var variable_handle_data = get_eager_safe_handle_data(initial_value); + + if (initial_value.dtype != dtypes.variant) + return variable_handle_data; + + throw new NotImplementedException(""); + } + + /// + /// Copies an existing variable to a new graph, with no initializer. + /// + /// + public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) + { + var new_variable = new UninitializedVariable( + trainable: variable.Trainable, + shape: variable.shape, + dtype: variable.dtype, + name: variable.SharedName, + aggregation: variable.Aggregation, + extra_handle_data: null); + new_variable._maybe_initialize_trackable(); + return new_variable; + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// + /// + /// + /// + /// + public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) + { + // lack of API: `proto.Variable.SetInParent()`. + if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) + { + throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + + $"unexpected suffix in the name (expected ':0') which won't be restored."); + } + if(proto.Variable is null) + { + proto.Variable = new SavedVariable(); + } + proto.Variable.Name = meta_graph.op_name(resource_variable.Name); + proto.Variable.Trainable = resource_variable.Trainable; + proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); + // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. + proto.Variable.Aggregation = resource_variable.Aggregation; + proto.Variable.Shape = resource_variable.shape.as_proto(); + + if (options.experimental_variable_policy.save_variable_devices()) + { + if (!string.IsNullOrEmpty(resource_variable.Device)) + { + proto.Variable.Device = resource_variable.Device; + } + } + } + + public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) + { + if(dtype == dtypes.variant) + { + var handle_data = get_eager_safe_handle_data(handle); + if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) + { + tensor.HandleData = new HandleData() + { + IsSet = true + }; + tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); + } + } + } + + public static HandleData get_eager_safe_handle_data(Tensor handle) + { + if (handle.Handle == null) + { + var data = new HandleData(); + data.ShapeAndType.Add(new HandleShapeAndType + { + Shape = handle.shape.as_shape_proto(), + Dtype = handle.dtype.as_datatype_enum() + }); + return data; + } + else + { + return HandleData.Parser.ParseFrom(handle.BufferToArray()); + } + //if(handle is EagerTensor) + //{ + // return handle.HandleData; + //} + //else + //{ + // return handle_data_util.get_resource_handle_data(handle); + //} + } + + public static void variable_accessed(IVariableV1 variable) + { + if (ops.get_default_graph() is FuncGraph func_graph) + { + func_graph.watch_variable(variable); + } + if (variable.Trainable) + { + foreach (var tape in tf.GetTapeSet()) + tape.VariableAccessed(variable); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/sort_ops.cs b/src/TensorFlowNET.Core/Operations/sort_ops.cs new file mode 100644 index 000000000..db38a073b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/sort_ops.cs @@ -0,0 +1,82 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class sort_ops + { + public static Tensor argsort(Tensor values, Axis axis = null, string direction = "ASCENDING", bool stable = false, string name = null) + { + axis = axis ?? new Axis(-1); + var k = array_ops.shape(values)[axis]; + values = -values; + var static_rank = values.shape.ndim; + var top_k_input = values; + if (axis == -1 || axis + 1 == values.shape.ndim) + { + } + else + { + if (axis == 0 && static_rank == 2) + top_k_input = array_ops.transpose(values, new[] { 1, 0 }); + else + throw new NotImplementedException(""); + } + + var (_, indices) = tf.Context.ExecuteOp("TopKV2", name, + new ExecuteOpArgs(top_k_input, k).SetAttributes(new + { + sorted = true + })); + return indices.Single; + } + + public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) + { + var k = array_ops.shape(values)[axis]; + values = -values; + var static_rank = values.shape.ndim; + var top_k_input = values; + if (axis == -1 || axis + 1 == values.shape.ndim) + { + } + else + { + if (axis == 0 && static_rank == 2) + top_k_input = array_ops.transpose(values, new[] { 1, 0 }); + else + throw new NotImplementedException(""); + } + + (values, _) = tf.Context.ExecuteOp("TopKV2", name, + new ExecuteOpArgs(top_k_input, k).SetAttributes(new + { + sorted = true + })); + return -values; + } + + public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixInverse", name, + new ExecuteOpArgs(input).SetAttributes(new + { + adjoint + })); + } +} diff --git a/src/TensorFlowNET.Core/Operations/sparse_ops.cs b/src/TensorFlowNET.Core/Operations/sparse_ops.cs new file mode 100644 index 000000000..37a54f59b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/sparse_ops.cs @@ -0,0 +1,29 @@ +namespace Tensorflow +{ + public class sparse_ops + { + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`. + public Tensor sparse_to_dense(Tensor sparse_indices, + int[] output_shape, + T sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sparse_indices, + output_shape, + sparse_values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/Operations/stateless_random_ops.cs b/src/TensorFlowNET.Core/Operations/stateless_random_ops.cs new file mode 100644 index 000000000..e9718770c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/stateless_random_ops.cs @@ -0,0 +1,62 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.ApiDef.Types; +using System.Reflection; +using static Tensorflow.Binding; +using System; + +namespace Tensorflow; + +public class stateless_random_ops +{ + public static Tensor stateless_random_normal(Shape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int[]? seed = null, + string name = null) + { + return tf_with(ops.name_scope(name, "stateless_random_normal", new { shape, seed, mean, stddev }), scope => + { + name = scope; + var shape_tensor = _ShapeTensor(shape); + var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); + + if (seed == null) + { + seed = new[] { new Random().Next(), 0 }; + } + var (key, counter) = _get_key_counter(seed, 3); + var rnd = gen_random_ops.stateless_random_normal_v2(shape: shape_tensor, key: key, counter: counter, dtype: dtype, alg: 3); + var value = math_ops.add(rnd * stddev, mean_tensor, name: name); + // tensor_util.maybe_set_static_shape(value, shape) + return value; + }); + } + + private static Tensor _ShapeTensor(int[] shape) + { + return ops.convert_to_tensor(shape, name: "shape"); + } + + private static (Tensor, Tensor) _get_key_counter(int[] seed, int alg) + { + var results = gen_random_ops.stateless_random_get_key_counter(seed); + return (results[0], results[1]); + } +} diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs new file mode 100644 index 000000000..1e50c4ad0 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -0,0 +1,155 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class string_ops + { + public Tensor lower(Tensor input, string encoding = "", string name = null) + => tf.Context.ExecuteOp("StringLower", name, new ExecuteOpArgs(input, encoding)); + + public Tensor regex_replace(Tensor input, string pattern, string rewrite, + bool replace_global = true, string name = null) + => tf.Context.ExecuteOp("StaticRegexReplace", name, new ExecuteOpArgs(input) + .SetAttributes(new { pattern, rewrite, replace_global })); + + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// + public Tensor substr(T input, int pos, int len, + string @uint = "BYTE", string name = null) + => tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len) + .SetAttributes(new { unit = @uint })); + + /// + /// Computes the length of each string given in the input tensor. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => tf.Context.ExecuteOp("StringLength", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = op => new + { + unit = op.get_attr("unit") + } + }.SetAttributes(new { unit })); + + public Tensor string_format(Tensor[] inputs, string template = "%s", string placeholder = "%s", int summarize = 3, string name = null) + => tf.Context.ExecuteOp("StringFormat", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { inputs }, + GetGradientAttrs = op => new + { + T = op.get_attr("T"), + template = op.get_attr("template"), + placeholder = op.get_attr("placeholder"), + summarize = op.get_attr("summarize") + } + }.SetAttributes(new { template, placeholder, summarize })); + + public RaggedTensor string_split_v2(Tensor input, string sep = " ", int maxsplit = -1, string name = null) + { + return tf_with(ops.name_scope(name, "StringSplit"), scope => + { + var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); + if(input.rank == 0) + { + var parts = string_split_v2(array_ops.stack(new[] { input }), + sep: sep, + maxsplit: maxsplit, + name: name); + return parts; + } + + var result = tf.Context.ExecuteOp("StringSplitV2", name, + new ExecuteOpArgs(input, sep) + { + GetGradientAttrs = op => new + { + maxsplit = op.get_attr("maxsplit") + } + }.SetAttributes(new { maxsplit })); + var (indices, values, shape) = (result[0], result[1], result[2]); + indices.shape = new Shape(-1, 2); + values.shape = new Shape(-1); + shape.shape = new Shape(2); + + var sparse_result = new SparseTensor(indices, values, shape); + return RaggedTensor.from_value_rowids(sparse_result.values, + value_rowids: sparse_result.indices[Slice.All, 0], + nrows: sparse_result.dense_shape[0], + validate: false); + }); + } + + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, string errors, + int replacement_char = 0xFFFD, bool replace_control_characters = false, string name = null) + { + return tf_with(ops.name_scope(name, "UnicodeDecodeWithOffsets"), scope => + { + var (codepoints, byte_start_offsets) = _unicode_decode(input, input_encoding, errors, + replacement_char, replace_control_characters, + with_offsets: true, name: name); + return (codepoints, byte_start_offsets); + }); + } + + (RaggedTensor, RaggedTensor) _unicode_decode(Tensor input, string input_encoding, string errors, int replacement_char, + bool replace_control_characters, bool with_offsets, string name = null) + { + if (with_offsets) + { + var flat_result = tf.Context.ExecuteOp("UnicodeDecodeWithOffsets", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = op => new + { + input_encoding = op.get_attr("input_encoding"), + errors = op.get_attr("errors"), + replacement_char = op.get_attr("replacement_char"), + replace_control_characters = op.get_attr("replace_control_characters"), + Tsplits = op.get_attr("Tsplits") + } + }.SetAttributes(new + { + input_encoding, + errors, + replacement_char, + replace_control_characters + })); + + var codepoints = RaggedTensor.from_row_splits(flat_result[1], flat_result[0], validate: false); + + var offsets = RaggedTensor.from_row_splits(flat_result[2], flat_result[0], validate: false); + return (codepoints, offsets); + } + + return (null, null); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs new file mode 100644 index 000000000..6be0706c2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -0,0 +1,46 @@ +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class tensor_array_ops + { + /// + /// Builds a TensorArray with a new `flow` tensor. + /// + /// + /// + /// + public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) + { + if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph())) + { + throw new NotImplementedException("Attempting to build a graph-mode TF2-style " + + "TensorArray from either an eager-mode " + + "TensorArray or a TF1-style TensorArray. " + + "This is not currently supported. You may be " + + "attempting to capture a TensorArray " + + "inside a tf.function or tf.data map function. " + + "Instead, construct a new TensorArray inside " + + "the function."); + } + var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape, + colocate_with_first_write_call: old_ta.colocate_with_first_write_call); + new_ta._dynamic_size = old_ta._dynamic_size; + new_ta._size = old_ta._size; + new_ta._colocate_with = old_ta._colocate_with; + new_ta._element_shape = old_ta._element_shape; + return new_ta; + } + + public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) + { + var new_ta = tf.TensorArray( + dtype: old_ta.dtype, + infer_shape: old_ta.infer_shape, + colocate_with_first_write_call: old_ta.colocate_with_first_write_call); + + return new_ta; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs new file mode 100644 index 000000000..8453fa259 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class weights_broadcast_ops + { + public static Tensor broadcast_weights(Tensor weights, Tensor values) + { + return tf_with(ops.name_scope(null, "broadcast_weights", (weights, values)), scope => + { + values = ops.convert_to_tensor(values, name: "values"); + weights = ops.convert_to_tensor( + weights, dtype: values.dtype.as_base_dtype(), name: "weights"); + + // Try static check for exact match. + var weights_shape = weights.shape; + var values_shape = values.shape; + if (weights_shape.IsFullyDefined && + values_shape.IsFullyDefined) + return weights; + + return math_ops.multiply( + weights, array_ops.ones_like(values), name: scope); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/while_v2.cs b/src/TensorFlowNET.Core/Operations/while_v2.cs new file mode 100644 index 000000000..aae15b77d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/while_v2.cs @@ -0,0 +1,400 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Eager; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Graphs; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + class _OperationWithOutputs : Operation + { + public _OperationWithOutputs(IntPtr handle, Graph g = null) + { + _handle = handle; + _graph = g; + _outputs = null; + g._add_op(this); + } + } + internal class while_v2 + { + public static Tensor[] while_loop(Func cond, + Func body, + Tensors loop_vars, + int maximum_iterations = -1, + int parallel_iterations = 10, + string name = null, + bool back_prop = true, + bool return_same_structure = true) + { + var orig_loop_vars = loop_vars; + var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray(); + int len_orig_loop_vars = orig_loop_vars.Length; + + loop_vars = _tensor_array_to_flow(loop_vars); + loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x), loop_vars).ToTensors(); + + var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), loop_vars); + + var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); + + if(string.IsNullOrEmpty(name)) + { + name = "while"; + } + + return tf_with(ops.name_scope(name), nameScopeWhile => + { + string scope = (nameScopeWhile as ops.NameScope).scope_name; + string cond_name = control_flow_util.unique_fn_name(scope, "cond"); + string body_name = control_flow_util.unique_fn_name(scope, "body"); + + var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations); + var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype, + name: "loop_counter"); + loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray(); + + var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)} + .Concat(loop_vars_signature.Flatten()).ToArray(); + + // TODO(Rinne): possible wrong implemenation here. + var add_control_dependencies = false; + + object[] wrapped_cond(object[] inputs) + { + Tensor loop_counter = (Tensor)inputs[0]; + Tensor maximum_iterations_arg = (Tensor)inputs[1]; + Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); + var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); + if(pred.shape.IsNull || pred.shape.ndim > 0) + { + pred = array_ops.squeeze(pred); + } + if(maximum_iterations == -1) + { + return new object[] { pred }; + } + else + { + return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) }; + } + } + + var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null, + null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); + + bool stateful_parallelism = false; + + object[] wrapped_body(object[] inputs) + { + Tensor loop_counter = (Tensor)inputs[0]; + Tensor maximum_iterations_arg = (Tensor)inputs[1]; + Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); + + _copy_handle_data(loop_vars.Flatten().Skip(2), args); + + foreach(var t in cond_graph.external_captures) + { + var graph = (FuncGraph)(ops.get_default_graph()); + graph.capture(t); + } + + var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); + outputs = _tensor_array_to_flow(outputs); + + return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); + } + + var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature, + add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); + + // TODO(Rinne): possible wrong implementation here. + NestList loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() }); + body_graph.Outputs.AddRange(body_graph.internal_captures); + + cond_graph.as_default(); + int num_cond_captures = cond_graph.external_captures.Length; + Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray())); + _duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray()); + cond_graph.Exit(); + + int first_loop_var_index = 2; + + int num_flattened_oututs = orig_loop_vars.Length; + int num_original_outputs = body_graph.Outputs.Length; + if (back_prop && control_flow_util.output_all_intermediates()) + { + var intermediate_tensors = _get_intermediates(body_graph); + + foreach(var intermediate_tensor in intermediate_tensors) + { + var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations); + loop_vars_list.Values.Add(tensor_list); + + cond_graph.as_default(); + cond_graph.capture(tensor_list); + cond_graph.Exit(); + + body_graph.as_default(); + var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor); + body_graph.Outputs.Add(appended_tensor_list); + body_graph.Exit(); + } + } + + List flattened_loop_vars = new(); + foreach(var item in loop_vars_list.Values) + { + flattened_loop_vars.AddRange(item.Flatten()); + } + // skip the check + + // TODO(Rinne): deal with control dependencies + var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray(); + var span = new Span(output_shapes).Slice(first_loop_var_index, num_flattened_oututs); + for(int i = 0; i < span.Length; i++) + { + span[i] = flat_shape_invariants[i]; + } + + Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations, + (nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism); + + if (!ops.get_default_graph().building_function) + { + outputs = outputs.Select(t => array_ops.identity(t)).ToArray(); + } + + var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray(); + + if (!back_prop) + { + output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray(); + } + outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars); + + return outputs; + }); + } + + private static Tensors _tensor_array_to_flow(Tensors loop_vars) + { + if(loop_vars.NestType == NestType.Node) + { + if(loop_vars.NodeValue is FakeTensorByTensorArray fake) + { + return new Tensors(fake.TensorArray.flow); + } + else + { + return new Tensors(loop_vars.NodeValue!); + } + } + else if(loop_vars.NestType == NestType.List) + { + List> list = new(); + foreach(var item in loop_vars.ListValue!) + { + if(item.NestType == NestType.Node) + { + var nested = item.AsNest(); + if (nested.NodeValue is FakeTensorByTensorArray fake) + { + list.Add(new Nest(fake.TensorArray.flow)); + } + else + { + list.Add(new Nest(nested.NodeValue!)); + } + } + else + { + list.Add(new Nest(item.AsNest())); + } + } + return Tensors.FromNest(new Nest(list)); + } + else + { + throw new NotImplementedException(); + } + } + + private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph, + Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism) + { + var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op); + var body_stateful_ops = body_graph.get_operations().Select(x => x.op); + + bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0; + + Tensor[] _make_op(Tensor[] inputs) + { + Tensor[] outputs; + if (is_stateful) + { + outputs = gen_functional_ops._while( + inputs, + control_flow_util.create_new_tf_function(cond_graph), + control_flow_util.create_new_tf_function(body_graph), + output_shapes, + parallel_iterations, + name + ); + } + else + { + outputs = gen_functional_ops.stateless_while( + inputs, + control_flow_util.create_new_tf_function(cond_graph), + control_flow_util.create_new_tf_function(body_graph), + output_shapes, + parallel_iterations, + name + ); + } + var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs); + _copy_handle_data(body_graph.Outputs, tensors); + _set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph}); + while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs }); + while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism }); + + cond_graph.outer_graph = ops.get_default_graph(); + body_graph.outer_graph = ops.get_default_graph(); + // TODO(Rinne): set the two graphs to while_op + return tensors; + } + + return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars); + } + + /// + /// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. + /// + /// + /// + private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs) + { + List read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList(); + foreach(var branch_graph in branch_graphs) + { + if (read_only_indices.Count == 0) + { + break; + } + var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph); + read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList(); + } + AttrValue.Types.ListValue listValue = new(); + listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x)); + op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue() + { + List = listValue + }); + } + + private static Tensors _pack_sequence_as(INestStructure loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars) + { + var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item => + { + var (flow, y) = item; + if (y is FakeTensorByTensorArray ta) + { + return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow)); + } + else + { + return flow; + } + }).ToArray(); + return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors(); + } + + private static Tensor[] _get_intermediates(FuncGraph func_graph) + { + List intermediates = new(); + var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1); + + foreach(var op in func_graph.get_operations()) + { + Debug.Assert(op is Operation); + var oper = (Operation)op; + if(oper.type == "Identity" || oper.type == "MutexLock") + { + continue; + } + foreach(var o in op.outputs) + { + if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o)) + { + intermediates.Add(o); + } + } + } + return intermediates.ToArray(); + } + + private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures) + { + var types = body_graph_captures.Select(t => t.dtype).ToList(); + var c_graph = cond_graph.c_graph; + var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList(); + + var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList(); + + List tensors = new(); + foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types)) + { + var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph); + op._outputs = new Tensor[] { tensor }; + tensors.Add(tensor); + } + + var tuples = zip(body_graph_captures, tensors).ToList(); + var keys = body_graph_captures.Select(t => t.Id).ToList(); + cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2)); + cond_graph.Inputs.AddRange(tensors); + } + + private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype) + { + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + var op = c_api.TF_FinishOperation(desc, tf.Status); + tf.Status.Check(true); + var output = new TF_Output(); + output.oper = op; + output.index = 0; + return output; + } + + private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) + { + return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); + } + + private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value) + { + return ops.convert_to_tensor(value, as_ref: false); + } + + private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) + { + return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations"); + } + + private static void _copy_handle_data(IEnumerable src_tensors, IEnumerable dst_tensors) + { + foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors)) + { + handle_data_util.copy_handle_data(src_t, dst_t); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/AllocationDescription.cs b/src/TensorFlowNET.Core/Protobuf/AllocationDescription.cs new file mode 100644 index 000000000..bac94eb7e --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/AllocationDescription.cs @@ -0,0 +1,442 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/allocation_description.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/allocation_description.proto + public static partial class AllocationDescriptionReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/allocation_description.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AllocationDescriptionReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjZ0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2FsbG9jYXRpb25fZGVzY3Jp", + "cHRpb24ucHJvdG8SCnRlbnNvcmZsb3ciowEKFUFsbG9jYXRpb25EZXNjcmlw", + "dGlvbhIXCg9yZXF1ZXN0ZWRfYnl0ZXMYASABKAMSFwoPYWxsb2NhdGVkX2J5", + "dGVzGAIgASgDEhYKDmFsbG9jYXRvcl9uYW1lGAMgASgJEhUKDWFsbG9jYXRp", + "b25faWQYBCABKAMSHAoUaGFzX3NpbmdsZV9yZWZlcmVuY2UYBSABKAgSCwoD", + "cHRyGAYgASgEQpsBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCG0FsbG9j", + "YXRpb25EZXNjcmlwdGlvblByb3Rvc1ABWl1naXRodWIuY29tL3RlbnNvcmZs", + "b3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL2Fs", + "bG9jYXRpb25fZGVzY3JpcHRpb25fZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AllocationDescription), global::Tensorflow.AllocationDescription.Parser, new[]{ "RequestedBytes", "AllocatedBytes", "AllocatorName", "AllocationId", "HasSingleReference", "Ptr" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class AllocationDescription : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AllocationDescription()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationDescription() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationDescription(AllocationDescription other) : this() { + requestedBytes_ = other.requestedBytes_; + allocatedBytes_ = other.allocatedBytes_; + allocatorName_ = other.allocatorName_; + allocationId_ = other.allocationId_; + hasSingleReference_ = other.hasSingleReference_; + ptr_ = other.ptr_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationDescription Clone() { + return new AllocationDescription(this); + } + + /// Field number for the "requested_bytes" field. + public const int RequestedBytesFieldNumber = 1; + private long requestedBytes_; + /// + /// Total number of bytes requested + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long RequestedBytes { + get { return requestedBytes_; } + set { + requestedBytes_ = value; + } + } + + /// Field number for the "allocated_bytes" field. + public const int AllocatedBytesFieldNumber = 2; + private long allocatedBytes_; + /// + /// Total number of bytes allocated if known + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocatedBytes { + get { return allocatedBytes_; } + set { + allocatedBytes_ = value; + } + } + + /// Field number for the "allocator_name" field. + public const int AllocatorNameFieldNumber = 3; + private string allocatorName_ = ""; + /// + /// Name of the allocator used + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorName { + get { return allocatorName_; } + set { + allocatorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "allocation_id" field. + public const int AllocationIdFieldNumber = 4; + private long allocationId_; + /// + /// Identifier of the allocated buffer if known + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocationId { + get { return allocationId_; } + set { + allocationId_ = value; + } + } + + /// Field number for the "has_single_reference" field. + public const int HasSingleReferenceFieldNumber = 5; + private bool hasSingleReference_; + /// + /// Set if this tensor only has one remaining reference + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool HasSingleReference { + get { return hasSingleReference_; } + set { + hasSingleReference_ = value; + } + } + + /// Field number for the "ptr" field. + public const int PtrFieldNumber = 6; + private ulong ptr_; + /// + /// Address of the allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Ptr { + get { return ptr_; } + set { + ptr_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AllocationDescription); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AllocationDescription other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (RequestedBytes != other.RequestedBytes) return false; + if (AllocatedBytes != other.AllocatedBytes) return false; + if (AllocatorName != other.AllocatorName) return false; + if (AllocationId != other.AllocationId) return false; + if (HasSingleReference != other.HasSingleReference) return false; + if (Ptr != other.Ptr) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode(); + if (AllocatedBytes != 0L) hash ^= AllocatedBytes.GetHashCode(); + if (AllocatorName.Length != 0) hash ^= AllocatorName.GetHashCode(); + if (AllocationId != 0L) hash ^= AllocationId.GetHashCode(); + if (HasSingleReference != false) hash ^= HasSingleReference.GetHashCode(); + if (Ptr != 0UL) hash ^= Ptr.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (RequestedBytes != 0L) { + output.WriteRawTag(8); + output.WriteInt64(RequestedBytes); + } + if (AllocatedBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllocatedBytes); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(AllocatorName); + } + if (AllocationId != 0L) { + output.WriteRawTag(32); + output.WriteInt64(AllocationId); + } + if (HasSingleReference != false) { + output.WriteRawTag(40); + output.WriteBool(HasSingleReference); + } + if (Ptr != 0UL) { + output.WriteRawTag(48); + output.WriteUInt64(Ptr); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (RequestedBytes != 0L) { + output.WriteRawTag(8); + output.WriteInt64(RequestedBytes); + } + if (AllocatedBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllocatedBytes); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(AllocatorName); + } + if (AllocationId != 0L) { + output.WriteRawTag(32); + output.WriteInt64(AllocationId); + } + if (HasSingleReference != false) { + output.WriteRawTag(40); + output.WriteBool(HasSingleReference); + } + if (Ptr != 0UL) { + output.WriteRawTag(48); + output.WriteUInt64(Ptr); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (RequestedBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(RequestedBytes); + } + if (AllocatedBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocatedBytes); + } + if (AllocatorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorName); + } + if (AllocationId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocationId); + } + if (HasSingleReference != false) { + size += 1 + 1; + } + if (Ptr != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Ptr); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AllocationDescription other) { + if (other == null) { + return; + } + if (other.RequestedBytes != 0L) { + RequestedBytes = other.RequestedBytes; + } + if (other.AllocatedBytes != 0L) { + AllocatedBytes = other.AllocatedBytes; + } + if (other.AllocatorName.Length != 0) { + AllocatorName = other.AllocatorName; + } + if (other.AllocationId != 0L) { + AllocationId = other.AllocationId; + } + if (other.HasSingleReference != false) { + HasSingleReference = other.HasSingleReference; + } + if (other.Ptr != 0UL) { + Ptr = other.Ptr; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + RequestedBytes = input.ReadInt64(); + break; + } + case 16: { + AllocatedBytes = input.ReadInt64(); + break; + } + case 26: { + AllocatorName = input.ReadString(); + break; + } + case 32: { + AllocationId = input.ReadInt64(); + break; + } + case 40: { + HasSingleReference = input.ReadBool(); + break; + } + case 48: { + Ptr = input.ReadUInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + RequestedBytes = input.ReadInt64(); + break; + } + case 16: { + AllocatedBytes = input.ReadInt64(); + break; + } + case 26: { + AllocatorName = input.ReadString(); + break; + } + case 32: { + AllocationId = input.ReadInt64(); + break; + } + case 40: { + HasSingleReference = input.ReadBool(); + break; + } + case 48: { + Ptr = input.ReadUInt64(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/ApiDef.cs b/src/TensorFlowNET.Core/Protobuf/ApiDef.cs new file mode 100644 index 000000000..b7bc58294 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ApiDef.cs @@ -0,0 +1,1785 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/api_def.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/api_def.proto + public static partial class ApiDefReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/api_def.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ApiDefReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cid0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2FwaV9kZWYucHJvdG8SCnRl", + "bnNvcmZsb3caKnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvYXR0cl92YWx1", + "ZS5wcm90byLhBQoGQXBpRGVmEhUKDWdyYXBoX29wX25hbWUYASABKAkSGwoT", + "ZGVwcmVjYXRpb25fbWVzc2FnZRgMIAEoCRIbChNkZXByZWNhdGlvbl92ZXJz", + "aW9uGA0gASgFEjEKCnZpc2liaWxpdHkYAiABKA4yHS50ZW5zb3JmbG93LkFw", + "aURlZi5WaXNpYmlsaXR5Ei0KCGVuZHBvaW50GAMgAygLMhsudGVuc29yZmxv", + "dy5BcGlEZWYuRW5kcG9pbnQSJgoGaW5fYXJnGAQgAygLMhYudGVuc29yZmxv", + "dy5BcGlEZWYuQXJnEicKB291dF9hcmcYBSADKAsyFi50ZW5zb3JmbG93LkFw", + "aURlZi5BcmcSEQoJYXJnX29yZGVyGAsgAygJEiUKBGF0dHIYBiADKAsyFy50", + "ZW5zb3JmbG93LkFwaURlZi5BdHRyEg8KB3N1bW1hcnkYByABKAkSEwoLZGVz", + "Y3JpcHRpb24YCCABKAkSGgoSZGVzY3JpcHRpb25fcHJlZml4GAkgASgJEhoK", + "EmRlc2NyaXB0aW9uX3N1ZmZpeBgKIAEoCRpJCghFbmRwb2ludBIMCgRuYW1l", + "GAEgASgJEhIKCmRlcHJlY2F0ZWQYAyABKAgSGwoTZGVwcmVjYXRpb25fdmVy", + "c2lvbhgEIAEoBRo7CgNBcmcSDAoEbmFtZRgBIAEoCRIRCglyZW5hbWVfdG8Y", + "AiABKAkSEwoLZGVzY3JpcHRpb24YAyABKAkaagoEQXR0chIMCgRuYW1lGAEg", + "ASgJEhEKCXJlbmFtZV90bxgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgL", + "MhUudGVuc29yZmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAki", + "RwoKVmlzaWJpbGl0eRIWChJERUZBVUxUX1ZJU0lCSUxJVFkQABILCgdWSVNJ", + "QkxFEAESCAoEU0tJUBACEgoKBkhJRERFThADIikKB0FwaURlZnMSHgoCb3AY", + "ASADKAsyEi50ZW5zb3JmbG93LkFwaURlZkJ9ChhvcmcudGVuc29yZmxvdy5m", + "cmFtZXdvcmtCDEFwaURlZlByb3Rvc1ABWk5naXRodWIuY29tL3RlbnNvcmZs", + "b3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL2Fw", + "aV9kZWZfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ApiDef), global::Tensorflow.ApiDef.Parser, new[]{ "GraphOpName", "DeprecationMessage", "DeprecationVersion", "Visibility", "Endpoint", "InArg", "OutArg", "ArgOrder", "Attr", "Summary", "Description", "DescriptionPrefix", "DescriptionSuffix" }, null, new[]{ typeof(global::Tensorflow.ApiDef.Types.Visibility) }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ApiDef.Types.Endpoint), global::Tensorflow.ApiDef.Types.Endpoint.Parser, new[]{ "Name", "Deprecated", "DeprecationVersion" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ApiDef.Types.Arg), global::Tensorflow.ApiDef.Types.Arg.Parser, new[]{ "Name", "RenameTo", "Description" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ApiDef.Types.Attr), global::Tensorflow.ApiDef.Types.Attr.Parser, new[]{ "Name", "RenameTo", "DefaultValue", "Description" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ApiDefs), global::Tensorflow.ApiDefs.Parser, new[]{ "Op" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Used to specify and override the default API & behavior in the + /// generated code for client languages, from what you would get from + /// the OpDef alone. There will be a set of ApiDefs that are common + /// to all client languages, and another set per client language. + /// The per-client-language ApiDefs will inherit values from the + /// common ApiDefs which it can either replace or modify. + /// + /// We separate the API definition from the OpDef so we can evolve the + /// API while remaining backwards compatible when interpreting old + /// graphs. Overrides go in an "api_def.pbtxt" file with a text-format + /// ApiDefs message. + /// + /// WARNING: Be *very* careful changing the API for any existing op -- + /// you can change the semantics of existing code. These changes may + /// need to wait until a major release of TensorFlow to avoid breaking + /// our compatibility promises. + /// + public sealed partial class ApiDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ApiDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ApiDefReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDef(ApiDef other) : this() { + graphOpName_ = other.graphOpName_; + deprecationMessage_ = other.deprecationMessage_; + deprecationVersion_ = other.deprecationVersion_; + visibility_ = other.visibility_; + endpoint_ = other.endpoint_.Clone(); + inArg_ = other.inArg_.Clone(); + outArg_ = other.outArg_.Clone(); + argOrder_ = other.argOrder_.Clone(); + attr_ = other.attr_.Clone(); + summary_ = other.summary_; + description_ = other.description_; + descriptionPrefix_ = other.descriptionPrefix_; + descriptionSuffix_ = other.descriptionSuffix_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDef Clone() { + return new ApiDef(this); + } + + /// Field number for the "graph_op_name" field. + public const int GraphOpNameFieldNumber = 1; + private string graphOpName_ = ""; + /// + /// Name of the op (in the OpDef) to specify the API for. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string GraphOpName { + get { return graphOpName_; } + set { + graphOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "deprecation_message" field. + public const int DeprecationMessageFieldNumber = 12; + private string deprecationMessage_ = ""; + /// + /// If this op is deprecated, set deprecation message to the message + /// that should be logged when this op is used. + /// The message should indicate alternative op to use, if any. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DeprecationMessage { + get { return deprecationMessage_; } + set { + deprecationMessage_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "deprecation_version" field. + public const int DeprecationVersionFieldNumber = 13; + private int deprecationVersion_; + /// + /// Major version when the op will be deleted. For e.g. set this + /// value to 2 if op API should be removed in TensorFlow 2.0 and + /// deprecated in versions before that. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int DeprecationVersion { + get { return deprecationVersion_; } + set { + deprecationVersion_ = value; + } + } + + /// Field number for the "visibility" field. + public const int VisibilityFieldNumber = 2; + private global::Tensorflow.ApiDef.Types.Visibility visibility_ = global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ApiDef.Types.Visibility Visibility { + get { return visibility_; } + set { + visibility_ = value; + } + } + + /// Field number for the "endpoint" field. + public const int EndpointFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_endpoint_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.ApiDef.Types.Endpoint.Parser); + private readonly pbc::RepeatedField endpoint_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Endpoint { + get { return endpoint_; } + } + + /// Field number for the "in_arg" field. + public const int InArgFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_inArg_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.ApiDef.Types.Arg.Parser); + private readonly pbc::RepeatedField inArg_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InArg { + get { return inArg_; } + } + + /// Field number for the "out_arg" field. + public const int OutArgFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_outArg_codec + = pb::FieldCodec.ForMessage(42, global::Tensorflow.ApiDef.Types.Arg.Parser); + private readonly pbc::RepeatedField outArg_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OutArg { + get { return outArg_; } + } + + /// Field number for the "arg_order" field. + public const int ArgOrderFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_argOrder_codec + = pb::FieldCodec.ForString(90); + private readonly pbc::RepeatedField argOrder_ = new pbc::RepeatedField(); + /// + /// List of original in_arg names to specify new argument order. + /// Length of arg_order should be either empty to keep current order + /// or match size of in_arg. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ArgOrder { + get { return argOrder_; } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_attr_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.ApiDef.Types.Attr.Parser); + private readonly pbc::RepeatedField attr_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Attr { + get { return attr_; } + } + + /// Field number for the "summary" field. + public const int SummaryFieldNumber = 7; + private string summary_ = ""; + /// + /// One-line human-readable description of what the Op does. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Summary { + get { return summary_; } + set { + summary_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 8; + private string description_ = ""; + /// + /// Additional, longer human-readable description of what the Op does. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description_prefix" field. + public const int DescriptionPrefixFieldNumber = 9; + private string descriptionPrefix_ = ""; + /// + /// Modify an existing/inherited description by adding text to the beginning + /// or end. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DescriptionPrefix { + get { return descriptionPrefix_; } + set { + descriptionPrefix_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description_suffix" field. + public const int DescriptionSuffixFieldNumber = 10; + private string descriptionSuffix_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DescriptionSuffix { + get { return descriptionSuffix_; } + set { + descriptionSuffix_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ApiDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ApiDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (GraphOpName != other.GraphOpName) return false; + if (DeprecationMessage != other.DeprecationMessage) return false; + if (DeprecationVersion != other.DeprecationVersion) return false; + if (Visibility != other.Visibility) return false; + if(!endpoint_.Equals(other.endpoint_)) return false; + if(!inArg_.Equals(other.inArg_)) return false; + if(!outArg_.Equals(other.outArg_)) return false; + if(!argOrder_.Equals(other.argOrder_)) return false; + if(!attr_.Equals(other.attr_)) return false; + if (Summary != other.Summary) return false; + if (Description != other.Description) return false; + if (DescriptionPrefix != other.DescriptionPrefix) return false; + if (DescriptionSuffix != other.DescriptionSuffix) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (GraphOpName.Length != 0) hash ^= GraphOpName.GetHashCode(); + if (DeprecationMessage.Length != 0) hash ^= DeprecationMessage.GetHashCode(); + if (DeprecationVersion != 0) hash ^= DeprecationVersion.GetHashCode(); + if (Visibility != global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility) hash ^= Visibility.GetHashCode(); + hash ^= endpoint_.GetHashCode(); + hash ^= inArg_.GetHashCode(); + hash ^= outArg_.GetHashCode(); + hash ^= argOrder_.GetHashCode(); + hash ^= attr_.GetHashCode(); + if (Summary.Length != 0) hash ^= Summary.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (DescriptionPrefix.Length != 0) hash ^= DescriptionPrefix.GetHashCode(); + if (DescriptionSuffix.Length != 0) hash ^= DescriptionSuffix.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (GraphOpName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(GraphOpName); + } + if (Visibility != global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility) { + output.WriteRawTag(16); + output.WriteEnum((int) Visibility); + } + endpoint_.WriteTo(output, _repeated_endpoint_codec); + inArg_.WriteTo(output, _repeated_inArg_codec); + outArg_.WriteTo(output, _repeated_outArg_codec); + attr_.WriteTo(output, _repeated_attr_codec); + if (Summary.Length != 0) { + output.WriteRawTag(58); + output.WriteString(Summary); + } + if (Description.Length != 0) { + output.WriteRawTag(66); + output.WriteString(Description); + } + if (DescriptionPrefix.Length != 0) { + output.WriteRawTag(74); + output.WriteString(DescriptionPrefix); + } + if (DescriptionSuffix.Length != 0) { + output.WriteRawTag(82); + output.WriteString(DescriptionSuffix); + } + argOrder_.WriteTo(output, _repeated_argOrder_codec); + if (DeprecationMessage.Length != 0) { + output.WriteRawTag(98); + output.WriteString(DeprecationMessage); + } + if (DeprecationVersion != 0) { + output.WriteRawTag(104); + output.WriteInt32(DeprecationVersion); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (GraphOpName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(GraphOpName); + } + if (Visibility != global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility) { + output.WriteRawTag(16); + output.WriteEnum((int) Visibility); + } + endpoint_.WriteTo(ref output, _repeated_endpoint_codec); + inArg_.WriteTo(ref output, _repeated_inArg_codec); + outArg_.WriteTo(ref output, _repeated_outArg_codec); + attr_.WriteTo(ref output, _repeated_attr_codec); + if (Summary.Length != 0) { + output.WriteRawTag(58); + output.WriteString(Summary); + } + if (Description.Length != 0) { + output.WriteRawTag(66); + output.WriteString(Description); + } + if (DescriptionPrefix.Length != 0) { + output.WriteRawTag(74); + output.WriteString(DescriptionPrefix); + } + if (DescriptionSuffix.Length != 0) { + output.WriteRawTag(82); + output.WriteString(DescriptionSuffix); + } + argOrder_.WriteTo(ref output, _repeated_argOrder_codec); + if (DeprecationMessage.Length != 0) { + output.WriteRawTag(98); + output.WriteString(DeprecationMessage); + } + if (DeprecationVersion != 0) { + output.WriteRawTag(104); + output.WriteInt32(DeprecationVersion); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (GraphOpName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GraphOpName); + } + if (DeprecationMessage.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DeprecationMessage); + } + if (DeprecationVersion != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(DeprecationVersion); + } + if (Visibility != global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Visibility); + } + size += endpoint_.CalculateSize(_repeated_endpoint_codec); + size += inArg_.CalculateSize(_repeated_inArg_codec); + size += outArg_.CalculateSize(_repeated_outArg_codec); + size += argOrder_.CalculateSize(_repeated_argOrder_codec); + size += attr_.CalculateSize(_repeated_attr_codec); + if (Summary.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Summary); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (DescriptionPrefix.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DescriptionPrefix); + } + if (DescriptionSuffix.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DescriptionSuffix); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ApiDef other) { + if (other == null) { + return; + } + if (other.GraphOpName.Length != 0) { + GraphOpName = other.GraphOpName; + } + if (other.DeprecationMessage.Length != 0) { + DeprecationMessage = other.DeprecationMessage; + } + if (other.DeprecationVersion != 0) { + DeprecationVersion = other.DeprecationVersion; + } + if (other.Visibility != global::Tensorflow.ApiDef.Types.Visibility.DefaultVisibility) { + Visibility = other.Visibility; + } + endpoint_.Add(other.endpoint_); + inArg_.Add(other.inArg_); + outArg_.Add(other.outArg_); + argOrder_.Add(other.argOrder_); + attr_.Add(other.attr_); + if (other.Summary.Length != 0) { + Summary = other.Summary; + } + if (other.Description.Length != 0) { + Description = other.Description; + } + if (other.DescriptionPrefix.Length != 0) { + DescriptionPrefix = other.DescriptionPrefix; + } + if (other.DescriptionSuffix.Length != 0) { + DescriptionSuffix = other.DescriptionSuffix; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + GraphOpName = input.ReadString(); + break; + } + case 16: { + Visibility = (global::Tensorflow.ApiDef.Types.Visibility) input.ReadEnum(); + break; + } + case 26: { + endpoint_.AddEntriesFrom(input, _repeated_endpoint_codec); + break; + } + case 34: { + inArg_.AddEntriesFrom(input, _repeated_inArg_codec); + break; + } + case 42: { + outArg_.AddEntriesFrom(input, _repeated_outArg_codec); + break; + } + case 50: { + attr_.AddEntriesFrom(input, _repeated_attr_codec); + break; + } + case 58: { + Summary = input.ReadString(); + break; + } + case 66: { + Description = input.ReadString(); + break; + } + case 74: { + DescriptionPrefix = input.ReadString(); + break; + } + case 82: { + DescriptionSuffix = input.ReadString(); + break; + } + case 90: { + argOrder_.AddEntriesFrom(input, _repeated_argOrder_codec); + break; + } + case 98: { + DeprecationMessage = input.ReadString(); + break; + } + case 104: { + DeprecationVersion = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + GraphOpName = input.ReadString(); + break; + } + case 16: { + Visibility = (global::Tensorflow.ApiDef.Types.Visibility) input.ReadEnum(); + break; + } + case 26: { + endpoint_.AddEntriesFrom(ref input, _repeated_endpoint_codec); + break; + } + case 34: { + inArg_.AddEntriesFrom(ref input, _repeated_inArg_codec); + break; + } + case 42: { + outArg_.AddEntriesFrom(ref input, _repeated_outArg_codec); + break; + } + case 50: { + attr_.AddEntriesFrom(ref input, _repeated_attr_codec); + break; + } + case 58: { + Summary = input.ReadString(); + break; + } + case 66: { + Description = input.ReadString(); + break; + } + case 74: { + DescriptionPrefix = input.ReadString(); + break; + } + case 82: { + DescriptionSuffix = input.ReadString(); + break; + } + case 90: { + argOrder_.AddEntriesFrom(ref input, _repeated_argOrder_codec); + break; + } + case 98: { + DeprecationMessage = input.ReadString(); + break; + } + case 104: { + DeprecationVersion = input.ReadInt32(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the ApiDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Visibility { + /// + /// Normally this is "VISIBLE" unless you are inheriting a + /// different value from another ApiDef. + /// + [pbr::OriginalName("DEFAULT_VISIBILITY")] DefaultVisibility = 0, + /// + /// Publicly visible in the API. + /// + [pbr::OriginalName("VISIBLE")] Visible = 1, + /// + /// Do not include this op in the generated API. If visibility is + /// set to 'SKIP', other fields are ignored for this op. + /// + [pbr::OriginalName("SKIP")] Skip = 2, + /// + /// Hide this op by putting it into an internal namespace (or whatever + /// is appropriate in the target language). + /// + [pbr::OriginalName("HIDDEN")] Hidden = 3, + } + + /// + /// If you specify any endpoint, this will replace all of the + /// inherited endpoints. The first endpoint should be the + /// "canonical" endpoint, and should not be deprecated (unless all + /// endpoints are deprecated). + /// + public sealed partial class Endpoint : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Endpoint()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ApiDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Endpoint() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Endpoint(Endpoint other) : this() { + name_ = other.name_; + deprecated_ = other.deprecated_; + deprecationVersion_ = other.deprecationVersion_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Endpoint Clone() { + return new Endpoint(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Name should be either like "CamelCaseName" or + /// "Package.CamelCaseName". Client-language-specific ApiDefs may + /// use a snake_case convention instead of CamelCase. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "deprecated" field. + public const int DeprecatedFieldNumber = 3; + private bool deprecated_; + /// + /// Set if this endpoint is deprecated. If set to true, a message suggesting + /// to use a non-deprecated endpoint instead will be printed. If all + /// endpoints are deprecated, set deprecation_message in ApiDef instead. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Deprecated { + get { return deprecated_; } + set { + deprecated_ = value; + } + } + + /// Field number for the "deprecation_version" field. + public const int DeprecationVersionFieldNumber = 4; + private int deprecationVersion_; + /// + /// Major version when an endpoint will be deleted. For e.g. set this + /// value to 2 if endpoint should be removed in TensorFlow 2.0 and + /// deprecated in versions before that. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int DeprecationVersion { + get { return deprecationVersion_; } + set { + deprecationVersion_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Endpoint); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Endpoint other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Deprecated != other.Deprecated) return false; + if (DeprecationVersion != other.DeprecationVersion) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Deprecated != false) hash ^= Deprecated.GetHashCode(); + if (DeprecationVersion != 0) hash ^= DeprecationVersion.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Deprecated != false) { + output.WriteRawTag(24); + output.WriteBool(Deprecated); + } + if (DeprecationVersion != 0) { + output.WriteRawTag(32); + output.WriteInt32(DeprecationVersion); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Deprecated != false) { + output.WriteRawTag(24); + output.WriteBool(Deprecated); + } + if (DeprecationVersion != 0) { + output.WriteRawTag(32); + output.WriteInt32(DeprecationVersion); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Deprecated != false) { + size += 1 + 1; + } + if (DeprecationVersion != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(DeprecationVersion); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Endpoint other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Deprecated != false) { + Deprecated = other.Deprecated; + } + if (other.DeprecationVersion != 0) { + DeprecationVersion = other.DeprecationVersion; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 24: { + Deprecated = input.ReadBool(); + break; + } + case 32: { + DeprecationVersion = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 24: { + Deprecated = input.ReadBool(); + break; + } + case 32: { + DeprecationVersion = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class Arg : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Arg()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ApiDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Arg() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Arg(Arg other) : this() { + name_ = other.name_; + renameTo_ = other.renameTo_; + description_ = other.description_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Arg Clone() { + return new Arg(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "rename_to" field. + public const int RenameToFieldNumber = 2; + private string renameTo_ = ""; + /// + /// Change the name used to access this arg in the API from what + /// is used in the GraphDef. Note that these names in `backticks` + /// will also be replaced in the summary & description fields. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RenameTo { + get { return renameTo_; } + set { + renameTo_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 3; + private string description_ = ""; + /// + /// Note: this will replace any inherited arg doc. There is no + /// current way of modifying arg descriptions (other than replacing + /// them entirely) as can be done with op descriptions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Arg); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Arg other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (RenameTo != other.RenameTo) return false; + if (Description != other.Description) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (RenameTo.Length != 0) hash ^= RenameTo.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (RenameTo.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RenameTo); + } + if (Description.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Description); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (RenameTo.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RenameTo); + } + if (Description.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Description); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (RenameTo.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RenameTo); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Arg other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.RenameTo.Length != 0) { + RenameTo = other.RenameTo; + } + if (other.Description.Length != 0) { + Description = other.Description; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + RenameTo = input.ReadString(); + break; + } + case 26: { + Description = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + RenameTo = input.ReadString(); + break; + } + case 26: { + Description = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Description of the graph-construction-time configuration of this + /// Op. That is to say, this describes the attr fields that will + /// be specified in the NodeDef. + /// + public sealed partial class Attr : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Attr()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ApiDef.Descriptor.NestedTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Attr() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Attr(Attr other) : this() { + name_ = other.name_; + renameTo_ = other.renameTo_; + defaultValue_ = other.defaultValue_ != null ? other.defaultValue_.Clone() : null; + description_ = other.description_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Attr Clone() { + return new Attr(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "rename_to" field. + public const int RenameToFieldNumber = 2; + private string renameTo_ = ""; + /// + /// Change the name used to access this attr in the API from what + /// is used in the GraphDef. Note that these names in `backticks` + /// will also be replaced in the summary & description fields. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RenameTo { + get { return renameTo_; } + set { + renameTo_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "default_value" field. + public const int DefaultValueFieldNumber = 3; + private global::Tensorflow.AttrValue defaultValue_; + /// + /// Specify a new default value to use for this attr. This default + /// will be used when creating new graphs, as opposed to the + /// default in the OpDef, which will be used when interpreting old + /// GraphDefs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.AttrValue DefaultValue { + get { return defaultValue_; } + set { + defaultValue_ = value; + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 4; + private string description_ = ""; + /// + /// Note: this will replace any inherited attr doc, there is no current + /// way of modifying attr descriptions as can be done with op descriptions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Attr); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Attr other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (RenameTo != other.RenameTo) return false; + if (!object.Equals(DefaultValue, other.DefaultValue)) return false; + if (Description != other.Description) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (RenameTo.Length != 0) hash ^= RenameTo.GetHashCode(); + if (defaultValue_ != null) hash ^= DefaultValue.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (RenameTo.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RenameTo); + } + if (defaultValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefaultValue); + } + if (Description.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Description); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (RenameTo.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RenameTo); + } + if (defaultValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefaultValue); + } + if (Description.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Description); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (RenameTo.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RenameTo); + } + if (defaultValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DefaultValue); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Attr other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.RenameTo.Length != 0) { + RenameTo = other.RenameTo; + } + if (other.defaultValue_ != null) { + if (defaultValue_ == null) { + DefaultValue = new global::Tensorflow.AttrValue(); + } + DefaultValue.MergeFrom(other.DefaultValue); + } + if (other.Description.Length != 0) { + Description = other.Description; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + RenameTo = input.ReadString(); + break; + } + case 26: { + if (defaultValue_ == null) { + DefaultValue = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(DefaultValue); + break; + } + case 34: { + Description = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + RenameTo = input.ReadString(); + break; + } + case 26: { + if (defaultValue_ == null) { + DefaultValue = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(DefaultValue); + break; + } + case 34: { + Description = input.ReadString(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + public sealed partial class ApiDefs : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ApiDefs()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ApiDefReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDefs() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDefs(ApiDefs other) : this() { + op_ = other.op_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ApiDefs Clone() { + return new ApiDefs(this); + } + + /// Field number for the "op" field. + public const int OpFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_op_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.ApiDef.Parser); + private readonly pbc::RepeatedField op_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Op { + get { return op_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ApiDefs); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ApiDefs other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!op_.Equals(other.op_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= op_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + op_.WriteTo(output, _repeated_op_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + op_.WriteTo(ref output, _repeated_op_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += op_.CalculateSize(_repeated_op_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ApiDefs other) { + if (other == null) { + return; + } + op_.Add(other.op_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + op_.AddEntriesFrom(input, _repeated_op_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + op_.AddEntriesFrom(ref input, _repeated_op_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/AttrValue.cs b/src/TensorFlowNET.Core/Protobuf/AttrValue.cs index 158179a09..fbccba222 100644 --- a/src/TensorFlowNET.Core/Protobuf/AttrValue.cs +++ b/src/TensorFlowNET.Core/Protobuf/AttrValue.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: attr_value.proto +// source: tensorflow/core/framework/attr_value.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from attr_value.proto + /// Holder for reflection information generated from tensorflow/core/framework/attr_value.proto public static partial class AttrValueReflection { #region Descriptor - /// File descriptor for attr_value.proto + /// File descriptor for tensorflow/core/framework/attr_value.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,31 +24,34 @@ public static partial class AttrValueReflection { static AttrValueReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "ChBhdHRyX3ZhbHVlLnByb3RvEgp0ZW5zb3JmbG93Ggx0ZW5zb3IucHJvdG8a", - "EnRlbnNvcl9zaGFwZS5wcm90bxoLdHlwZXMucHJvdG8ipgQKCUF0dHJWYWx1", - "ZRILCgFzGAIgASgMSAASCwoBaRgDIAEoA0gAEgsKAWYYBCABKAJIABILCgFi", - "GAUgASgISAASJAoEdHlwZRgGIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGVI", - "ABItCgVzaGFwZRgHIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90", - "b0gAEikKBnRlbnNvchgIIAEoCzIXLnRlbnNvcmZsb3cuVGVuc29yUHJvdG9I", - "ABIvCgRsaXN0GAEgASgLMh8udGVuc29yZmxvdy5BdHRyVmFsdWUuTGlzdFZh", - "bHVlSAASKAoEZnVuYxgKIAEoCzIYLnRlbnNvcmZsb3cuTmFtZUF0dHJMaXN0", - "SAASFQoLcGxhY2Vob2xkZXIYCSABKAlIABrpAQoJTGlzdFZhbHVlEgkKAXMY", - "AiADKAwSDQoBaRgDIAMoA0ICEAESDQoBZhgEIAMoAkICEAESDQoBYhgFIAMo", - "CEICEAESJgoEdHlwZRgGIAMoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGVCAhAB", - "EisKBXNoYXBlGAcgAygLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3Rv", - "EicKBnRlbnNvchgIIAMoCzIXLnRlbnNvcmZsb3cuVGVuc29yUHJvdG8SJgoE", - "ZnVuYxgJIAMoCzIYLnRlbnNvcmZsb3cuTmFtZUF0dHJMaXN0QgcKBXZhbHVl", - "IpIBCgxOYW1lQXR0ckxpc3QSDAoEbmFtZRgBIAEoCRIwCgRhdHRyGAIgAygL", - "MiIudGVuc29yZmxvdy5OYW1lQXR0ckxpc3QuQXR0ckVudHJ5GkIKCUF0dHJF", - "bnRyeRILCgNrZXkYASABKAkSJAoFdmFsdWUYAiABKAsyFS50ZW5zb3JmbG93", - "LkF0dHJWYWx1ZToCOAFCbwoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQg9B", - "dHRyVmFsdWVQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", - "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + "Cip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2F0dHJfdmFsdWUucHJvdG8S", + "CnRlbnNvcmZsb3caJnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdGVuc29y", + "LnByb3RvGix0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvcl9zaGFw", + "ZS5wcm90bxoldGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90eXBlcy5wcm90", + "byKmBAoJQXR0clZhbHVlEgsKAXMYAiABKAxIABILCgFpGAMgASgDSAASCwoB", + "ZhgEIAEoAkgAEgsKAWIYBSABKAhIABIkCgR0eXBlGAYgASgOMhQudGVuc29y", + "Zmxvdy5EYXRhVHlwZUgAEi0KBXNoYXBlGAcgASgLMhwudGVuc29yZmxvdy5U", + "ZW5zb3JTaGFwZVByb3RvSAASKQoGdGVuc29yGAggASgLMhcudGVuc29yZmxv", + "dy5UZW5zb3JQcm90b0gAEi8KBGxpc3QYASABKAsyHy50ZW5zb3JmbG93LkF0", + "dHJWYWx1ZS5MaXN0VmFsdWVIABIoCgRmdW5jGAogASgLMhgudGVuc29yZmxv", + "dy5OYW1lQXR0ckxpc3RIABIVCgtwbGFjZWhvbGRlchgJIAEoCUgAGukBCglM", + "aXN0VmFsdWUSCQoBcxgCIAMoDBINCgFpGAMgAygDQgIQARINCgFmGAQgAygC", + "QgIQARINCgFiGAUgAygIQgIQARImCgR0eXBlGAYgAygOMhQudGVuc29yZmxv", + "dy5EYXRhVHlwZUICEAESKwoFc2hhcGUYByADKAsyHC50ZW5zb3JmbG93LlRl", + "bnNvclNoYXBlUHJvdG8SJwoGdGVuc29yGAggAygLMhcudGVuc29yZmxvdy5U", + "ZW5zb3JQcm90bxImCgRmdW5jGAkgAygLMhgudGVuc29yZmxvdy5OYW1lQXR0", + "ckxpc3RCBwoFdmFsdWUikgEKDE5hbWVBdHRyTGlzdBIMCgRuYW1lGAEgASgJ", + "EjAKBGF0dHIYAiADKAsyIi50ZW5zb3JmbG93Lk5hbWVBdHRyTGlzdC5BdHRy", + "RW50cnkaQgoJQXR0ckVudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEo", + "CzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlOgI4AUKDAQoYb3JnLnRlbnNvcmZs", + "b3cuZnJhbWV3b3JrQg9BdHRyVmFsdWVQcm90b3NQAVpRZ2l0aHViLmNvbS90", + "ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1l", + "d29yay9hdHRyX3ZhbHVlX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Tensorflow.TensorReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue), global::Tensorflow.AttrValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "List", "Func", "Placeholder" }, new[]{ "Value" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue.Types.ListValue), global::Tensorflow.AttrValue.Types.ListValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "Func" }, null, null, null)}), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NameAttrList), global::Tensorflow.NameAttrList.Parser, new[]{ "Name", "Attr" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue), global::Tensorflow.AttrValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "List", "Func", "Placeholder" }, new[]{ "Value" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue.Types.ListValue), global::Tensorflow.AttrValue.Types.ListValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "Func" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NameAttrList), global::Tensorflow.NameAttrList.Parser, new[]{ "Name", "Attr" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) })); } #endregion @@ -60,23 +63,31 @@ static AttrValueReflection() { /// Comment indicates the corresponding attr type. Only the field matching the /// attr type may be filled. /// - public sealed partial class AttrValue : pb::IMessage { + public sealed partial class AttrValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttrValue()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrValue() { OnConstruction(); } @@ -84,6 +95,7 @@ public AttrValue() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrValue(AttrValue other) : this() { switch (other.ValueCase) { case ValueOneofCase.S: @@ -122,6 +134,7 @@ public AttrValue(AttrValue other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrValue Clone() { return new AttrValue(this); } @@ -132,6 +145,7 @@ public AttrValue Clone() { /// "string" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pb::ByteString S { get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } set { @@ -146,6 +160,7 @@ public AttrValue Clone() { /// "int" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public long I { get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } set { @@ -160,6 +175,7 @@ public long I { /// "float" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public float F { get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } set { @@ -174,6 +190,7 @@ public float F { /// "bool" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool B { get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } set { @@ -188,8 +205,9 @@ public bool B { /// "type" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.DataType Type { - get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : 0; } + get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; } set { value_ = value; valueCase_ = ValueOneofCase.Type; @@ -202,6 +220,7 @@ public bool B { /// "shape" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.TensorShapeProto Shape { get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } set { @@ -216,6 +235,7 @@ public bool B { /// "tensor" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.TensorProto Tensor { get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } set { @@ -230,6 +250,7 @@ public bool B { /// any "list(...)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.AttrValue.Types.ListValue List { get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } set { @@ -247,6 +268,7 @@ public bool B { /// that attr in the instantiation. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.NameAttrList Func { get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } set { @@ -267,6 +289,7 @@ public bool B { /// given the value "bar". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Placeholder { get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } set { @@ -292,22 +315,26 @@ public enum ValueOneofCase { } private ValueOneofCase valueCase_ = ValueOneofCase.None; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ValueOneofCase ValueCase { get { return valueCase_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void ClearValue() { valueCase_ = ValueOneofCase.None; value_ = null; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as AttrValue); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(AttrValue other) { if (ReferenceEquals(other, null)) { return false; @@ -330,6 +357,7 @@ public bool Equals(AttrValue other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); @@ -350,12 +378,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (valueCase_ == ValueOneofCase.List) { output.WriteRawTag(10); output.WriteMessage(List); @@ -399,9 +432,61 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (valueCase_ == ValueOneofCase.List) { + output.WriteRawTag(10); + output.WriteMessage(List); + } + if (valueCase_ == ValueOneofCase.S) { + output.WriteRawTag(18); + output.WriteBytes(S); + } + if (valueCase_ == ValueOneofCase.I) { + output.WriteRawTag(24); + output.WriteInt64(I); + } + if (valueCase_ == ValueOneofCase.F) { + output.WriteRawTag(37); + output.WriteFloat(F); + } + if (valueCase_ == ValueOneofCase.B) { + output.WriteRawTag(40); + output.WriteBool(B); + } + if (valueCase_ == ValueOneofCase.Type) { + output.WriteRawTag(48); + output.WriteEnum((int) Type); + } + if (valueCase_ == ValueOneofCase.Shape) { + output.WriteRawTag(58); + output.WriteMessage(Shape); + } + if (valueCase_ == ValueOneofCase.Tensor) { + output.WriteRawTag(66); + output.WriteMessage(Tensor); + } + if (valueCase_ == ValueOneofCase.Placeholder) { + output.WriteRawTag(74); + output.WriteString(Placeholder); + } + if (valueCase_ == ValueOneofCase.Func) { + output.WriteRawTag(82); + output.WriteMessage(Func); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (valueCase_ == ValueOneofCase.S) { @@ -441,6 +526,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(AttrValue other) { if (other == null) { return; @@ -494,7 +580,11 @@ public void MergeFrom(AttrValue other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -564,32 +654,118 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue(); + if (valueCase_ == ValueOneofCase.List) { + subBuilder.MergeFrom(List); + } + input.ReadMessage(subBuilder); + List = subBuilder; + break; + } + case 18: { + S = input.ReadBytes(); + break; + } + case 24: { + I = input.ReadInt64(); + break; + } + case 37: { + F = input.ReadFloat(); + break; + } + case 40: { + B = input.ReadBool(); + break; + } + case 48: { + value_ = input.ReadEnum(); + valueCase_ = ValueOneofCase.Type; + break; + } + case 58: { + global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); + if (valueCase_ == ValueOneofCase.Shape) { + subBuilder.MergeFrom(Shape); + } + input.ReadMessage(subBuilder); + Shape = subBuilder; + break; + } + case 66: { + global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto(); + if (valueCase_ == ValueOneofCase.Tensor) { + subBuilder.MergeFrom(Tensor); + } + input.ReadMessage(subBuilder); + Tensor = subBuilder; + break; + } + case 74: { + Placeholder = input.ReadString(); + break; + } + case 82: { + global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList(); + if (valueCase_ == ValueOneofCase.Func) { + subBuilder.MergeFrom(Func); + } + input.ReadMessage(subBuilder); + Func = subBuilder; + break; + } + } + } + } + #endif + #region Nested types /// Container for nested types declared in the AttrValue message type. [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static partial class Types { /// /// LINT.IfChange /// - public sealed partial class ListValue : pb::IMessage { + public sealed partial class ListValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListValue()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ListValue() { OnConstruction(); } @@ -597,6 +773,7 @@ public ListValue() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ListValue(ListValue other) : this() { s_ = other.s_.Clone(); i_ = other.i_.Clone(); @@ -610,6 +787,7 @@ public ListValue(ListValue other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ListValue Clone() { return new ListValue(this); } @@ -623,6 +801,7 @@ public ListValue Clone() { /// "list(string)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField S { get { return s_; } } @@ -636,6 +815,7 @@ public ListValue Clone() { /// "list(int)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField I { get { return i_; } } @@ -649,6 +829,7 @@ public ListValue Clone() { /// "list(float)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField F { get { return f_; } } @@ -662,6 +843,7 @@ public ListValue Clone() { /// "list(bool)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField B { get { return b_; } } @@ -675,6 +857,7 @@ public ListValue Clone() { /// "list(type)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Type { get { return type_; } } @@ -688,6 +871,7 @@ public ListValue Clone() { /// "list(shape)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Shape { get { return shape_; } } @@ -701,6 +885,7 @@ public ListValue Clone() { /// "list(tensor)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Tensor { get { return tensor_; } } @@ -714,16 +899,19 @@ public ListValue Clone() { /// "list(attr)" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Func { get { return func_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as ListValue); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(ListValue other) { if (ReferenceEquals(other, null)) { return false; @@ -743,6 +931,7 @@ public bool Equals(ListValue other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; hash ^= s_.GetHashCode(); @@ -760,12 +949,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else s_.WriteTo(output, _repeated_s_codec); i_.WriteTo(output, _repeated_i_codec); f_.WriteTo(output, _repeated_f_codec); @@ -777,9 +971,29 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + s_.WriteTo(ref output, _repeated_s_codec); + i_.WriteTo(ref output, _repeated_i_codec); + f_.WriteTo(ref output, _repeated_f_codec); + b_.WriteTo(ref output, _repeated_b_codec); + type_.WriteTo(ref output, _repeated_type_codec); + shape_.WriteTo(ref output, _repeated_shape_codec); + tensor_.WriteTo(ref output, _repeated_tensor_codec); + func_.WriteTo(ref output, _repeated_func_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; size += s_.CalculateSize(_repeated_s_codec); @@ -797,6 +1011,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(ListValue other) { if (other == null) { return; @@ -813,7 +1028,11 @@ public void MergeFrom(ListValue other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -858,7 +1077,59 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 18: { + s_.AddEntriesFrom(ref input, _repeated_s_codec); + break; + } + case 26: + case 24: { + i_.AddEntriesFrom(ref input, _repeated_i_codec); + break; + } + case 34: + case 37: { + f_.AddEntriesFrom(ref input, _repeated_f_codec); + break; + } + case 42: + case 40: { + b_.AddEntriesFrom(ref input, _repeated_b_codec); + break; + } + case 50: + case 48: { + type_.AddEntriesFrom(ref input, _repeated_type_codec); + break; + } + case 58: { + shape_.AddEntriesFrom(ref input, _repeated_shape_codec); + break; + } + case 66: { + tensor_.AddEntriesFrom(ref input, _repeated_tensor_codec); + break; + } + case 74: { + func_.AddEntriesFrom(ref input, _repeated_func_codec); + break; + } + } + } } + #endif } @@ -871,23 +1142,31 @@ public void MergeFrom(pb::CodedInputStream input) { /// A list of attr names and their values. The whole list is attached /// with a string name. E.g., MatMul[T=float]. /// - public sealed partial class NameAttrList : pb::IMessage { + public sealed partial class NameAttrList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NameAttrList()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NameAttrList() { OnConstruction(); } @@ -895,6 +1174,7 @@ public NameAttrList() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NameAttrList(NameAttrList other) : this() { name_ = other.name_; attr_ = other.attr_.Clone(); @@ -902,6 +1182,7 @@ public NameAttrList(NameAttrList other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NameAttrList Clone() { return new NameAttrList(this); } @@ -910,6 +1191,7 @@ public NameAttrList Clone() { public const int NameFieldNumber = 1; private string name_ = ""; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -920,19 +1202,22 @@ public string Name { /// Field number for the "attr" field. public const int AttrFieldNumber = 2; private static readonly pbc::MapField.Codec _map_attr_codec - = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); private readonly pbc::MapField attr_ = new pbc::MapField(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::MapField Attr { get { return attr_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as NameAttrList); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(NameAttrList other) { if (ReferenceEquals(other, null)) { return false; @@ -946,6 +1231,7 @@ public bool Equals(NameAttrList other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Name.Length != 0) hash ^= Name.GetHashCode(); @@ -957,12 +1243,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Name.Length != 0) { output.WriteRawTag(10); output.WriteString(Name); @@ -971,9 +1262,26 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + attr_.WriteTo(ref output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Name.Length != 0) { @@ -987,6 +1295,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(NameAttrList other) { if (other == null) { return; @@ -999,7 +1308,11 @@ public void MergeFrom(NameAttrList other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -1016,7 +1329,31 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + attr_.AddEntriesFrom(ref input, _map_attr_codec); + break; + } + } + } } + #endif } diff --git a/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs b/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs new file mode 100644 index 000000000..26d929e24 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CheckpointState.cs @@ -0,0 +1,347 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/python/training/checkpoint_state.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/python/training/checkpoint_state.proto + public static partial class CheckpointStateReflection { + + #region Descriptor + /// File descriptor for tensorflow/python/training/checkpoint_state.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CheckpointStateReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjF0ZW5zb3JmbG93L3B5dGhvbi90cmFpbmluZy9jaGVja3BvaW50X3N0YXRl", + "LnByb3RvEgp0ZW5zb3JmbG93Ip8BCg9DaGVja3BvaW50U3RhdGUSHQoVbW9k", + "ZWxfY2hlY2twb2ludF9wYXRoGAEgASgJEiIKGmFsbF9tb2RlbF9jaGVja3Bv", + "aW50X3BhdGhzGAIgAygJEicKH2FsbF9tb2RlbF9jaGVja3BvaW50X3RpbWVz", + "dGFtcHMYAyADKAESIAoYbGFzdF9wcmVzZXJ2ZWRfdGltZXN0YW1wGAQgASgB", + "QgP4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CheckpointState), global::Tensorflow.CheckpointState.Parser, new[]{ "ModelCheckpointPath", "AllModelCheckpointPaths", "AllModelCheckpointTimestamps", "LastPreservedTimestamp" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the checkpoint state. + /// + public sealed partial class CheckpointState : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CheckpointState()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CheckpointState() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CheckpointState(CheckpointState other) : this() { + modelCheckpointPath_ = other.modelCheckpointPath_; + allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); + allModelCheckpointTimestamps_ = other.allModelCheckpointTimestamps_.Clone(); + lastPreservedTimestamp_ = other.lastPreservedTimestamp_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CheckpointState Clone() { + return new CheckpointState(this); + } + + /// Field number for the "model_checkpoint_path" field. + public const int ModelCheckpointPathFieldNumber = 1; + private string modelCheckpointPath_ = ""; + /// + /// Path to the most-recent model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ModelCheckpointPath { + get { return modelCheckpointPath_; } + set { + modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "all_model_checkpoint_paths" field. + public const int AllModelCheckpointPathsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_allModelCheckpointPaths_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField allModelCheckpointPaths_ = new pbc::RepeatedField(); + /// + /// Paths to all not-yet-deleted model checkpoints, sorted from oldest to + /// newest. + /// Note that the value of model_checkpoint_path should be the last item in + /// this list. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AllModelCheckpointPaths { + get { return allModelCheckpointPaths_; } + } + + /// Field number for the "all_model_checkpoint_timestamps" field. + public const int AllModelCheckpointTimestampsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_allModelCheckpointTimestamps_codec + = pb::FieldCodec.ForDouble(26); + private readonly pbc::RepeatedField allModelCheckpointTimestamps_ = new pbc::RepeatedField(); + /// + /// Unix timestamps corresponding to all_model_checkpoint_paths, indicating + /// when each checkpoint was created. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AllModelCheckpointTimestamps { + get { return allModelCheckpointTimestamps_; } + } + + /// Field number for the "last_preserved_timestamp" field. + public const int LastPreservedTimestampFieldNumber = 4; + private double lastPreservedTimestamp_; + /// + /// Unix timestamp indicating the creation time for the last preserved + /// checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double LastPreservedTimestamp { + get { return lastPreservedTimestamp_; } + set { + lastPreservedTimestamp_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CheckpointState); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CheckpointState other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ModelCheckpointPath != other.ModelCheckpointPath) return false; + if(!allModelCheckpointPaths_.Equals(other.allModelCheckpointPaths_)) return false; + if(!allModelCheckpointTimestamps_.Equals(other.allModelCheckpointTimestamps_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(LastPreservedTimestamp, other.LastPreservedTimestamp)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); + hash ^= allModelCheckpointPaths_.GetHashCode(); + hash ^= allModelCheckpointTimestamps_.GetHashCode(); + if (LastPreservedTimestamp != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(LastPreservedTimestamp); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ModelCheckpointPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ModelCheckpointPath); + } + allModelCheckpointPaths_.WriteTo(output, _repeated_allModelCheckpointPaths_codec); + allModelCheckpointTimestamps_.WriteTo(output, _repeated_allModelCheckpointTimestamps_codec); + if (LastPreservedTimestamp != 0D) { + output.WriteRawTag(33); + output.WriteDouble(LastPreservedTimestamp); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ModelCheckpointPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ModelCheckpointPath); + } + allModelCheckpointPaths_.WriteTo(ref output, _repeated_allModelCheckpointPaths_codec); + allModelCheckpointTimestamps_.WriteTo(ref output, _repeated_allModelCheckpointTimestamps_codec); + if (LastPreservedTimestamp != 0D) { + output.WriteRawTag(33); + output.WriteDouble(LastPreservedTimestamp); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ModelCheckpointPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath); + } + size += allModelCheckpointPaths_.CalculateSize(_repeated_allModelCheckpointPaths_codec); + size += allModelCheckpointTimestamps_.CalculateSize(_repeated_allModelCheckpointTimestamps_codec); + if (LastPreservedTimestamp != 0D) { + size += 1 + 8; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CheckpointState other) { + if (other == null) { + return; + } + if (other.ModelCheckpointPath.Length != 0) { + ModelCheckpointPath = other.ModelCheckpointPath; + } + allModelCheckpointPaths_.Add(other.allModelCheckpointPaths_); + allModelCheckpointTimestamps_.Add(other.allModelCheckpointTimestamps_); + if (other.LastPreservedTimestamp != 0D) { + LastPreservedTimestamp = other.LastPreservedTimestamp; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ModelCheckpointPath = input.ReadString(); + break; + } + case 18: { + allModelCheckpointPaths_.AddEntriesFrom(input, _repeated_allModelCheckpointPaths_codec); + break; + } + case 26: + case 25: { + allModelCheckpointTimestamps_.AddEntriesFrom(input, _repeated_allModelCheckpointTimestamps_codec); + break; + } + case 33: { + LastPreservedTimestamp = input.ReadDouble(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ModelCheckpointPath = input.ReadString(); + break; + } + case 18: { + allModelCheckpointPaths_.AddEntriesFrom(ref input, _repeated_allModelCheckpointPaths_codec); + break; + } + case 26: + case 25: { + allModelCheckpointTimestamps_.AddEntriesFrom(ref input, _repeated_allModelCheckpointTimestamps_codec); + break; + } + case 33: { + LastPreservedTimestamp = input.ReadDouble(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Cluster.cs b/src/TensorFlowNET.Core/Protobuf/Cluster.cs new file mode 100644 index 000000000..4c398c824 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Cluster.cs @@ -0,0 +1,463 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/cluster.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/cluster.proto + public static partial class ClusterReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/cluster.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ClusterReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiZ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY2x1c3Rlci5wcm90bxIKdGVu", + "c29yZmxvdyJyCgZKb2JEZWYSDAoEbmFtZRgBIAEoCRIsCgV0YXNrcxgCIAMo", + "CzIdLnRlbnNvcmZsb3cuSm9iRGVmLlRhc2tzRW50cnkaLAoKVGFza3NFbnRy", + "eRILCgNrZXkYASABKAUSDQoFdmFsdWUYAiABKAk6AjgBIi0KCkNsdXN0ZXJE", + "ZWYSHwoDam9iGAEgAygLMhIudGVuc29yZmxvdy5Kb2JEZWZChwEKGm9yZy50", + "ZW5zb3JmbG93LmRpc3RydW50aW1lQg1DbHVzdGVyUHJvdG9zUAFaVWdpdGh1", + "Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29y", + "ZS9wcm90b2J1Zi9mb3JfY29yZV9wcm90b3NfZ29fcHJvdG/4AQFiBnByb3Rv", + "Mw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.JobDef), global::Tensorflow.JobDef.Parser, new[]{ "Name", "Tasks" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ClusterDef), global::Tensorflow.ClusterDef.Parser, new[]{ "Job" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Defines a single job in a TensorFlow cluster. + /// + public sealed partial class JobDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new JobDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public JobDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public JobDef(JobDef other) : this() { + name_ = other.name_; + tasks_ = other.tasks_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public JobDef Clone() { + return new JobDef(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// The name of this job. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tasks" field. + public const int TasksFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_tasks_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForInt32(8, 0), pb::FieldCodec.ForString(18, ""), 18); + private readonly pbc::MapField tasks_ = new pbc::MapField(); + /// + /// Mapping from task ID to "hostname:port" string. + /// + /// If the `name` field contains "worker", and the `tasks` map contains a + /// mapping from 7 to "example.org:2222", then the device prefix + /// "/job:worker/task:7" will be assigned to "example.org:2222". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Tasks { + get { return tasks_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as JobDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(JobDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!Tasks.Equals(other.Tasks)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= Tasks.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + tasks_.WriteTo(output, _map_tasks_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + tasks_.WriteTo(ref output, _map_tasks_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += tasks_.CalculateSize(_map_tasks_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(JobDef other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + tasks_.Add(other.tasks_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + tasks_.AddEntriesFrom(input, _map_tasks_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + tasks_.AddEntriesFrom(ref input, _map_tasks_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Defines a TensorFlow cluster as a set of jobs. + /// + public sealed partial class ClusterDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ClusterDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ClusterDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ClusterDef(ClusterDef other) : this() { + job_ = other.job_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ClusterDef Clone() { + return new ClusterDef(this); + } + + /// Field number for the "job" field. + public const int JobFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_job_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.JobDef.Parser); + private readonly pbc::RepeatedField job_ = new pbc::RepeatedField(); + /// + /// The jobs that comprise the cluster. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Job { + get { return job_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ClusterDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ClusterDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!job_.Equals(other.job_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= job_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + job_.WriteTo(output, _repeated_job_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + job_.WriteTo(ref output, _repeated_job_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += job_.CalculateSize(_repeated_job_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ClusterDef other) { + if (other == null) { + return; + } + job_.Add(other.job_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + job_.AddEntriesFrom(input, _repeated_job_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + job_.AddEntriesFrom(ref input, _repeated_job_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Config.cs b/src/TensorFlowNET.Core/Protobuf/Config.cs new file mode 100644 index 000000000..de7b38637 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Config.cs @@ -0,0 +1,7974 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/config.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/config.proto + public static partial class ConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29uZmlnLnByb3RvEgp0ZW5z", + "b3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgu", + "cHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZ3JhcGgucHJvdG8a", + "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxom", + "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2NsdXN0ZXIucHJvdG8aMnRlbnNv", + "cmZsb3cvY29yZS9wcm90b2J1Zi9jb29yZGluYXRpb25fY29uZmlnLnByb3Rv", + "GiR0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvZGVidWcucHJvdG8aLnRlbnNv", + "cmZsb3cvY29yZS9wcm90b2J1Zi9yZXdyaXRlcl9jb25maWcucHJvdG8i1wYK", + "CkdQVU9wdGlvbnMSJwofcGVyX3Byb2Nlc3NfZ3B1X21lbW9yeV9mcmFjdGlv", + "bhgBIAEoARIUCgxhbGxvd19ncm93dGgYBCABKAgSFgoOYWxsb2NhdG9yX3R5", + "cGUYAiABKAkSHwoXZGVmZXJyZWRfZGVsZXRpb25fYnl0ZXMYAyABKAMSGwoT", + "dmlzaWJsZV9kZXZpY2VfbGlzdBgFIAEoCRIiChpwb2xsaW5nX2FjdGl2ZV9k", + "ZWxheV91c2VjcxgGIAEoBRIkChxwb2xsaW5nX2luYWN0aXZlX2RlbGF5X21z", + "ZWNzGAcgASgFEhwKFGZvcmNlX2dwdV9jb21wYXRpYmxlGAggASgIEjkKDGV4", + "cGVyaW1lbnRhbBgJIAEoCzIjLnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBl", + "cmltZW50YWwakAQKDEV4cGVyaW1lbnRhbBJLCg92aXJ0dWFsX2RldmljZXMY", + "ASADKAsyMi50ZW5zb3JmbG93LkdQVU9wdGlvbnMuRXhwZXJpbWVudGFsLlZp", + "cnR1YWxEZXZpY2VzEhoKEnVzZV91bmlmaWVkX21lbW9yeRgCIAEoCBIjChtu", + "dW1fZGV2X3RvX2Rldl9jb3B5X3N0cmVhbXMYAyABKAUSHQoVY29sbGVjdGl2", + "ZV9yaW5nX29yZGVyGAQgASgJEh0KFXRpbWVzdGFtcGVkX2FsbG9jYXRvchgF", + "IAEoCBIjChtrZXJuZWxfdHJhY2tlcl9tYXhfaW50ZXJ2YWwYByABKAUSIAoY", + "a2VybmVsX3RyYWNrZXJfbWF4X2J5dGVzGAggASgFEiIKGmtlcm5lbF90cmFj", + "a2VyX21heF9wZW5kaW5nGAkgASgFEicKH2ludGVybmFsX2ZyYWdtZW50YXRp", + "b25fZnJhY3Rpb24YCiABKAESHQoVdXNlX2N1ZGFfbWFsbG9jX2FzeW5jGAsg", + "ASgIEiwKJGRpc2FsbG93X3JldHJ5X29uX2FsbG9jYXRpb25fZmFpbHVyZRgM", + "IAEoCBpTCg5WaXJ0dWFsRGV2aWNlcxIXCg9tZW1vcnlfbGltaXRfbWIYASAD", + "KAISEAoIcHJpb3JpdHkYAiADKAUSFgoOZGV2aWNlX29yZGluYWwYAyADKAUi", + "nQMKEE9wdGltaXplck9wdGlvbnMSKwojZG9fY29tbW9uX3N1YmV4cHJlc3Np", + "b25fZWxpbWluYXRpb24YASABKAgSGwoTZG9fY29uc3RhbnRfZm9sZGluZxgC", + "IAEoCBIkChxtYXhfZm9sZGVkX2NvbnN0YW50X2luX2J5dGVzGAYgASgDEhwK", + "FGRvX2Z1bmN0aW9uX2lubGluaW5nGAQgASgIEjUKCW9wdF9sZXZlbBgDIAEo", + "DjIiLnRlbnNvcmZsb3cuT3B0aW1pemVyT3B0aW9ucy5MZXZlbBJFChBnbG9i", + "YWxfaml0X2xldmVsGAUgASgOMisudGVuc29yZmxvdy5PcHRpbWl6ZXJPcHRp", + "b25zLkdsb2JhbEppdExldmVsEhYKDmNwdV9nbG9iYWxfaml0GAcgASgIIiAK", + "BUxldmVsEgYKAkwxEAASDwoCTDAQ////////////ASJDCg5HbG9iYWxKaXRM", + "ZXZlbBILCgdERUZBVUxUEAASEAoDT0ZGEP///////////wESCAoET05fMRAB", + "EggKBE9OXzIQAiLuAgoMR3JhcGhPcHRpb25zEh4KFmVuYWJsZV9yZWN2X3Nj", + "aGVkdWxpbmcYAiABKAgSNwoRb3B0aW1pemVyX29wdGlvbnMYAyABKAsyHC50", + "ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMSGAoQYnVpbGRfY29zdF9tb2Rl", + "bBgEIAEoAxIeChZidWlsZF9jb3N0X21vZGVsX2FmdGVyGAkgASgDEhQKDGlu", + "ZmVyX3NoYXBlcxgFIAEoCBIaChJwbGFjZV9wcnVuZWRfZ3JhcGgYBiABKAgS", + "IAoYZW5hYmxlX2JmbG9hdDE2X3NlbmRyZWN2GAcgASgIEhUKDXRpbWVsaW5l", + "X3N0ZXAYCCABKAUSMwoPcmV3cml0ZV9vcHRpb25zGAogASgLMhoudGVuc29y", + "Zmxvdy5SZXdyaXRlckNvbmZpZ0oECAEQAlIlc2tpcF9jb21tb25fc3ViZXhw", + "cmVzc2lvbl9lbGltaW5hdGlvbiJBChVUaHJlYWRQb29sT3B0aW9uUHJvdG8S", + "EwoLbnVtX3RocmVhZHMYASABKAUSEwoLZ2xvYmFsX25hbWUYAiABKAki1QEK", + "ClJQQ09wdGlvbnMSJAocdXNlX3JwY19mb3JfaW5wcm9jZXNzX21hc3RlchgB", + "IAEoCBIdChVjb21wcmVzc2lvbl9hbGdvcml0aG0YAiABKAkSGQoRY29tcHJl", + "c3Npb25fbGV2ZWwYAyABKAUSGgoSY2FjaGVfcnBjX3Jlc3BvbnNlGAQgASgI", + "EioKImRpc2FibGVfc2Vzc2lvbl9jb25uZWN0aW9uX3NoYXJpbmcYBSABKAgS", + "HwoXbnVtX2NoYW5uZWxzX3Blcl90YXJnZXQYBiABKAUiMAoPU2Vzc2lvbk1l", + "dGFkYXRhEgwKBG5hbWUYASABKAkSDwoHdmVyc2lvbhgCIAEoAyKuDgoLQ29u", + "ZmlnUHJvdG8SPgoMZGV2aWNlX2NvdW50GAEgAygLMigudGVuc29yZmxvdy5D", + "b25maWdQcm90by5EZXZpY2VDb3VudEVudHJ5EiQKHGludHJhX29wX3BhcmFs", + "bGVsaXNtX3RocmVhZHMYAiABKAUSJAocaW50ZXJfb3BfcGFyYWxsZWxpc21f", + "dGhyZWFkcxgFIAEoBRIfChd1c2VfcGVyX3Nlc3Npb25fdGhyZWFkcxgJIAEo", + "CBJHChxzZXNzaW9uX2ludGVyX29wX3RocmVhZF9wb29sGAwgAygLMiEudGVu", + "c29yZmxvdy5UaHJlYWRQb29sT3B0aW9uUHJvdG8SGAoQcGxhY2VtZW50X3Bl", + "cmlvZBgDIAEoBRIWCg5kZXZpY2VfZmlsdGVycxgEIAMoCRIrCgtncHVfb3B0", + "aW9ucxgGIAEoCzIWLnRlbnNvcmZsb3cuR1BVT3B0aW9ucxIcChRhbGxvd19z", + "b2Z0X3BsYWNlbWVudBgHIAEoCBIcChRsb2dfZGV2aWNlX3BsYWNlbWVudBgI", + "IAEoCBIvCg1ncmFwaF9vcHRpb25zGAogASgLMhgudGVuc29yZmxvdy5HcmFw", + "aE9wdGlvbnMSHwoXb3BlcmF0aW9uX3RpbWVvdXRfaW5fbXMYCyABKAMSKwoL", + "cnBjX29wdGlvbnMYDSABKAsyFi50ZW5zb3JmbG93LlJQQ09wdGlvbnMSKwoL", + "Y2x1c3Rlcl9kZWYYDiABKAsyFi50ZW5zb3JmbG93LkNsdXN0ZXJEZWYSHQoV", + "aXNvbGF0ZV9zZXNzaW9uX3N0YXRlGA8gASgIEigKIHNoYXJlX2NsdXN0ZXJf", + "ZGV2aWNlc19pbl9zZXNzaW9uGBEgASgIEjoKDGV4cGVyaW1lbnRhbBgQIAEo", + "CzIkLnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8uRXhwZXJpbWVudGFsGjIKEERl", + "dmljZUNvdW50RW50cnkSCwoDa2V5GAEgASgJEg0KBXZhbHVlGAIgASgFOgI4", + "ARqoCAoMRXhwZXJpbWVudGFsEh8KF2NvbGxlY3RpdmVfZ3JvdXBfbGVhZGVy", + "GAEgASgJEhUKDWV4ZWN1dG9yX3R5cGUYAyABKAkSGgoScmVjdl9idWZfbWF4", + "X2NodW5rGAQgASgFEhkKEXVzZV9udW1hX2FmZmluaXR5GAUgASgIEjUKLWNv", + "bGxlY3RpdmVfZGV0ZXJtaW5pc3RpY19zZXF1ZW50aWFsX2V4ZWN1dGlvbhgG", + "IAEoCBIXCg9jb2xsZWN0aXZlX25jY2wYByABKAgSNgouc2hhcmVfc2Vzc2lv", + "bl9zdGF0ZV9pbl9jbHVzdGVyc3BlY19wcm9wYWdhdGlvbhgIIAEoCBIfChdk", + "aXNhYmxlX3RocmVhZF9zcGlubmluZxgJIAEoCBIoCiBzaGFyZV9jbHVzdGVy", + "X2RldmljZXNfaW5fc2Vzc2lvbhgKIAEoCBI1ChBzZXNzaW9uX21ldGFkYXRh", + "GAsgASgLMhsudGVuc29yZmxvdy5TZXNzaW9uTWV0YWRhdGESIQoZb3B0aW1p", + "emVfZm9yX3N0YXRpY19ncmFwaBgMIAEoCBIaChJlbmFibGVfbWxpcl9icmlk", + "Z2UYDSABKAgSUwoTbWxpcl9icmlkZ2Vfcm9sbG91dBgRIAEoDjI2LnRlbnNv", + "cmZsb3cuQ29uZmlnUHJvdG8uRXhwZXJpbWVudGFsLk1saXJCcmlkZ2VSb2xs", + "b3V0EiYKHmVuYWJsZV9tbGlyX2dyYXBoX29wdGltaXphdGlvbhgQIAEoCBIn", + "Ch9kaXNhYmxlX291dHB1dF9wYXJ0aXRpb25fZ3JhcGhzGA4gASgIEiMKG3hs", + "YV9mdXNpb25fYXV0b3R1bmVyX3RocmVzaBgPIAEoAxIQCgh1c2VfdGZydBgS", + "IAEoCBInCh9kaXNhYmxlX2Z1bmN0aW9uYWxfb3BzX2xvd2VyaW5nGBUgASgI", + "EicKH3hsYV9wcmVmZXJfc2luZ2xlX2dyYXBoX2NsdXN0ZXIYFiABKAgSQgoT", + "Y29vcmRpbmF0aW9uX2NvbmZpZxgXIAEoCzIlLnRlbnNvcmZsb3cuQ29vcmRp", + "bmF0aW9uU2VydmljZUNvbmZpZyLaAQoRTWxpckJyaWRnZVJvbGxvdXQSIwof", + "TUxJUl9CUklER0VfUk9MTE9VVF9VTlNQRUNJRklFRBAAEh8KG01MSVJfQlJJ", + "REdFX1JPTExPVVRfRU5BQkxFRBABEiAKHE1MSVJfQlJJREdFX1JPTExPVVRf", + "RElTQUJMRUQQAhIpCiVNTElSX0JSSURHRV9ST0xMT1VUX1NBRkVfTU9ERV9F", + "TkFCTEVEEAMSMgouTUxJUl9CUklER0VfUk9MTE9VVF9TQUZFX01PREVfRkFM", + "TEJBQ0tfRU5BQkxFRBAESgQIAhADSgQIExAUSgQIFBAVIuEECgpSdW5PcHRp", + "b25zEjYKC3RyYWNlX2xldmVsGAEgASgOMiEudGVuc29yZmxvdy5SdW5PcHRp", + "b25zLlRyYWNlTGV2ZWwSFQoNdGltZW91dF9pbl9tcxgCIAEoAxIcChRpbnRl", + "cl9vcF90aHJlYWRfcG9vbBgDIAEoBRIfChdvdXRwdXRfcGFydGl0aW9uX2dy", + "YXBocxgFIAEoCBIvCg1kZWJ1Z19vcHRpb25zGAYgASgLMhgudGVuc29yZmxv", + "dy5EZWJ1Z09wdGlvbnMSKgoicmVwb3J0X3RlbnNvcl9hbGxvY2F0aW9uc191", + "cG9uX29vbRgHIAEoCBI5CgxleHBlcmltZW50YWwYCCABKAsyIy50ZW5zb3Jm", + "bG93LlJ1bk9wdGlvbnMuRXhwZXJpbWVudGFsGtIBCgxFeHBlcmltZW50YWwS", + "HAoUY29sbGVjdGl2ZV9ncmFwaF9rZXkYASABKAMSHAoUdXNlX3J1bl9oYW5k", + "bGVyX3Bvb2wYAiABKAgSWwoYcnVuX2hhbmRsZXJfcG9vbF9vcHRpb25zGAMg", + "ASgLMjkudGVuc29yZmxvdy5SdW5PcHRpb25zLkV4cGVyaW1lbnRhbC5SdW5I", + "YW5kbGVyUG9vbE9wdGlvbnMaKQoVUnVuSGFuZGxlclBvb2xPcHRpb25zEhAK", + "CHByaW9yaXR5GAEgASgDIlIKClRyYWNlTGV2ZWwSDAoITk9fVFJBQ0UQABIS", + "Cg5TT0ZUV0FSRV9UUkFDRRABEhIKDkhBUkRXQVJFX1RSQUNFEAISDgoKRlVM", + "TF9UUkFDRRADSgQIBBAFIr4DCgtSdW5NZXRhZGF0YRIpCgpzdGVwX3N0YXRz", + "GAEgASgLMhUudGVuc29yZmxvdy5TdGVwU3RhdHMSLAoKY29zdF9ncmFwaBgC", + "IAEoCzIYLnRlbnNvcmZsb3cuQ29zdEdyYXBoRGVmEi4KEHBhcnRpdGlvbl9n", + "cmFwaHMYAyADKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVmEj8KD2Z1bmN0aW9u", + "X2dyYXBocxgEIAMoCzImLnRlbnNvcmZsb3cuUnVuTWV0YWRhdGEuRnVuY3Rp", + "b25HcmFwaHMSNQoQc2Vzc2lvbl9tZXRhZGF0YRgFIAEoCzIbLnRlbnNvcmZs", + "b3cuU2Vzc2lvbk1ldGFkYXRhGq0BCg5GdW5jdGlvbkdyYXBocxIuChBwYXJ0", + "aXRpb25fZ3JhcGhzGAEgAygLMhQudGVuc29yZmxvdy5HcmFwaERlZhI0ChZw", + "cmVfb3B0aW1pemF0aW9uX2dyYXBoGAIgASgLMhQudGVuc29yZmxvdy5HcmFw", + "aERlZhI1Chdwb3N0X29wdGltaXphdGlvbl9ncmFwaBgDIAEoCzIULnRlbnNv", + "cmZsb3cuR3JhcGhEZWYiOgoQVGVuc29yQ29ubmVjdGlvbhITCgtmcm9tX3Rl", + "bnNvchgBIAEoCRIRCgl0b190ZW5zb3IYAiABKAkisAMKD0NhbGxhYmxlT3B0", + "aW9ucxIMCgRmZWVkGAEgAygJEg0KBWZldGNoGAIgAygJEg4KBnRhcmdldBgD", + "IAMoCRIrCgtydW5fb3B0aW9ucxgEIAEoCzIWLnRlbnNvcmZsb3cuUnVuT3B0", + "aW9ucxI3ChF0ZW5zb3JfY29ubmVjdGlvbhgFIAMoCzIcLnRlbnNvcmZsb3cu", + "VGVuc29yQ29ubmVjdGlvbhJCCgxmZWVkX2RldmljZXMYBiADKAsyLC50ZW5z", + "b3JmbG93LkNhbGxhYmxlT3B0aW9ucy5GZWVkRGV2aWNlc0VudHJ5EkQKDWZl", + "dGNoX2RldmljZXMYByADKAsyLS50ZW5zb3JmbG93LkNhbGxhYmxlT3B0aW9u", + "cy5GZXRjaERldmljZXNFbnRyeRIXCg9mZXRjaF9za2lwX3N5bmMYCCABKAga", + "MgoQRmVlZERldmljZXNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiAB", + "KAk6AjgBGjMKEUZldGNoRGV2aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2", + "YWx1ZRgCIAEoCToCOAFChAEKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IM", + "Q29uZmlnUHJvdG9zUAFaVWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3Jm", + "bG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1Zi9mb3JfY29yZV9wcm90", + "b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.CoordinationConfigReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending", "InternalFragmentationFraction", "UseCudaMallocAsync", "DisallowRetryOnAllocationFailure" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb", "Priority", "DeviceOrdinal" }, null, null, null, null)})}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel", "CpuGlobalJit" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel", "CacheRpcResponse", "DisableSessionConnectionSharing", "NumChannelsPerTarget" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SessionMetadata), global::Tensorflow.SessionMetadata.Parser, new[]{ "Name", "Version" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "ShareClusterDevicesInSession", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession", "SessionMetadata", "OptimizeForStaticGraph", "EnableMlirBridge", "MlirBridgeRollout", "EnableMlirGraphOptimization", "DisableOutputPartitionGraphs", "XlaFusionAutotunerThresh", "UseTfrt", "DisableFunctionalOpsLowering", "XlaPreferSingleGraphCluster", "CoordinationConfig" }, null, new[]{ typeof(global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout) }, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool", "RunHandlerPoolOptions" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions), global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions.Parser, new[]{ "Priority" }, null, null, null, null)})}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs", "FunctionGraphs", "SessionMetadata" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata.Types.FunctionGraphs), global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser, new[]{ "PartitionGraphs", "PreOptimizationGraph", "PostOptimizationGraph" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CallableOptions), global::Tensorflow.CallableOptions.Parser, new[]{ "Feed", "Fetch", "Target", "RunOptions", "TensorConnection", "FeedDevices", "FetchDevices", "FetchSkipSync" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }) + })); + } + #endregion + + } + #region Messages + public sealed partial class GPUOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GPUOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GPUOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GPUOptions(GPUOptions other) : this() { + perProcessGpuMemoryFraction_ = other.perProcessGpuMemoryFraction_; + allowGrowth_ = other.allowGrowth_; + allocatorType_ = other.allocatorType_; + deferredDeletionBytes_ = other.deferredDeletionBytes_; + visibleDeviceList_ = other.visibleDeviceList_; + pollingActiveDelayUsecs_ = other.pollingActiveDelayUsecs_; + pollingInactiveDelayMsecs_ = other.pollingInactiveDelayMsecs_; + forceGpuCompatible_ = other.forceGpuCompatible_; + experimental_ = other.experimental_ != null ? other.experimental_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GPUOptions Clone() { + return new GPUOptions(this); + } + + /// Field number for the "per_process_gpu_memory_fraction" field. + public const int PerProcessGpuMemoryFractionFieldNumber = 1; + private double perProcessGpuMemoryFraction_; + /// + /// Fraction of the available GPU memory to allocate for each process. + /// 1 means to allocate all of the GPU memory, 0.5 means the process + /// allocates up to ~50% of the available GPU memory. + /// + /// GPU memory is pre-allocated unless the allow_growth option is enabled. + /// + /// If greater than 1.0, uses CUDA unified memory to potentially oversubscribe + /// the amount of memory available on the GPU device by using host memory as a + /// swap space. Accessing memory not available on the device will be + /// significantly slower as that would require memory transfer between the host + /// and the device. Options to reduce the memory requirement should be + /// considered before enabling this option as this may come with a negative + /// performance impact. Oversubscription using the unified memory requires + /// Pascal class or newer GPUs and it is currently only supported on the Linux + /// operating system. See + /// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements + /// for the detailed requirements. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double PerProcessGpuMemoryFraction { + get { return perProcessGpuMemoryFraction_; } + set { + perProcessGpuMemoryFraction_ = value; + } + } + + /// Field number for the "allow_growth" field. + public const int AllowGrowthFieldNumber = 4; + private bool allowGrowth_; + /// + /// If true, the allocator does not pre-allocate the entire specified + /// GPU memory region, instead starting small and growing as needed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool AllowGrowth { + get { return allowGrowth_; } + set { + allowGrowth_ = value; + } + } + + /// Field number for the "allocator_type" field. + public const int AllocatorTypeFieldNumber = 2; + private string allocatorType_ = ""; + /// + /// The type of GPU allocation strategy to use. + /// + /// Allowed values: + /// "": The empty string (default) uses a system-chosen default + /// which may change over time. + /// + /// "BFC": A "Best-fit with coalescing" algorithm, simplified from a + /// version of dlmalloc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorType { + get { return allocatorType_; } + set { + allocatorType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "deferred_deletion_bytes" field. + public const int DeferredDeletionBytesFieldNumber = 3; + private long deferredDeletionBytes_; + /// + /// Delay deletion of up to this many bytes to reduce the number of + /// interactions with gpu driver code. If 0, the system chooses + /// a reasonable default (several MBs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DeferredDeletionBytes { + get { return deferredDeletionBytes_; } + set { + deferredDeletionBytes_ = value; + } + } + + /// Field number for the "visible_device_list" field. + public const int VisibleDeviceListFieldNumber = 5; + private string visibleDeviceList_ = ""; + /// + /// A comma-separated list of GPU ids that determines the 'visible' + /// to 'virtual' mapping of GPU devices. For example, if TensorFlow + /// can see 8 GPU devices in the process, and one wanted to map + /// visible GPU devices 5 and 3 as "/device:GPU:0", and "/device:GPU:1", + /// then one would specify this field as "5,3". This field is similar in + /// spirit to the CUDA_VISIBLE_DEVICES environment variable, except + /// it applies to the visible GPU devices in the process. + /// + /// NOTE: + /// 1. The GPU driver provides the process with the visible GPUs + /// in an order which is not guaranteed to have any correlation to + /// the *physical* GPU id in the machine. This field is used for + /// remapping "visible" to "virtual", which means this operates only + /// after the process starts. Users are required to use vendor + /// specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the + /// physical to visible device mapping prior to invoking TensorFlow. + /// 2. In the code, the ids in this list are also called "platform GPU id"s, + /// and the 'virtual' ids of GPU devices (i.e. the ids in the device + /// name "/device:GPU:<id>") are also called "TF GPU id"s. Please + /// refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h + /// for more information. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string VisibleDeviceList { + get { return visibleDeviceList_; } + set { + visibleDeviceList_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "polling_active_delay_usecs" field. + public const int PollingActiveDelayUsecsFieldNumber = 6; + private int pollingActiveDelayUsecs_; + /// + /// In the event polling loop sleep this many microseconds between + /// PollEvents calls, when the queue is not empty. If value is not + /// set or set to 0, gets set to a non-zero default. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PollingActiveDelayUsecs { + get { return pollingActiveDelayUsecs_; } + set { + pollingActiveDelayUsecs_ = value; + } + } + + /// Field number for the "polling_inactive_delay_msecs" field. + public const int PollingInactiveDelayMsecsFieldNumber = 7; + private int pollingInactiveDelayMsecs_; + /// + /// This field is deprecated and ignored. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PollingInactiveDelayMsecs { + get { return pollingInactiveDelayMsecs_; } + set { + pollingInactiveDelayMsecs_ = value; + } + } + + /// Field number for the "force_gpu_compatible" field. + public const int ForceGpuCompatibleFieldNumber = 8; + private bool forceGpuCompatible_; + /// + /// Force all tensors to be gpu_compatible. On a GPU-enabled TensorFlow, + /// enabling this option forces all CPU tensors to be allocated with Cuda + /// pinned memory. Normally, TensorFlow will infer which tensors should be + /// allocated as the pinned memory. But in case where the inference is + /// incomplete, this option can significantly speed up the cross-device memory + /// copy performance as long as it fits the memory. + /// Note that this option is not something that should be + /// enabled by default for unknown or very large models, since all Cuda pinned + /// memory is unpageable, having too much pinned memory might negatively impact + /// the overall host system performance. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ForceGpuCompatible { + get { return forceGpuCompatible_; } + set { + forceGpuCompatible_ = value; + } + } + + /// Field number for the "experimental" field. + public const int ExperimentalFieldNumber = 9; + private global::Tensorflow.GPUOptions.Types.Experimental experimental_; + /// + /// Everything inside experimental is subject to change and is not subject + /// to API stability guarantees in + /// https://www.tensorflow.org/guide/version_compat. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GPUOptions.Types.Experimental Experimental { + get { return experimental_; } + set { + experimental_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GPUOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GPUOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(PerProcessGpuMemoryFraction, other.PerProcessGpuMemoryFraction)) return false; + if (AllowGrowth != other.AllowGrowth) return false; + if (AllocatorType != other.AllocatorType) return false; + if (DeferredDeletionBytes != other.DeferredDeletionBytes) return false; + if (VisibleDeviceList != other.VisibleDeviceList) return false; + if (PollingActiveDelayUsecs != other.PollingActiveDelayUsecs) return false; + if (PollingInactiveDelayMsecs != other.PollingInactiveDelayMsecs) return false; + if (ForceGpuCompatible != other.ForceGpuCompatible) return false; + if (!object.Equals(Experimental, other.Experimental)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (PerProcessGpuMemoryFraction != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(PerProcessGpuMemoryFraction); + if (AllowGrowth != false) hash ^= AllowGrowth.GetHashCode(); + if (AllocatorType.Length != 0) hash ^= AllocatorType.GetHashCode(); + if (DeferredDeletionBytes != 0L) hash ^= DeferredDeletionBytes.GetHashCode(); + if (VisibleDeviceList.Length != 0) hash ^= VisibleDeviceList.GetHashCode(); + if (PollingActiveDelayUsecs != 0) hash ^= PollingActiveDelayUsecs.GetHashCode(); + if (PollingInactiveDelayMsecs != 0) hash ^= PollingInactiveDelayMsecs.GetHashCode(); + if (ForceGpuCompatible != false) hash ^= ForceGpuCompatible.GetHashCode(); + if (experimental_ != null) hash ^= Experimental.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (PerProcessGpuMemoryFraction != 0D) { + output.WriteRawTag(9); + output.WriteDouble(PerProcessGpuMemoryFraction); + } + if (AllocatorType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(AllocatorType); + } + if (DeferredDeletionBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(DeferredDeletionBytes); + } + if (AllowGrowth != false) { + output.WriteRawTag(32); + output.WriteBool(AllowGrowth); + } + if (VisibleDeviceList.Length != 0) { + output.WriteRawTag(42); + output.WriteString(VisibleDeviceList); + } + if (PollingActiveDelayUsecs != 0) { + output.WriteRawTag(48); + output.WriteInt32(PollingActiveDelayUsecs); + } + if (PollingInactiveDelayMsecs != 0) { + output.WriteRawTag(56); + output.WriteInt32(PollingInactiveDelayMsecs); + } + if (ForceGpuCompatible != false) { + output.WriteRawTag(64); + output.WriteBool(ForceGpuCompatible); + } + if (experimental_ != null) { + output.WriteRawTag(74); + output.WriteMessage(Experimental); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (PerProcessGpuMemoryFraction != 0D) { + output.WriteRawTag(9); + output.WriteDouble(PerProcessGpuMemoryFraction); + } + if (AllocatorType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(AllocatorType); + } + if (DeferredDeletionBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(DeferredDeletionBytes); + } + if (AllowGrowth != false) { + output.WriteRawTag(32); + output.WriteBool(AllowGrowth); + } + if (VisibleDeviceList.Length != 0) { + output.WriteRawTag(42); + output.WriteString(VisibleDeviceList); + } + if (PollingActiveDelayUsecs != 0) { + output.WriteRawTag(48); + output.WriteInt32(PollingActiveDelayUsecs); + } + if (PollingInactiveDelayMsecs != 0) { + output.WriteRawTag(56); + output.WriteInt32(PollingInactiveDelayMsecs); + } + if (ForceGpuCompatible != false) { + output.WriteRawTag(64); + output.WriteBool(ForceGpuCompatible); + } + if (experimental_ != null) { + output.WriteRawTag(74); + output.WriteMessage(Experimental); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (PerProcessGpuMemoryFraction != 0D) { + size += 1 + 8; + } + if (AllowGrowth != false) { + size += 1 + 1; + } + if (AllocatorType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorType); + } + if (DeferredDeletionBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DeferredDeletionBytes); + } + if (VisibleDeviceList.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VisibleDeviceList); + } + if (PollingActiveDelayUsecs != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PollingActiveDelayUsecs); + } + if (PollingInactiveDelayMsecs != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PollingInactiveDelayMsecs); + } + if (ForceGpuCompatible != false) { + size += 1 + 1; + } + if (experimental_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Experimental); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GPUOptions other) { + if (other == null) { + return; + } + if (other.PerProcessGpuMemoryFraction != 0D) { + PerProcessGpuMemoryFraction = other.PerProcessGpuMemoryFraction; + } + if (other.AllowGrowth != false) { + AllowGrowth = other.AllowGrowth; + } + if (other.AllocatorType.Length != 0) { + AllocatorType = other.AllocatorType; + } + if (other.DeferredDeletionBytes != 0L) { + DeferredDeletionBytes = other.DeferredDeletionBytes; + } + if (other.VisibleDeviceList.Length != 0) { + VisibleDeviceList = other.VisibleDeviceList; + } + if (other.PollingActiveDelayUsecs != 0) { + PollingActiveDelayUsecs = other.PollingActiveDelayUsecs; + } + if (other.PollingInactiveDelayMsecs != 0) { + PollingInactiveDelayMsecs = other.PollingInactiveDelayMsecs; + } + if (other.ForceGpuCompatible != false) { + ForceGpuCompatible = other.ForceGpuCompatible; + } + if (other.experimental_ != null) { + if (experimental_ == null) { + Experimental = new global::Tensorflow.GPUOptions.Types.Experimental(); + } + Experimental.MergeFrom(other.Experimental); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + PerProcessGpuMemoryFraction = input.ReadDouble(); + break; + } + case 18: { + AllocatorType = input.ReadString(); + break; + } + case 24: { + DeferredDeletionBytes = input.ReadInt64(); + break; + } + case 32: { + AllowGrowth = input.ReadBool(); + break; + } + case 42: { + VisibleDeviceList = input.ReadString(); + break; + } + case 48: { + PollingActiveDelayUsecs = input.ReadInt32(); + break; + } + case 56: { + PollingInactiveDelayMsecs = input.ReadInt32(); + break; + } + case 64: { + ForceGpuCompatible = input.ReadBool(); + break; + } + case 74: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.GPUOptions.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + PerProcessGpuMemoryFraction = input.ReadDouble(); + break; + } + case 18: { + AllocatorType = input.ReadString(); + break; + } + case 24: { + DeferredDeletionBytes = input.ReadInt64(); + break; + } + case 32: { + AllowGrowth = input.ReadBool(); + break; + } + case 42: { + VisibleDeviceList = input.ReadString(); + break; + } + case 48: { + PollingActiveDelayUsecs = input.ReadInt32(); + break; + } + case 56: { + PollingInactiveDelayMsecs = input.ReadInt32(); + break; + } + case 64: { + ForceGpuCompatible = input.ReadBool(); + break; + } + case 74: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.GPUOptions.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the GPUOptions message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class Experimental : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Experimental()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GPUOptions.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental(Experimental other) : this() { + virtualDevices_ = other.virtualDevices_.Clone(); + useUnifiedMemory_ = other.useUnifiedMemory_; + numDevToDevCopyStreams_ = other.numDevToDevCopyStreams_; + collectiveRingOrder_ = other.collectiveRingOrder_; + timestampedAllocator_ = other.timestampedAllocator_; + kernelTrackerMaxInterval_ = other.kernelTrackerMaxInterval_; + kernelTrackerMaxBytes_ = other.kernelTrackerMaxBytes_; + kernelTrackerMaxPending_ = other.kernelTrackerMaxPending_; + internalFragmentationFraction_ = other.internalFragmentationFraction_; + useCudaMallocAsync_ = other.useCudaMallocAsync_; + disallowRetryOnAllocationFailure_ = other.disallowRetryOnAllocationFailure_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental Clone() { + return new Experimental(this); + } + + /// Field number for the "virtual_devices" field. + public const int VirtualDevicesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_virtualDevices_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser); + private readonly pbc::RepeatedField virtualDevices_ = new pbc::RepeatedField(); + /// + /// The multi virtual device settings. If empty (not set), it will create + /// single virtual device on each visible GPU, according to the settings + /// in "visible_device_list" above. Otherwise, the number of elements in the + /// list must be the same as the number of visible GPUs (after + /// "visible_device_list" filtering if it is set), and the string represented + /// device names (e.g. /device:GPU:<id>) will refer to the virtual + /// devices and have the <id> field assigned sequentially starting from 0, + /// according to the order of the virtual devices determined by + /// device_ordinal and the location in the virtual device list. + /// + /// For example, + /// visible_device_list = "1,0" + /// virtual_devices { memory_limit: 1GB memory_limit: 2GB } + /// virtual_devices { memory_limit: 3GB memory_limit: 4GB } + /// will create 4 virtual devices as: + /// /device:GPU:0 -> visible GPU 1 with 1GB memory + /// /device:GPU:1 -> visible GPU 1 with 2GB memory + /// /device:GPU:2 -> visible GPU 0 with 3GB memory + /// /device:GPU:3 -> visible GPU 0 with 4GB memory + /// + /// but + /// visible_device_list = "1,0" + /// virtual_devices { memory_limit: 1GB memory_limit: 2GB + /// device_ordinal: 10 device_ordinal: 20} + /// virtual_devices { memory_limit: 3GB memory_limit: 4GB + /// device_ordinal: 10 device_ordinal: 20} + /// will create 4 virtual devices as: + /// /device:GPU:0 -> visible GPU 1 with 1GB memory (ordinal 10) + /// /device:GPU:1 -> visible GPU 0 with 3GB memory (ordinal 10) + /// /device:GPU:2 -> visible GPU 1 with 2GB memory (ordinal 20) + /// /device:GPU:3 -> visible GPU 0 with 4GB memory (ordinal 20) + /// + /// NOTE: + /// 1. It's invalid to set both this and "per_process_gpu_memory_fraction" + /// at the same time. + /// 2. Currently this setting is per-process, not per-session. Using + /// different settings in different sessions within same process will + /// result in undefined behavior. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField VirtualDevices { + get { return virtualDevices_; } + } + + /// Field number for the "use_unified_memory" field. + public const int UseUnifiedMemoryFieldNumber = 2; + private bool useUnifiedMemory_; + /// + /// If true, uses CUDA unified memory for memory allocations. If + /// per_process_gpu_memory_fraction option is greater than 1.0, then unified + /// memory is used regardless of the value for this field. See comments for + /// per_process_gpu_memory_fraction field for more details and requirements + /// of the unified memory. This option is useful to oversubscribe memory if + /// multiple processes are sharing a single GPU while individually using less + /// than 1.0 per process memory fraction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseUnifiedMemory { + get { return useUnifiedMemory_; } + set { + useUnifiedMemory_ = value; + } + } + + /// Field number for the "num_dev_to_dev_copy_streams" field. + public const int NumDevToDevCopyStreamsFieldNumber = 3; + private int numDevToDevCopyStreams_; + /// + /// If > 1, the number of device-to-device copy streams to create + /// for each GPUDevice. Default value is 0, which is automatically + /// converted to 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumDevToDevCopyStreams { + get { return numDevToDevCopyStreams_; } + set { + numDevToDevCopyStreams_ = value; + } + } + + /// Field number for the "collective_ring_order" field. + public const int CollectiveRingOrderFieldNumber = 4; + private string collectiveRingOrder_ = ""; + /// + /// If non-empty, defines a good GPU ring order on a single worker based on + /// device interconnect. This assumes that all workers have the same GPU + /// topology. Specify as a comma-separated string, e.g. "3,2,1,0,7,6,5,4". + /// This ring order is used by the RingReducer implementation of + /// CollectiveReduce, and serves as an override to automatic ring order + /// generation in OrderTaskDeviceMap() during CollectiveParam resolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CollectiveRingOrder { + get { return collectiveRingOrder_; } + set { + collectiveRingOrder_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "timestamped_allocator" field. + public const int TimestampedAllocatorFieldNumber = 5; + private bool timestampedAllocator_; + /// + /// If true then extra work is done by GPUDevice and GPUBFCAllocator to + /// keep track of when GPU memory is freed and when kernels actually + /// complete so that we can know when a nominally free memory chunk + /// is really not subject to pending use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool TimestampedAllocator { + get { return timestampedAllocator_; } + set { + timestampedAllocator_ = value; + } + } + + /// Field number for the "kernel_tracker_max_interval" field. + public const int KernelTrackerMaxIntervalFieldNumber = 7; + private int kernelTrackerMaxInterval_; + /// + /// Parameters for GPUKernelTracker. By default no kernel tracking is done. + /// Note that timestamped_allocator is only effective if some tracking is + /// specified. + /// + /// If kernel_tracker_max_interval = n > 0, then a tracking event + /// is inserted after every n kernels without an event. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int KernelTrackerMaxInterval { + get { return kernelTrackerMaxInterval_; } + set { + kernelTrackerMaxInterval_ = value; + } + } + + /// Field number for the "kernel_tracker_max_bytes" field. + public const int KernelTrackerMaxBytesFieldNumber = 8; + private int kernelTrackerMaxBytes_; + /// + /// If kernel_tracker_max_bytes = n > 0, then a tracking event is + /// inserted after every series of kernels allocating a sum of + /// memory >= n. If one kernel allocates b * n bytes, then one + /// event will be inserted after it, but it will count as b against + /// the pending limit. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int KernelTrackerMaxBytes { + get { return kernelTrackerMaxBytes_; } + set { + kernelTrackerMaxBytes_ = value; + } + } + + /// Field number for the "kernel_tracker_max_pending" field. + public const int KernelTrackerMaxPendingFieldNumber = 9; + private int kernelTrackerMaxPending_; + /// + /// If kernel_tracker_max_pending > 0 then no more than this many + /// tracking events can be outstanding at a time. An attempt to + /// launch an additional kernel will stall until an event + /// completes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int KernelTrackerMaxPending { + get { return kernelTrackerMaxPending_; } + set { + kernelTrackerMaxPending_ = value; + } + } + + /// Field number for the "internal_fragmentation_fraction" field. + public const int InternalFragmentationFractionFieldNumber = 10; + private double internalFragmentationFraction_; + /// + /// BFC Allocator can return an allocated chunk of memory upto 2x the + /// requested size. For virtual devices with tight memory constraints, and + /// proportionately large allocation requests, this can lead to a significant + /// reduction in available memory. The threshold below controls when a chunk + /// should be split if the chunk size exceeds requested memory size. It is + /// expressed as a fraction of total available memory for the tf device. For + /// example setting it to 0.05 would imply a chunk needs to be split if its + /// size exceeds the requested memory by 5% of the total virtual device/gpu + /// memory size. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double InternalFragmentationFraction { + get { return internalFragmentationFraction_; } + set { + internalFragmentationFraction_ = value; + } + } + + /// Field number for the "use_cuda_malloc_async" field. + public const int UseCudaMallocAsyncFieldNumber = 11; + private bool useCudaMallocAsync_; + /// + /// When true, use CUDA cudaMallocAsync API instead of TF gpu allocator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseCudaMallocAsync { + get { return useCudaMallocAsync_; } + set { + useCudaMallocAsync_ = value; + } + } + + /// Field number for the "disallow_retry_on_allocation_failure" field. + public const int DisallowRetryOnAllocationFailureFieldNumber = 12; + private bool disallowRetryOnAllocationFailure_; + /// + /// By default, BFCAllocator may sleep when it runs out of memory, in the + /// hopes that another thread will free up memory in the meantime. Setting + /// this to true disables the sleep; instead we'll OOM immediately. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisallowRetryOnAllocationFailure { + get { return disallowRetryOnAllocationFailure_; } + set { + disallowRetryOnAllocationFailure_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Experimental); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Experimental other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!virtualDevices_.Equals(other.virtualDevices_)) return false; + if (UseUnifiedMemory != other.UseUnifiedMemory) return false; + if (NumDevToDevCopyStreams != other.NumDevToDevCopyStreams) return false; + if (CollectiveRingOrder != other.CollectiveRingOrder) return false; + if (TimestampedAllocator != other.TimestampedAllocator) return false; + if (KernelTrackerMaxInterval != other.KernelTrackerMaxInterval) return false; + if (KernelTrackerMaxBytes != other.KernelTrackerMaxBytes) return false; + if (KernelTrackerMaxPending != other.KernelTrackerMaxPending) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(InternalFragmentationFraction, other.InternalFragmentationFraction)) return false; + if (UseCudaMallocAsync != other.UseCudaMallocAsync) return false; + if (DisallowRetryOnAllocationFailure != other.DisallowRetryOnAllocationFailure) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= virtualDevices_.GetHashCode(); + if (UseUnifiedMemory != false) hash ^= UseUnifiedMemory.GetHashCode(); + if (NumDevToDevCopyStreams != 0) hash ^= NumDevToDevCopyStreams.GetHashCode(); + if (CollectiveRingOrder.Length != 0) hash ^= CollectiveRingOrder.GetHashCode(); + if (TimestampedAllocator != false) hash ^= TimestampedAllocator.GetHashCode(); + if (KernelTrackerMaxInterval != 0) hash ^= KernelTrackerMaxInterval.GetHashCode(); + if (KernelTrackerMaxBytes != 0) hash ^= KernelTrackerMaxBytes.GetHashCode(); + if (KernelTrackerMaxPending != 0) hash ^= KernelTrackerMaxPending.GetHashCode(); + if (InternalFragmentationFraction != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(InternalFragmentationFraction); + if (UseCudaMallocAsync != false) hash ^= UseCudaMallocAsync.GetHashCode(); + if (DisallowRetryOnAllocationFailure != false) hash ^= DisallowRetryOnAllocationFailure.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + virtualDevices_.WriteTo(output, _repeated_virtualDevices_codec); + if (UseUnifiedMemory != false) { + output.WriteRawTag(16); + output.WriteBool(UseUnifiedMemory); + } + if (NumDevToDevCopyStreams != 0) { + output.WriteRawTag(24); + output.WriteInt32(NumDevToDevCopyStreams); + } + if (CollectiveRingOrder.Length != 0) { + output.WriteRawTag(34); + output.WriteString(CollectiveRingOrder); + } + if (TimestampedAllocator != false) { + output.WriteRawTag(40); + output.WriteBool(TimestampedAllocator); + } + if (KernelTrackerMaxInterval != 0) { + output.WriteRawTag(56); + output.WriteInt32(KernelTrackerMaxInterval); + } + if (KernelTrackerMaxBytes != 0) { + output.WriteRawTag(64); + output.WriteInt32(KernelTrackerMaxBytes); + } + if (KernelTrackerMaxPending != 0) { + output.WriteRawTag(72); + output.WriteInt32(KernelTrackerMaxPending); + } + if (InternalFragmentationFraction != 0D) { + output.WriteRawTag(81); + output.WriteDouble(InternalFragmentationFraction); + } + if (UseCudaMallocAsync != false) { + output.WriteRawTag(88); + output.WriteBool(UseCudaMallocAsync); + } + if (DisallowRetryOnAllocationFailure != false) { + output.WriteRawTag(96); + output.WriteBool(DisallowRetryOnAllocationFailure); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + virtualDevices_.WriteTo(ref output, _repeated_virtualDevices_codec); + if (UseUnifiedMemory != false) { + output.WriteRawTag(16); + output.WriteBool(UseUnifiedMemory); + } + if (NumDevToDevCopyStreams != 0) { + output.WriteRawTag(24); + output.WriteInt32(NumDevToDevCopyStreams); + } + if (CollectiveRingOrder.Length != 0) { + output.WriteRawTag(34); + output.WriteString(CollectiveRingOrder); + } + if (TimestampedAllocator != false) { + output.WriteRawTag(40); + output.WriteBool(TimestampedAllocator); + } + if (KernelTrackerMaxInterval != 0) { + output.WriteRawTag(56); + output.WriteInt32(KernelTrackerMaxInterval); + } + if (KernelTrackerMaxBytes != 0) { + output.WriteRawTag(64); + output.WriteInt32(KernelTrackerMaxBytes); + } + if (KernelTrackerMaxPending != 0) { + output.WriteRawTag(72); + output.WriteInt32(KernelTrackerMaxPending); + } + if (InternalFragmentationFraction != 0D) { + output.WriteRawTag(81); + output.WriteDouble(InternalFragmentationFraction); + } + if (UseCudaMallocAsync != false) { + output.WriteRawTag(88); + output.WriteBool(UseCudaMallocAsync); + } + if (DisallowRetryOnAllocationFailure != false) { + output.WriteRawTag(96); + output.WriteBool(DisallowRetryOnAllocationFailure); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += virtualDevices_.CalculateSize(_repeated_virtualDevices_codec); + if (UseUnifiedMemory != false) { + size += 1 + 1; + } + if (NumDevToDevCopyStreams != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumDevToDevCopyStreams); + } + if (CollectiveRingOrder.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveRingOrder); + } + if (TimestampedAllocator != false) { + size += 1 + 1; + } + if (KernelTrackerMaxInterval != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxInterval); + } + if (KernelTrackerMaxBytes != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxBytes); + } + if (KernelTrackerMaxPending != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxPending); + } + if (InternalFragmentationFraction != 0D) { + size += 1 + 8; + } + if (UseCudaMallocAsync != false) { + size += 1 + 1; + } + if (DisallowRetryOnAllocationFailure != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Experimental other) { + if (other == null) { + return; + } + virtualDevices_.Add(other.virtualDevices_); + if (other.UseUnifiedMemory != false) { + UseUnifiedMemory = other.UseUnifiedMemory; + } + if (other.NumDevToDevCopyStreams != 0) { + NumDevToDevCopyStreams = other.NumDevToDevCopyStreams; + } + if (other.CollectiveRingOrder.Length != 0) { + CollectiveRingOrder = other.CollectiveRingOrder; + } + if (other.TimestampedAllocator != false) { + TimestampedAllocator = other.TimestampedAllocator; + } + if (other.KernelTrackerMaxInterval != 0) { + KernelTrackerMaxInterval = other.KernelTrackerMaxInterval; + } + if (other.KernelTrackerMaxBytes != 0) { + KernelTrackerMaxBytes = other.KernelTrackerMaxBytes; + } + if (other.KernelTrackerMaxPending != 0) { + KernelTrackerMaxPending = other.KernelTrackerMaxPending; + } + if (other.InternalFragmentationFraction != 0D) { + InternalFragmentationFraction = other.InternalFragmentationFraction; + } + if (other.UseCudaMallocAsync != false) { + UseCudaMallocAsync = other.UseCudaMallocAsync; + } + if (other.DisallowRetryOnAllocationFailure != false) { + DisallowRetryOnAllocationFailure = other.DisallowRetryOnAllocationFailure; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + virtualDevices_.AddEntriesFrom(input, _repeated_virtualDevices_codec); + break; + } + case 16: { + UseUnifiedMemory = input.ReadBool(); + break; + } + case 24: { + NumDevToDevCopyStreams = input.ReadInt32(); + break; + } + case 34: { + CollectiveRingOrder = input.ReadString(); + break; + } + case 40: { + TimestampedAllocator = input.ReadBool(); + break; + } + case 56: { + KernelTrackerMaxInterval = input.ReadInt32(); + break; + } + case 64: { + KernelTrackerMaxBytes = input.ReadInt32(); + break; + } + case 72: { + KernelTrackerMaxPending = input.ReadInt32(); + break; + } + case 81: { + InternalFragmentationFraction = input.ReadDouble(); + break; + } + case 88: { + UseCudaMallocAsync = input.ReadBool(); + break; + } + case 96: { + DisallowRetryOnAllocationFailure = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + virtualDevices_.AddEntriesFrom(ref input, _repeated_virtualDevices_codec); + break; + } + case 16: { + UseUnifiedMemory = input.ReadBool(); + break; + } + case 24: { + NumDevToDevCopyStreams = input.ReadInt32(); + break; + } + case 34: { + CollectiveRingOrder = input.ReadString(); + break; + } + case 40: { + TimestampedAllocator = input.ReadBool(); + break; + } + case 56: { + KernelTrackerMaxInterval = input.ReadInt32(); + break; + } + case 64: { + KernelTrackerMaxBytes = input.ReadInt32(); + break; + } + case 72: { + KernelTrackerMaxPending = input.ReadInt32(); + break; + } + case 81: { + InternalFragmentationFraction = input.ReadDouble(); + break; + } + case 88: { + UseCudaMallocAsync = input.ReadBool(); + break; + } + case 96: { + DisallowRetryOnAllocationFailure = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Experimental message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Configuration for breaking down a visible GPU into multiple "virtual" + /// devices. + /// + public sealed partial class VirtualDevices : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VirtualDevices()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GPUOptions.Types.Experimental.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VirtualDevices() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VirtualDevices(VirtualDevices other) : this() { + memoryLimitMb_ = other.memoryLimitMb_.Clone(); + priority_ = other.priority_.Clone(); + deviceOrdinal_ = other.deviceOrdinal_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VirtualDevices Clone() { + return new VirtualDevices(this); + } + + /// Field number for the "memory_limit_mb" field. + public const int MemoryLimitMbFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_memoryLimitMb_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField memoryLimitMb_ = new pbc::RepeatedField(); + /// + /// Per "virtual" device memory limit, in MB. The number of elements in + /// the list is the number of virtual devices to create on the + /// corresponding visible GPU (see "virtual_devices" below). + /// If empty, it will create single virtual device taking all available + /// memory from the device. + /// + /// For the concept of "visible" and "virtual" GPU, see the comments for + /// "visible_device_list" above for more information. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MemoryLimitMb { + get { return memoryLimitMb_; } + } + + /// Field number for the "priority" field. + public const int PriorityFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_priority_codec + = pb::FieldCodec.ForInt32(18); + private readonly pbc::RepeatedField priority_ = new pbc::RepeatedField(); + /// + /// Priority values to use with the virtual devices. Use the cuda function + /// cudaDeviceGetStreamPriorityRange to query for valid range of values for + /// priority. + /// + /// On a P4000 GPU with cuda 10.1, the priority range reported was 0 for + /// least priority and -1 for greatest priority. + /// + /// If this field is not specified, then the virtual devices will be + /// created with the default. If this field has values set, then the size + /// of this must match with the above memory_limit_mb. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Priority { + get { return priority_; } + } + + /// Field number for the "device_ordinal" field. + public const int DeviceOrdinalFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_deviceOrdinal_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField deviceOrdinal_ = new pbc::RepeatedField(); + /// + /// Virtual Device ordinal number determines the device ID of the device. + /// A Virtual device with a lower ordinal number always receives the a + /// smaller device id. The phyiscal device id and location in the + /// virtual device list is used to break ties. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DeviceOrdinal { + get { return deviceOrdinal_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as VirtualDevices); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(VirtualDevices other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!memoryLimitMb_.Equals(other.memoryLimitMb_)) return false; + if(!priority_.Equals(other.priority_)) return false; + if(!deviceOrdinal_.Equals(other.deviceOrdinal_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= memoryLimitMb_.GetHashCode(); + hash ^= priority_.GetHashCode(); + hash ^= deviceOrdinal_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + memoryLimitMb_.WriteTo(output, _repeated_memoryLimitMb_codec); + priority_.WriteTo(output, _repeated_priority_codec); + deviceOrdinal_.WriteTo(output, _repeated_deviceOrdinal_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + memoryLimitMb_.WriteTo(ref output, _repeated_memoryLimitMb_codec); + priority_.WriteTo(ref output, _repeated_priority_codec); + deviceOrdinal_.WriteTo(ref output, _repeated_deviceOrdinal_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += memoryLimitMb_.CalculateSize(_repeated_memoryLimitMb_codec); + size += priority_.CalculateSize(_repeated_priority_codec); + size += deviceOrdinal_.CalculateSize(_repeated_deviceOrdinal_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(VirtualDevices other) { + if (other == null) { + return; + } + memoryLimitMb_.Add(other.memoryLimitMb_); + priority_.Add(other.priority_); + deviceOrdinal_.Add(other.deviceOrdinal_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + memoryLimitMb_.AddEntriesFrom(input, _repeated_memoryLimitMb_codec); + break; + } + case 18: + case 16: { + priority_.AddEntriesFrom(input, _repeated_priority_codec); + break; + } + case 26: + case 24: { + deviceOrdinal_.AddEntriesFrom(input, _repeated_deviceOrdinal_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 13: { + memoryLimitMb_.AddEntriesFrom(ref input, _repeated_memoryLimitMb_codec); + break; + } + case 18: + case 16: { + priority_.AddEntriesFrom(ref input, _repeated_priority_codec); + break; + } + case 26: + case 24: { + deviceOrdinal_.AddEntriesFrom(ref input, _repeated_deviceOrdinal_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + } + #endregion + + } + + /// + /// Options passed to the graph optimizer + /// + public sealed partial class OptimizerOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OptimizerOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptimizerOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptimizerOptions(OptimizerOptions other) : this() { + doCommonSubexpressionElimination_ = other.doCommonSubexpressionElimination_; + doConstantFolding_ = other.doConstantFolding_; + maxFoldedConstantInBytes_ = other.maxFoldedConstantInBytes_; + doFunctionInlining_ = other.doFunctionInlining_; + optLevel_ = other.optLevel_; + globalJitLevel_ = other.globalJitLevel_; + cpuGlobalJit_ = other.cpuGlobalJit_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptimizerOptions Clone() { + return new OptimizerOptions(this); + } + + /// Field number for the "do_common_subexpression_elimination" field. + public const int DoCommonSubexpressionEliminationFieldNumber = 1; + private bool doCommonSubexpressionElimination_; + /// + /// If true, optimize the graph using common subexpression elimination. + /// Note: the optimization Level L1 will override this setting to true. So in + /// order to disable common subexpression elimination the opt_level has to be + /// set to L0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DoCommonSubexpressionElimination { + get { return doCommonSubexpressionElimination_; } + set { + doCommonSubexpressionElimination_ = value; + } + } + + /// Field number for the "do_constant_folding" field. + public const int DoConstantFoldingFieldNumber = 2; + private bool doConstantFolding_; + /// + /// If true, perform constant folding optimization on the graph. + /// Note: the optimization Level L1 will override this setting to true. So in + /// order to disable constant folding the opt_level has to be set to L0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DoConstantFolding { + get { return doConstantFolding_; } + set { + doConstantFolding_ = value; + } + } + + /// Field number for the "max_folded_constant_in_bytes" field. + public const int MaxFoldedConstantInBytesFieldNumber = 6; + private long maxFoldedConstantInBytes_; + /// + /// Constant folding optimization replaces tensors whose values can be + /// predetermined, with constant nodes. To avoid inserting too large constants, + /// the size of each constant created can be limited. If this value is zero, a + /// default limit of 10 MiB will be applied. If constant folding optimization + /// is disabled, this value is ignored. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long MaxFoldedConstantInBytes { + get { return maxFoldedConstantInBytes_; } + set { + maxFoldedConstantInBytes_ = value; + } + } + + /// Field number for the "do_function_inlining" field. + public const int DoFunctionInliningFieldNumber = 4; + private bool doFunctionInlining_; + /// + /// If true, perform function inlining on the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DoFunctionInlining { + get { return doFunctionInlining_; } + set { + doFunctionInlining_ = value; + } + } + + /// Field number for the "opt_level" field. + public const int OptLevelFieldNumber = 3; + private global::Tensorflow.OptimizerOptions.Types.Level optLevel_ = global::Tensorflow.OptimizerOptions.Types.Level.L1; + /// + /// Overall optimization level. The actual optimizations applied will be the + /// logical OR of the flags that this level implies and any flags already set. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.OptimizerOptions.Types.Level OptLevel { + get { return optLevel_; } + set { + optLevel_ = value; + } + } + + /// Field number for the "global_jit_level" field. + public const int GlobalJitLevelFieldNumber = 5; + private global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel globalJitLevel_ = global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel GlobalJitLevel { + get { return globalJitLevel_; } + set { + globalJitLevel_ = value; + } + } + + /// Field number for the "cpu_global_jit" field. + public const int CpuGlobalJitFieldNumber = 7; + private bool cpuGlobalJit_; + /// + /// CPU code will be autoclustered only if global_jit_level >= ON_1 and either: + /// - this flag is true, or + /// - TF_XLA_FLAGS contains --tf_xla_cpu_global_jit=true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CpuGlobalJit { + get { return cpuGlobalJit_; } + set { + cpuGlobalJit_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OptimizerOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OptimizerOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DoCommonSubexpressionElimination != other.DoCommonSubexpressionElimination) return false; + if (DoConstantFolding != other.DoConstantFolding) return false; + if (MaxFoldedConstantInBytes != other.MaxFoldedConstantInBytes) return false; + if (DoFunctionInlining != other.DoFunctionInlining) return false; + if (OptLevel != other.OptLevel) return false; + if (GlobalJitLevel != other.GlobalJitLevel) return false; + if (CpuGlobalJit != other.CpuGlobalJit) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DoCommonSubexpressionElimination != false) hash ^= DoCommonSubexpressionElimination.GetHashCode(); + if (DoConstantFolding != false) hash ^= DoConstantFolding.GetHashCode(); + if (MaxFoldedConstantInBytes != 0L) hash ^= MaxFoldedConstantInBytes.GetHashCode(); + if (DoFunctionInlining != false) hash ^= DoFunctionInlining.GetHashCode(); + if (OptLevel != global::Tensorflow.OptimizerOptions.Types.Level.L1) hash ^= OptLevel.GetHashCode(); + if (GlobalJitLevel != global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default) hash ^= GlobalJitLevel.GetHashCode(); + if (CpuGlobalJit != false) hash ^= CpuGlobalJit.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DoCommonSubexpressionElimination != false) { + output.WriteRawTag(8); + output.WriteBool(DoCommonSubexpressionElimination); + } + if (DoConstantFolding != false) { + output.WriteRawTag(16); + output.WriteBool(DoConstantFolding); + } + if (OptLevel != global::Tensorflow.OptimizerOptions.Types.Level.L1) { + output.WriteRawTag(24); + output.WriteEnum((int) OptLevel); + } + if (DoFunctionInlining != false) { + output.WriteRawTag(32); + output.WriteBool(DoFunctionInlining); + } + if (GlobalJitLevel != global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default) { + output.WriteRawTag(40); + output.WriteEnum((int) GlobalJitLevel); + } + if (MaxFoldedConstantInBytes != 0L) { + output.WriteRawTag(48); + output.WriteInt64(MaxFoldedConstantInBytes); + } + if (CpuGlobalJit != false) { + output.WriteRawTag(56); + output.WriteBool(CpuGlobalJit); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DoCommonSubexpressionElimination != false) { + output.WriteRawTag(8); + output.WriteBool(DoCommonSubexpressionElimination); + } + if (DoConstantFolding != false) { + output.WriteRawTag(16); + output.WriteBool(DoConstantFolding); + } + if (OptLevel != global::Tensorflow.OptimizerOptions.Types.Level.L1) { + output.WriteRawTag(24); + output.WriteEnum((int) OptLevel); + } + if (DoFunctionInlining != false) { + output.WriteRawTag(32); + output.WriteBool(DoFunctionInlining); + } + if (GlobalJitLevel != global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default) { + output.WriteRawTag(40); + output.WriteEnum((int) GlobalJitLevel); + } + if (MaxFoldedConstantInBytes != 0L) { + output.WriteRawTag(48); + output.WriteInt64(MaxFoldedConstantInBytes); + } + if (CpuGlobalJit != false) { + output.WriteRawTag(56); + output.WriteBool(CpuGlobalJit); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DoCommonSubexpressionElimination != false) { + size += 1 + 1; + } + if (DoConstantFolding != false) { + size += 1 + 1; + } + if (MaxFoldedConstantInBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(MaxFoldedConstantInBytes); + } + if (DoFunctionInlining != false) { + size += 1 + 1; + } + if (OptLevel != global::Tensorflow.OptimizerOptions.Types.Level.L1) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) OptLevel); + } + if (GlobalJitLevel != global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) GlobalJitLevel); + } + if (CpuGlobalJit != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OptimizerOptions other) { + if (other == null) { + return; + } + if (other.DoCommonSubexpressionElimination != false) { + DoCommonSubexpressionElimination = other.DoCommonSubexpressionElimination; + } + if (other.DoConstantFolding != false) { + DoConstantFolding = other.DoConstantFolding; + } + if (other.MaxFoldedConstantInBytes != 0L) { + MaxFoldedConstantInBytes = other.MaxFoldedConstantInBytes; + } + if (other.DoFunctionInlining != false) { + DoFunctionInlining = other.DoFunctionInlining; + } + if (other.OptLevel != global::Tensorflow.OptimizerOptions.Types.Level.L1) { + OptLevel = other.OptLevel; + } + if (other.GlobalJitLevel != global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel.Default) { + GlobalJitLevel = other.GlobalJitLevel; + } + if (other.CpuGlobalJit != false) { + CpuGlobalJit = other.CpuGlobalJit; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + DoCommonSubexpressionElimination = input.ReadBool(); + break; + } + case 16: { + DoConstantFolding = input.ReadBool(); + break; + } + case 24: { + OptLevel = (global::Tensorflow.OptimizerOptions.Types.Level) input.ReadEnum(); + break; + } + case 32: { + DoFunctionInlining = input.ReadBool(); + break; + } + case 40: { + GlobalJitLevel = (global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) input.ReadEnum(); + break; + } + case 48: { + MaxFoldedConstantInBytes = input.ReadInt64(); + break; + } + case 56: { + CpuGlobalJit = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + DoCommonSubexpressionElimination = input.ReadBool(); + break; + } + case 16: { + DoConstantFolding = input.ReadBool(); + break; + } + case 24: { + OptLevel = (global::Tensorflow.OptimizerOptions.Types.Level) input.ReadEnum(); + break; + } + case 32: { + DoFunctionInlining = input.ReadBool(); + break; + } + case 40: { + GlobalJitLevel = (global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) input.ReadEnum(); + break; + } + case 48: { + MaxFoldedConstantInBytes = input.ReadInt64(); + break; + } + case 56: { + CpuGlobalJit = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the OptimizerOptions message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Optimization level + /// + public enum Level { + /// + /// L1 is the default level. + /// Optimization performed at L1 : + /// 1. Common subexpression elimination + /// 2. Constant folding + /// + [pbr::OriginalName("L1")] L1 = 0, + /// + /// No optimizations + /// + [pbr::OriginalName("L0")] L0 = -1, + } + + /// + /// Control the use of the compiler/jit. Experimental. + /// + public enum GlobalJitLevel { + /// + /// Default setting ("off" now, but later expected to be "on") + /// + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("OFF")] Off = -1, + /// + /// The following settings turn on compilation, with higher values being + /// more aggressive. Higher values may reduce opportunities for parallelism + /// and may use more memory. (At present, there is no distinction, but this + /// is expected to change.) + /// + [pbr::OriginalName("ON_1")] On1 = 1, + [pbr::OriginalName("ON_2")] On2 = 2, + } + + } + #endregion + + } + + public sealed partial class GraphOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphOptions(GraphOptions other) : this() { + enableRecvScheduling_ = other.enableRecvScheduling_; + optimizerOptions_ = other.optimizerOptions_ != null ? other.optimizerOptions_.Clone() : null; + buildCostModel_ = other.buildCostModel_; + buildCostModelAfter_ = other.buildCostModelAfter_; + inferShapes_ = other.inferShapes_; + placePrunedGraph_ = other.placePrunedGraph_; + enableBfloat16Sendrecv_ = other.enableBfloat16Sendrecv_; + timelineStep_ = other.timelineStep_; + rewriteOptions_ = other.rewriteOptions_ != null ? other.rewriteOptions_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphOptions Clone() { + return new GraphOptions(this); + } + + /// Field number for the "enable_recv_scheduling" field. + public const int EnableRecvSchedulingFieldNumber = 2; + private bool enableRecvScheduling_; + /// + /// If true, use control flow to schedule the activation of Recv nodes. + /// (Currently ignored.) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool EnableRecvScheduling { + get { return enableRecvScheduling_; } + set { + enableRecvScheduling_ = value; + } + } + + /// Field number for the "optimizer_options" field. + public const int OptimizerOptionsFieldNumber = 3; + private global::Tensorflow.OptimizerOptions optimizerOptions_; + /// + /// Options controlling how graph is optimized. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.OptimizerOptions OptimizerOptions { + get { return optimizerOptions_; } + set { + optimizerOptions_ = value; + } + } + + /// Field number for the "build_cost_model" field. + public const int BuildCostModelFieldNumber = 4; + private long buildCostModel_; + /// + /// The number of steps to run before returning a cost model detailing + /// the memory usage and performance of each node of the graph. 0 means + /// no cost model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BuildCostModel { + get { return buildCostModel_; } + set { + buildCostModel_ = value; + } + } + + /// Field number for the "build_cost_model_after" field. + public const int BuildCostModelAfterFieldNumber = 9; + private long buildCostModelAfter_; + /// + /// The number of steps to skip before collecting statistics for the + /// cost model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BuildCostModelAfter { + get { return buildCostModelAfter_; } + set { + buildCostModelAfter_ = value; + } + } + + /// Field number for the "infer_shapes" field. + public const int InferShapesFieldNumber = 5; + private bool inferShapes_; + /// + /// Annotate each Node with Op output shape data, to the extent it can + /// be statically inferred. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool InferShapes { + get { return inferShapes_; } + set { + inferShapes_ = value; + } + } + + /// Field number for the "place_pruned_graph" field. + public const int PlacePrunedGraphFieldNumber = 6; + private bool placePrunedGraph_; + /// + /// Only place the subgraphs that are run, rather than the entire graph. + /// + /// This is useful for interactive graph building, where one might + /// produce graphs that cannot be placed during the debugging + /// process. In particular, it allows the client to continue work in + /// a session after adding a node to a graph whose placement + /// constraints are unsatisfiable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool PlacePrunedGraph { + get { return placePrunedGraph_; } + set { + placePrunedGraph_ = value; + } + } + + /// Field number for the "enable_bfloat16_sendrecv" field. + public const int EnableBfloat16SendrecvFieldNumber = 7; + private bool enableBfloat16Sendrecv_; + /// + /// If true, transfer float values between processes as bfloat16. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool EnableBfloat16Sendrecv { + get { return enableBfloat16Sendrecv_; } + set { + enableBfloat16Sendrecv_ = value; + } + } + + /// Field number for the "timeline_step" field. + public const int TimelineStepFieldNumber = 8; + private int timelineStep_; + /// + /// If > 0, record a timeline every this many steps. + /// EXPERIMENTAL: This currently has no effect in MasterSession. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int TimelineStep { + get { return timelineStep_; } + set { + timelineStep_ = value; + } + } + + /// Field number for the "rewrite_options" field. + public const int RewriteOptionsFieldNumber = 10; + private global::Tensorflow.RewriterConfig rewriteOptions_; + /// + /// Options that control the type and amount of graph rewriting. + /// Not currently configurable via the public Python API (i.e. there is no API + /// stability guarantee if you import RewriterConfig explicitly). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig RewriteOptions { + get { return rewriteOptions_; } + set { + rewriteOptions_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (EnableRecvScheduling != other.EnableRecvScheduling) return false; + if (!object.Equals(OptimizerOptions, other.OptimizerOptions)) return false; + if (BuildCostModel != other.BuildCostModel) return false; + if (BuildCostModelAfter != other.BuildCostModelAfter) return false; + if (InferShapes != other.InferShapes) return false; + if (PlacePrunedGraph != other.PlacePrunedGraph) return false; + if (EnableBfloat16Sendrecv != other.EnableBfloat16Sendrecv) return false; + if (TimelineStep != other.TimelineStep) return false; + if (!object.Equals(RewriteOptions, other.RewriteOptions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (EnableRecvScheduling != false) hash ^= EnableRecvScheduling.GetHashCode(); + if (optimizerOptions_ != null) hash ^= OptimizerOptions.GetHashCode(); + if (BuildCostModel != 0L) hash ^= BuildCostModel.GetHashCode(); + if (BuildCostModelAfter != 0L) hash ^= BuildCostModelAfter.GetHashCode(); + if (InferShapes != false) hash ^= InferShapes.GetHashCode(); + if (PlacePrunedGraph != false) hash ^= PlacePrunedGraph.GetHashCode(); + if (EnableBfloat16Sendrecv != false) hash ^= EnableBfloat16Sendrecv.GetHashCode(); + if (TimelineStep != 0) hash ^= TimelineStep.GetHashCode(); + if (rewriteOptions_ != null) hash ^= RewriteOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (EnableRecvScheduling != false) { + output.WriteRawTag(16); + output.WriteBool(EnableRecvScheduling); + } + if (optimizerOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(OptimizerOptions); + } + if (BuildCostModel != 0L) { + output.WriteRawTag(32); + output.WriteInt64(BuildCostModel); + } + if (InferShapes != false) { + output.WriteRawTag(40); + output.WriteBool(InferShapes); + } + if (PlacePrunedGraph != false) { + output.WriteRawTag(48); + output.WriteBool(PlacePrunedGraph); + } + if (EnableBfloat16Sendrecv != false) { + output.WriteRawTag(56); + output.WriteBool(EnableBfloat16Sendrecv); + } + if (TimelineStep != 0) { + output.WriteRawTag(64); + output.WriteInt32(TimelineStep); + } + if (BuildCostModelAfter != 0L) { + output.WriteRawTag(72); + output.WriteInt64(BuildCostModelAfter); + } + if (rewriteOptions_ != null) { + output.WriteRawTag(82); + output.WriteMessage(RewriteOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (EnableRecvScheduling != false) { + output.WriteRawTag(16); + output.WriteBool(EnableRecvScheduling); + } + if (optimizerOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(OptimizerOptions); + } + if (BuildCostModel != 0L) { + output.WriteRawTag(32); + output.WriteInt64(BuildCostModel); + } + if (InferShapes != false) { + output.WriteRawTag(40); + output.WriteBool(InferShapes); + } + if (PlacePrunedGraph != false) { + output.WriteRawTag(48); + output.WriteBool(PlacePrunedGraph); + } + if (EnableBfloat16Sendrecv != false) { + output.WriteRawTag(56); + output.WriteBool(EnableBfloat16Sendrecv); + } + if (TimelineStep != 0) { + output.WriteRawTag(64); + output.WriteInt32(TimelineStep); + } + if (BuildCostModelAfter != 0L) { + output.WriteRawTag(72); + output.WriteInt64(BuildCostModelAfter); + } + if (rewriteOptions_ != null) { + output.WriteRawTag(82); + output.WriteMessage(RewriteOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (EnableRecvScheduling != false) { + size += 1 + 1; + } + if (optimizerOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(OptimizerOptions); + } + if (BuildCostModel != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BuildCostModel); + } + if (BuildCostModelAfter != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BuildCostModelAfter); + } + if (InferShapes != false) { + size += 1 + 1; + } + if (PlacePrunedGraph != false) { + size += 1 + 1; + } + if (EnableBfloat16Sendrecv != false) { + size += 1 + 1; + } + if (TimelineStep != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TimelineStep); + } + if (rewriteOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RewriteOptions); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphOptions other) { + if (other == null) { + return; + } + if (other.EnableRecvScheduling != false) { + EnableRecvScheduling = other.EnableRecvScheduling; + } + if (other.optimizerOptions_ != null) { + if (optimizerOptions_ == null) { + OptimizerOptions = new global::Tensorflow.OptimizerOptions(); + } + OptimizerOptions.MergeFrom(other.OptimizerOptions); + } + if (other.BuildCostModel != 0L) { + BuildCostModel = other.BuildCostModel; + } + if (other.BuildCostModelAfter != 0L) { + BuildCostModelAfter = other.BuildCostModelAfter; + } + if (other.InferShapes != false) { + InferShapes = other.InferShapes; + } + if (other.PlacePrunedGraph != false) { + PlacePrunedGraph = other.PlacePrunedGraph; + } + if (other.EnableBfloat16Sendrecv != false) { + EnableBfloat16Sendrecv = other.EnableBfloat16Sendrecv; + } + if (other.TimelineStep != 0) { + TimelineStep = other.TimelineStep; + } + if (other.rewriteOptions_ != null) { + if (rewriteOptions_ == null) { + RewriteOptions = new global::Tensorflow.RewriterConfig(); + } + RewriteOptions.MergeFrom(other.RewriteOptions); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 16: { + EnableRecvScheduling = input.ReadBool(); + break; + } + case 26: { + if (optimizerOptions_ == null) { + OptimizerOptions = new global::Tensorflow.OptimizerOptions(); + } + input.ReadMessage(OptimizerOptions); + break; + } + case 32: { + BuildCostModel = input.ReadInt64(); + break; + } + case 40: { + InferShapes = input.ReadBool(); + break; + } + case 48: { + PlacePrunedGraph = input.ReadBool(); + break; + } + case 56: { + EnableBfloat16Sendrecv = input.ReadBool(); + break; + } + case 64: { + TimelineStep = input.ReadInt32(); + break; + } + case 72: { + BuildCostModelAfter = input.ReadInt64(); + break; + } + case 82: { + if (rewriteOptions_ == null) { + RewriteOptions = new global::Tensorflow.RewriterConfig(); + } + input.ReadMessage(RewriteOptions); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 16: { + EnableRecvScheduling = input.ReadBool(); + break; + } + case 26: { + if (optimizerOptions_ == null) { + OptimizerOptions = new global::Tensorflow.OptimizerOptions(); + } + input.ReadMessage(OptimizerOptions); + break; + } + case 32: { + BuildCostModel = input.ReadInt64(); + break; + } + case 40: { + InferShapes = input.ReadBool(); + break; + } + case 48: { + PlacePrunedGraph = input.ReadBool(); + break; + } + case 56: { + EnableBfloat16Sendrecv = input.ReadBool(); + break; + } + case 64: { + TimelineStep = input.ReadInt32(); + break; + } + case 72: { + BuildCostModelAfter = input.ReadInt64(); + break; + } + case 82: { + if (rewriteOptions_ == null) { + RewriteOptions = new global::Tensorflow.RewriterConfig(); + } + input.ReadMessage(RewriteOptions); + break; + } + } + } + } + #endif + + } + + public sealed partial class ThreadPoolOptionProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ThreadPoolOptionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ThreadPoolOptionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ThreadPoolOptionProto(ThreadPoolOptionProto other) : this() { + numThreads_ = other.numThreads_; + globalName_ = other.globalName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ThreadPoolOptionProto Clone() { + return new ThreadPoolOptionProto(this); + } + + /// Field number for the "num_threads" field. + public const int NumThreadsFieldNumber = 1; + private int numThreads_; + /// + /// The number of threads in the pool. + /// + /// 0 means the system picks a value based on where this option proto is used + /// (see the declaration of the specific field for more info). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumThreads { + get { return numThreads_; } + set { + numThreads_ = value; + } + } + + /// Field number for the "global_name" field. + public const int GlobalNameFieldNumber = 2; + private string globalName_ = ""; + /// + /// The global name of the threadpool. + /// + /// If empty, then the threadpool is made and used according to the scope it's + /// in - e.g., for a session threadpool, it is used by that session only. + /// + /// If non-empty, then: + /// - a global threadpool associated with this name is looked + /// up or created. This allows, for example, sharing one threadpool across + /// many sessions (e.g., like the default behavior, if + /// inter_op_parallelism_threads is not configured), but still partitioning + /// into a large and small pool. + /// - if the threadpool for this global_name already exists, then it is an + /// error if the existing pool was created using a different num_threads + /// value as is specified on this call. + /// - threadpools created this way are never garbage collected. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string GlobalName { + get { return globalName_; } + set { + globalName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ThreadPoolOptionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ThreadPoolOptionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumThreads != other.NumThreads) return false; + if (GlobalName != other.GlobalName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NumThreads != 0) hash ^= NumThreads.GetHashCode(); + if (GlobalName.Length != 0) hash ^= GlobalName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NumThreads != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumThreads); + } + if (GlobalName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(GlobalName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NumThreads != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumThreads); + } + if (GlobalName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(GlobalName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NumThreads != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumThreads); + } + if (GlobalName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GlobalName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ThreadPoolOptionProto other) { + if (other == null) { + return; + } + if (other.NumThreads != 0) { + NumThreads = other.NumThreads; + } + if (other.GlobalName.Length != 0) { + GlobalName = other.GlobalName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumThreads = input.ReadInt32(); + break; + } + case 18: { + GlobalName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NumThreads = input.ReadInt32(); + break; + } + case 18: { + GlobalName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class RPCOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RPCOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RPCOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RPCOptions(RPCOptions other) : this() { + useRpcForInprocessMaster_ = other.useRpcForInprocessMaster_; + compressionAlgorithm_ = other.compressionAlgorithm_; + compressionLevel_ = other.compressionLevel_; + cacheRpcResponse_ = other.cacheRpcResponse_; + disableSessionConnectionSharing_ = other.disableSessionConnectionSharing_; + numChannelsPerTarget_ = other.numChannelsPerTarget_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RPCOptions Clone() { + return new RPCOptions(this); + } + + /// Field number for the "use_rpc_for_inprocess_master" field. + public const int UseRpcForInprocessMasterFieldNumber = 1; + private bool useRpcForInprocessMaster_; + /// + /// If true, always use RPC to contact the session target. + /// + /// If false (the default option), TensorFlow may use an optimized + /// transport for client-master communication that avoids the RPC + /// stack. This option is primarily for used testing the RPC stack. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseRpcForInprocessMaster { + get { return useRpcForInprocessMaster_; } + set { + useRpcForInprocessMaster_ = value; + } + } + + /// Field number for the "compression_algorithm" field. + public const int CompressionAlgorithmFieldNumber = 2; + private string compressionAlgorithm_ = ""; + /// + /// The compression algorithm to be used. One of "deflate", "gzip". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CompressionAlgorithm { + get { return compressionAlgorithm_; } + set { + compressionAlgorithm_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "compression_level" field. + public const int CompressionLevelFieldNumber = 3; + private int compressionLevel_; + /// + /// If compression_algorithm is set, the compression level to be used. + /// From 0 (no compression), up to 3. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CompressionLevel { + get { return compressionLevel_; } + set { + compressionLevel_ = value; + } + } + + /// Field number for the "cache_rpc_response" field. + public const int CacheRpcResponseFieldNumber = 4; + private bool cacheRpcResponse_; + /// + /// Setting cache_rpc_response to true will enable sender side caching of + /// response for RecvTensorAsync and RecvBufAsync to allow receiver to retry + /// requests . This is only necessary when the network fabric is experiencing a + /// significant error rate. Without it we'll fail a step on an network error, + /// while with it we'll be able to complete long steps (like complex + /// initializations) in the face of some network errors during RecvTensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CacheRpcResponse { + get { return cacheRpcResponse_; } + set { + cacheRpcResponse_ = value; + } + } + + /// Field number for the "disable_session_connection_sharing" field. + public const int DisableSessionConnectionSharingFieldNumber = 5; + private bool disableSessionConnectionSharing_; + /// + /// Disables TCP connection sharing when opening a new RPC channel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableSessionConnectionSharing { + get { return disableSessionConnectionSharing_; } + set { + disableSessionConnectionSharing_ = value; + } + } + + /// Field number for the "num_channels_per_target" field. + public const int NumChannelsPerTargetFieldNumber = 6; + private int numChannelsPerTarget_; + /// + /// Setting num_channels_per_target > 0 allows uses of multiple channels to + /// communicate to the same target. This can be used to improve the aggregate + /// throughput on high speed links (e.g 100G) where single connection is not + /// sufficient to maximize link utilization. Note that a single RPC only goes + /// on a single channel, this only helps in situations where there are multiple + /// transfers to the same target overlapping in time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumChannelsPerTarget { + get { return numChannelsPerTarget_; } + set { + numChannelsPerTarget_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RPCOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RPCOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (UseRpcForInprocessMaster != other.UseRpcForInprocessMaster) return false; + if (CompressionAlgorithm != other.CompressionAlgorithm) return false; + if (CompressionLevel != other.CompressionLevel) return false; + if (CacheRpcResponse != other.CacheRpcResponse) return false; + if (DisableSessionConnectionSharing != other.DisableSessionConnectionSharing) return false; + if (NumChannelsPerTarget != other.NumChannelsPerTarget) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (UseRpcForInprocessMaster != false) hash ^= UseRpcForInprocessMaster.GetHashCode(); + if (CompressionAlgorithm.Length != 0) hash ^= CompressionAlgorithm.GetHashCode(); + if (CompressionLevel != 0) hash ^= CompressionLevel.GetHashCode(); + if (CacheRpcResponse != false) hash ^= CacheRpcResponse.GetHashCode(); + if (DisableSessionConnectionSharing != false) hash ^= DisableSessionConnectionSharing.GetHashCode(); + if (NumChannelsPerTarget != 0) hash ^= NumChannelsPerTarget.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (UseRpcForInprocessMaster != false) { + output.WriteRawTag(8); + output.WriteBool(UseRpcForInprocessMaster); + } + if (CompressionAlgorithm.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CompressionAlgorithm); + } + if (CompressionLevel != 0) { + output.WriteRawTag(24); + output.WriteInt32(CompressionLevel); + } + if (CacheRpcResponse != false) { + output.WriteRawTag(32); + output.WriteBool(CacheRpcResponse); + } + if (DisableSessionConnectionSharing != false) { + output.WriteRawTag(40); + output.WriteBool(DisableSessionConnectionSharing); + } + if (NumChannelsPerTarget != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumChannelsPerTarget); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (UseRpcForInprocessMaster != false) { + output.WriteRawTag(8); + output.WriteBool(UseRpcForInprocessMaster); + } + if (CompressionAlgorithm.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CompressionAlgorithm); + } + if (CompressionLevel != 0) { + output.WriteRawTag(24); + output.WriteInt32(CompressionLevel); + } + if (CacheRpcResponse != false) { + output.WriteRawTag(32); + output.WriteBool(CacheRpcResponse); + } + if (DisableSessionConnectionSharing != false) { + output.WriteRawTag(40); + output.WriteBool(DisableSessionConnectionSharing); + } + if (NumChannelsPerTarget != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumChannelsPerTarget); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (UseRpcForInprocessMaster != false) { + size += 1 + 1; + } + if (CompressionAlgorithm.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CompressionAlgorithm); + } + if (CompressionLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(CompressionLevel); + } + if (CacheRpcResponse != false) { + size += 1 + 1; + } + if (DisableSessionConnectionSharing != false) { + size += 1 + 1; + } + if (NumChannelsPerTarget != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumChannelsPerTarget); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RPCOptions other) { + if (other == null) { + return; + } + if (other.UseRpcForInprocessMaster != false) { + UseRpcForInprocessMaster = other.UseRpcForInprocessMaster; + } + if (other.CompressionAlgorithm.Length != 0) { + CompressionAlgorithm = other.CompressionAlgorithm; + } + if (other.CompressionLevel != 0) { + CompressionLevel = other.CompressionLevel; + } + if (other.CacheRpcResponse != false) { + CacheRpcResponse = other.CacheRpcResponse; + } + if (other.DisableSessionConnectionSharing != false) { + DisableSessionConnectionSharing = other.DisableSessionConnectionSharing; + } + if (other.NumChannelsPerTarget != 0) { + NumChannelsPerTarget = other.NumChannelsPerTarget; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + UseRpcForInprocessMaster = input.ReadBool(); + break; + } + case 18: { + CompressionAlgorithm = input.ReadString(); + break; + } + case 24: { + CompressionLevel = input.ReadInt32(); + break; + } + case 32: { + CacheRpcResponse = input.ReadBool(); + break; + } + case 40: { + DisableSessionConnectionSharing = input.ReadBool(); + break; + } + case 48: { + NumChannelsPerTarget = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + UseRpcForInprocessMaster = input.ReadBool(); + break; + } + case 18: { + CompressionAlgorithm = input.ReadString(); + break; + } + case 24: { + CompressionLevel = input.ReadInt32(); + break; + } + case 32: { + CacheRpcResponse = input.ReadBool(); + break; + } + case 40: { + DisableSessionConnectionSharing = input.ReadBool(); + break; + } + case 48: { + NumChannelsPerTarget = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + /// + /// Metadata about the session. + /// + /// This can be used by the runtime and the Ops for debugging, monitoring, etc. + /// + /// The (name, version) tuple is expected to be a unique identifier for + /// sessions within the same process. + /// + /// NOTE: This is currently used and propagated only by the direct session. + /// + public sealed partial class SessionMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SessionMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionMetadata(SessionMetadata other) : this() { + name_ = other.name_; + version_ = other.version_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionMetadata Clone() { + return new SessionMetadata(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 2; + private long version_; + /// + /// The version is optional. If set, needs to be >= 0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Version { + get { return version_; } + set { + version_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SessionMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SessionMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Version != other.Version) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Version != 0L) hash ^= Version.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Version != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Version != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Version != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Version); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SessionMetadata other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Version != 0L) { + Version = other.Version; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + Version = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + Version = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Session configuration parameters. + /// The system picks appropriate values for fields that are not set. + /// + public sealed partial class ConfigProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConfigProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConfigProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConfigProto(ConfigProto other) : this() { + deviceCount_ = other.deviceCount_.Clone(); + intraOpParallelismThreads_ = other.intraOpParallelismThreads_; + interOpParallelismThreads_ = other.interOpParallelismThreads_; + usePerSessionThreads_ = other.usePerSessionThreads_; + sessionInterOpThreadPool_ = other.sessionInterOpThreadPool_.Clone(); + placementPeriod_ = other.placementPeriod_; + deviceFilters_ = other.deviceFilters_.Clone(); + gpuOptions_ = other.gpuOptions_ != null ? other.gpuOptions_.Clone() : null; + allowSoftPlacement_ = other.allowSoftPlacement_; + logDevicePlacement_ = other.logDevicePlacement_; + graphOptions_ = other.graphOptions_ != null ? other.graphOptions_.Clone() : null; + operationTimeoutInMs_ = other.operationTimeoutInMs_; + rpcOptions_ = other.rpcOptions_ != null ? other.rpcOptions_.Clone() : null; + clusterDef_ = other.clusterDef_ != null ? other.clusterDef_.Clone() : null; + isolateSessionState_ = other.isolateSessionState_; + shareClusterDevicesInSession_ = other.shareClusterDevicesInSession_; + experimental_ = other.experimental_ != null ? other.experimental_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConfigProto Clone() { + return new ConfigProto(this); + } + + /// Field number for the "device_count" field. + public const int DeviceCountFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_deviceCount_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForInt32(16, 0), 10); + private readonly pbc::MapField deviceCount_ = new pbc::MapField(); + /// + /// Map from device type name (e.g., "CPU" or "GPU" ) to maximum + /// number of devices of that type to use. If a particular device + /// type is not found in the map, the system picks an appropriate + /// number. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField DeviceCount { + get { return deviceCount_; } + } + + /// Field number for the "intra_op_parallelism_threads" field. + public const int IntraOpParallelismThreadsFieldNumber = 2; + private int intraOpParallelismThreads_; + /// + /// The execution of an individual op (for some op types) can be + /// parallelized on a pool of intra_op_parallelism_threads. + /// 0 means the system picks an appropriate number. + /// + /// If you create an ordinary session, e.g., from Python or C++, + /// then there is exactly one intra op thread pool per process. + /// The first session created determines the number of threads in this pool. + /// All subsequent sessions reuse/share this one global pool. + /// + /// There are notable exceptions to the default behavior described above: + /// 1. There is an environment variable for overriding this thread pool, + /// named TF_OVERRIDE_GLOBAL_THREADPOOL. + /// 2. When connecting to a server, such as a remote `tf.train.Server` + /// instance, then this option will be ignored altogether. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int IntraOpParallelismThreads { + get { return intraOpParallelismThreads_; } + set { + intraOpParallelismThreads_ = value; + } + } + + /// Field number for the "inter_op_parallelism_threads" field. + public const int InterOpParallelismThreadsFieldNumber = 5; + private int interOpParallelismThreads_; + /// + /// Nodes that perform blocking operations are enqueued on a pool of + /// inter_op_parallelism_threads available in each process. + /// + /// 0 means the system picks an appropriate number. + /// Negative means all operations are performed in caller's thread. + /// + /// Note that the first Session created in the process sets the + /// number of threads for all future sessions unless use_per_session_threads is + /// true or session_inter_op_thread_pool is configured. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int InterOpParallelismThreads { + get { return interOpParallelismThreads_; } + set { + interOpParallelismThreads_ = value; + } + } + + /// Field number for the "use_per_session_threads" field. + public const int UsePerSessionThreadsFieldNumber = 9; + private bool usePerSessionThreads_; + /// + /// If true, use a new set of threads for this session rather than the global + /// pool of threads. Only supported by direct sessions. + /// + /// If false, use the global threads created by the first session, or the + /// per-session thread pools configured by session_inter_op_thread_pool. + /// + /// This option is deprecated. The same effect can be achieved by setting + /// session_inter_op_thread_pool to have one element, whose num_threads equals + /// inter_op_parallelism_threads. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UsePerSessionThreads { + get { return usePerSessionThreads_; } + set { + usePerSessionThreads_ = value; + } + } + + /// Field number for the "session_inter_op_thread_pool" field. + public const int SessionInterOpThreadPoolFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_sessionInterOpThreadPool_codec + = pb::FieldCodec.ForMessage(98, global::Tensorflow.ThreadPoolOptionProto.Parser); + private readonly pbc::RepeatedField sessionInterOpThreadPool_ = new pbc::RepeatedField(); + /// + /// This option is experimental - it may be replaced with a different mechanism + /// in the future. + /// + /// Configures session thread pools. If this is configured, then RunOptions for + /// a Run call can select the thread pool to use. + /// + /// The intended use is for when some session invocations need to run in a + /// background pool limited to a small number of threads: + /// - For example, a session may be configured to have one large pool (for + /// regular compute) and one small pool (for periodic, low priority work); + /// using the small pool is currently the mechanism for limiting the inter-op + /// parallelism of the low priority work. Note that it does not limit the + /// parallelism of work spawned by a single op kernel implementation. + /// - Using this setting is normally not needed in training, but may help some + /// serving use cases. + /// - It is also generally recommended to set the global_name field of this + /// proto, to avoid creating multiple large pools. It is typically better to + /// run the non-low-priority work, even across sessions, in a single large + /// pool. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SessionInterOpThreadPool { + get { return sessionInterOpThreadPool_; } + } + + /// Field number for the "placement_period" field. + public const int PlacementPeriodFieldNumber = 3; + private int placementPeriod_; + /// + /// Assignment of Nodes to Devices is recomputed every placement_period + /// steps until the system warms up (at which point the recomputation + /// typically slows down automatically). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PlacementPeriod { + get { return placementPeriod_; } + set { + placementPeriod_ = value; + } + } + + /// Field number for the "device_filters" field. + public const int DeviceFiltersFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_deviceFilters_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField deviceFilters_ = new pbc::RepeatedField(); + /// + /// When any filters are present sessions will ignore all devices which do not + /// match the filters. Each filter can be partially specified, e.g. "/job:ps" + /// "/job:worker/replica:3", etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DeviceFilters { + get { return deviceFilters_; } + } + + /// Field number for the "gpu_options" field. + public const int GpuOptionsFieldNumber = 6; + private global::Tensorflow.GPUOptions gpuOptions_; + /// + /// Options that apply to all GPUs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GPUOptions GpuOptions { + get { return gpuOptions_; } + set { + gpuOptions_ = value; + } + } + + /// Field number for the "allow_soft_placement" field. + public const int AllowSoftPlacementFieldNumber = 7; + private bool allowSoftPlacement_; + /// + /// Whether soft placement is allowed. If allow_soft_placement is true, + /// an op will be placed on CPU if + /// 1. there's no GPU implementation for the OP + /// or + /// 2. no GPU devices are known or registered + /// or + /// 3. need to co-locate with reftype input(s) which are from CPU. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool AllowSoftPlacement { + get { return allowSoftPlacement_; } + set { + allowSoftPlacement_ = value; + } + } + + /// Field number for the "log_device_placement" field. + public const int LogDevicePlacementFieldNumber = 8; + private bool logDevicePlacement_; + /// + /// Whether device placements should be logged. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool LogDevicePlacement { + get { return logDevicePlacement_; } + set { + logDevicePlacement_ = value; + } + } + + /// Field number for the "graph_options" field. + public const int GraphOptionsFieldNumber = 10; + private global::Tensorflow.GraphOptions graphOptions_; + /// + /// Options that apply to all graphs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GraphOptions GraphOptions { + get { return graphOptions_; } + set { + graphOptions_ = value; + } + } + + /// Field number for the "operation_timeout_in_ms" field. + public const int OperationTimeoutInMsFieldNumber = 11; + private long operationTimeoutInMs_; + /// + /// Global timeout for all blocking operations in this session. If non-zero, + /// and not overridden on a per-operation basis, this value will be used as the + /// deadline for all blocking operations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OperationTimeoutInMs { + get { return operationTimeoutInMs_; } + set { + operationTimeoutInMs_ = value; + } + } + + /// Field number for the "rpc_options" field. + public const int RpcOptionsFieldNumber = 13; + private global::Tensorflow.RPCOptions rpcOptions_; + /// + /// Options that apply when this session uses the distributed runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RPCOptions RpcOptions { + get { return rpcOptions_; } + set { + rpcOptions_ = value; + } + } + + /// Field number for the "cluster_def" field. + public const int ClusterDefFieldNumber = 14; + private global::Tensorflow.ClusterDef clusterDef_; + /// + /// Optional list of all workers to use in this session. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ClusterDef ClusterDef { + get { return clusterDef_; } + set { + clusterDef_ = value; + } + } + + /// Field number for the "isolate_session_state" field. + public const int IsolateSessionStateFieldNumber = 15; + private bool isolateSessionState_; + /// + /// If true, any resources such as Variables used in the session will not be + /// shared with other sessions. However, when clusterspec propagation is + /// enabled, this field is ignored and sessions are always isolated. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsolateSessionState { + get { return isolateSessionState_; } + set { + isolateSessionState_ = value; + } + } + + /// Field number for the "share_cluster_devices_in_session" field. + public const int ShareClusterDevicesInSessionFieldNumber = 17; + private bool shareClusterDevicesInSession_; + /// + /// When true, WorkerSessions are created with device attributes from the + /// full cluster. + /// This is helpful when a worker wants to partition a graph + /// (for example during a PartitionedCallOp). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ShareClusterDevicesInSession { + get { return shareClusterDevicesInSession_; } + set { + shareClusterDevicesInSession_ = value; + } + } + + /// Field number for the "experimental" field. + public const int ExperimentalFieldNumber = 16; + private global::Tensorflow.ConfigProto.Types.Experimental experimental_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ConfigProto.Types.Experimental Experimental { + get { return experimental_; } + set { + experimental_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ConfigProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ConfigProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!DeviceCount.Equals(other.DeviceCount)) return false; + if (IntraOpParallelismThreads != other.IntraOpParallelismThreads) return false; + if (InterOpParallelismThreads != other.InterOpParallelismThreads) return false; + if (UsePerSessionThreads != other.UsePerSessionThreads) return false; + if(!sessionInterOpThreadPool_.Equals(other.sessionInterOpThreadPool_)) return false; + if (PlacementPeriod != other.PlacementPeriod) return false; + if(!deviceFilters_.Equals(other.deviceFilters_)) return false; + if (!object.Equals(GpuOptions, other.GpuOptions)) return false; + if (AllowSoftPlacement != other.AllowSoftPlacement) return false; + if (LogDevicePlacement != other.LogDevicePlacement) return false; + if (!object.Equals(GraphOptions, other.GraphOptions)) return false; + if (OperationTimeoutInMs != other.OperationTimeoutInMs) return false; + if (!object.Equals(RpcOptions, other.RpcOptions)) return false; + if (!object.Equals(ClusterDef, other.ClusterDef)) return false; + if (IsolateSessionState != other.IsolateSessionState) return false; + if (ShareClusterDevicesInSession != other.ShareClusterDevicesInSession) return false; + if (!object.Equals(Experimental, other.Experimental)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= DeviceCount.GetHashCode(); + if (IntraOpParallelismThreads != 0) hash ^= IntraOpParallelismThreads.GetHashCode(); + if (InterOpParallelismThreads != 0) hash ^= InterOpParallelismThreads.GetHashCode(); + if (UsePerSessionThreads != false) hash ^= UsePerSessionThreads.GetHashCode(); + hash ^= sessionInterOpThreadPool_.GetHashCode(); + if (PlacementPeriod != 0) hash ^= PlacementPeriod.GetHashCode(); + hash ^= deviceFilters_.GetHashCode(); + if (gpuOptions_ != null) hash ^= GpuOptions.GetHashCode(); + if (AllowSoftPlacement != false) hash ^= AllowSoftPlacement.GetHashCode(); + if (LogDevicePlacement != false) hash ^= LogDevicePlacement.GetHashCode(); + if (graphOptions_ != null) hash ^= GraphOptions.GetHashCode(); + if (OperationTimeoutInMs != 0L) hash ^= OperationTimeoutInMs.GetHashCode(); + if (rpcOptions_ != null) hash ^= RpcOptions.GetHashCode(); + if (clusterDef_ != null) hash ^= ClusterDef.GetHashCode(); + if (IsolateSessionState != false) hash ^= IsolateSessionState.GetHashCode(); + if (ShareClusterDevicesInSession != false) hash ^= ShareClusterDevicesInSession.GetHashCode(); + if (experimental_ != null) hash ^= Experimental.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + deviceCount_.WriteTo(output, _map_deviceCount_codec); + if (IntraOpParallelismThreads != 0) { + output.WriteRawTag(16); + output.WriteInt32(IntraOpParallelismThreads); + } + if (PlacementPeriod != 0) { + output.WriteRawTag(24); + output.WriteInt32(PlacementPeriod); + } + deviceFilters_.WriteTo(output, _repeated_deviceFilters_codec); + if (InterOpParallelismThreads != 0) { + output.WriteRawTag(40); + output.WriteInt32(InterOpParallelismThreads); + } + if (gpuOptions_ != null) { + output.WriteRawTag(50); + output.WriteMessage(GpuOptions); + } + if (AllowSoftPlacement != false) { + output.WriteRawTag(56); + output.WriteBool(AllowSoftPlacement); + } + if (LogDevicePlacement != false) { + output.WriteRawTag(64); + output.WriteBool(LogDevicePlacement); + } + if (UsePerSessionThreads != false) { + output.WriteRawTag(72); + output.WriteBool(UsePerSessionThreads); + } + if (graphOptions_ != null) { + output.WriteRawTag(82); + output.WriteMessage(GraphOptions); + } + if (OperationTimeoutInMs != 0L) { + output.WriteRawTag(88); + output.WriteInt64(OperationTimeoutInMs); + } + sessionInterOpThreadPool_.WriteTo(output, _repeated_sessionInterOpThreadPool_codec); + if (rpcOptions_ != null) { + output.WriteRawTag(106); + output.WriteMessage(RpcOptions); + } + if (clusterDef_ != null) { + output.WriteRawTag(114); + output.WriteMessage(ClusterDef); + } + if (IsolateSessionState != false) { + output.WriteRawTag(120); + output.WriteBool(IsolateSessionState); + } + if (experimental_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(Experimental); + } + if (ShareClusterDevicesInSession != false) { + output.WriteRawTag(136, 1); + output.WriteBool(ShareClusterDevicesInSession); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + deviceCount_.WriteTo(ref output, _map_deviceCount_codec); + if (IntraOpParallelismThreads != 0) { + output.WriteRawTag(16); + output.WriteInt32(IntraOpParallelismThreads); + } + if (PlacementPeriod != 0) { + output.WriteRawTag(24); + output.WriteInt32(PlacementPeriod); + } + deviceFilters_.WriteTo(ref output, _repeated_deviceFilters_codec); + if (InterOpParallelismThreads != 0) { + output.WriteRawTag(40); + output.WriteInt32(InterOpParallelismThreads); + } + if (gpuOptions_ != null) { + output.WriteRawTag(50); + output.WriteMessage(GpuOptions); + } + if (AllowSoftPlacement != false) { + output.WriteRawTag(56); + output.WriteBool(AllowSoftPlacement); + } + if (LogDevicePlacement != false) { + output.WriteRawTag(64); + output.WriteBool(LogDevicePlacement); + } + if (UsePerSessionThreads != false) { + output.WriteRawTag(72); + output.WriteBool(UsePerSessionThreads); + } + if (graphOptions_ != null) { + output.WriteRawTag(82); + output.WriteMessage(GraphOptions); + } + if (OperationTimeoutInMs != 0L) { + output.WriteRawTag(88); + output.WriteInt64(OperationTimeoutInMs); + } + sessionInterOpThreadPool_.WriteTo(ref output, _repeated_sessionInterOpThreadPool_codec); + if (rpcOptions_ != null) { + output.WriteRawTag(106); + output.WriteMessage(RpcOptions); + } + if (clusterDef_ != null) { + output.WriteRawTag(114); + output.WriteMessage(ClusterDef); + } + if (IsolateSessionState != false) { + output.WriteRawTag(120); + output.WriteBool(IsolateSessionState); + } + if (experimental_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(Experimental); + } + if (ShareClusterDevicesInSession != false) { + output.WriteRawTag(136, 1); + output.WriteBool(ShareClusterDevicesInSession); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += deviceCount_.CalculateSize(_map_deviceCount_codec); + if (IntraOpParallelismThreads != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(IntraOpParallelismThreads); + } + if (InterOpParallelismThreads != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(InterOpParallelismThreads); + } + if (UsePerSessionThreads != false) { + size += 1 + 1; + } + size += sessionInterOpThreadPool_.CalculateSize(_repeated_sessionInterOpThreadPool_codec); + if (PlacementPeriod != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PlacementPeriod); + } + size += deviceFilters_.CalculateSize(_repeated_deviceFilters_codec); + if (gpuOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GpuOptions); + } + if (AllowSoftPlacement != false) { + size += 1 + 1; + } + if (LogDevicePlacement != false) { + size += 1 + 1; + } + if (graphOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GraphOptions); + } + if (OperationTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OperationTimeoutInMs); + } + if (rpcOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RpcOptions); + } + if (clusterDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ClusterDef); + } + if (IsolateSessionState != false) { + size += 1 + 1; + } + if (ShareClusterDevicesInSession != false) { + size += 2 + 1; + } + if (experimental_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(Experimental); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ConfigProto other) { + if (other == null) { + return; + } + deviceCount_.Add(other.deviceCount_); + if (other.IntraOpParallelismThreads != 0) { + IntraOpParallelismThreads = other.IntraOpParallelismThreads; + } + if (other.InterOpParallelismThreads != 0) { + InterOpParallelismThreads = other.InterOpParallelismThreads; + } + if (other.UsePerSessionThreads != false) { + UsePerSessionThreads = other.UsePerSessionThreads; + } + sessionInterOpThreadPool_.Add(other.sessionInterOpThreadPool_); + if (other.PlacementPeriod != 0) { + PlacementPeriod = other.PlacementPeriod; + } + deviceFilters_.Add(other.deviceFilters_); + if (other.gpuOptions_ != null) { + if (gpuOptions_ == null) { + GpuOptions = new global::Tensorflow.GPUOptions(); + } + GpuOptions.MergeFrom(other.GpuOptions); + } + if (other.AllowSoftPlacement != false) { + AllowSoftPlacement = other.AllowSoftPlacement; + } + if (other.LogDevicePlacement != false) { + LogDevicePlacement = other.LogDevicePlacement; + } + if (other.graphOptions_ != null) { + if (graphOptions_ == null) { + GraphOptions = new global::Tensorflow.GraphOptions(); + } + GraphOptions.MergeFrom(other.GraphOptions); + } + if (other.OperationTimeoutInMs != 0L) { + OperationTimeoutInMs = other.OperationTimeoutInMs; + } + if (other.rpcOptions_ != null) { + if (rpcOptions_ == null) { + RpcOptions = new global::Tensorflow.RPCOptions(); + } + RpcOptions.MergeFrom(other.RpcOptions); + } + if (other.clusterDef_ != null) { + if (clusterDef_ == null) { + ClusterDef = new global::Tensorflow.ClusterDef(); + } + ClusterDef.MergeFrom(other.ClusterDef); + } + if (other.IsolateSessionState != false) { + IsolateSessionState = other.IsolateSessionState; + } + if (other.ShareClusterDevicesInSession != false) { + ShareClusterDevicesInSession = other.ShareClusterDevicesInSession; + } + if (other.experimental_ != null) { + if (experimental_ == null) { + Experimental = new global::Tensorflow.ConfigProto.Types.Experimental(); + } + Experimental.MergeFrom(other.Experimental); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + deviceCount_.AddEntriesFrom(input, _map_deviceCount_codec); + break; + } + case 16: { + IntraOpParallelismThreads = input.ReadInt32(); + break; + } + case 24: { + PlacementPeriod = input.ReadInt32(); + break; + } + case 34: { + deviceFilters_.AddEntriesFrom(input, _repeated_deviceFilters_codec); + break; + } + case 40: { + InterOpParallelismThreads = input.ReadInt32(); + break; + } + case 50: { + if (gpuOptions_ == null) { + GpuOptions = new global::Tensorflow.GPUOptions(); + } + input.ReadMessage(GpuOptions); + break; + } + case 56: { + AllowSoftPlacement = input.ReadBool(); + break; + } + case 64: { + LogDevicePlacement = input.ReadBool(); + break; + } + case 72: { + UsePerSessionThreads = input.ReadBool(); + break; + } + case 82: { + if (graphOptions_ == null) { + GraphOptions = new global::Tensorflow.GraphOptions(); + } + input.ReadMessage(GraphOptions); + break; + } + case 88: { + OperationTimeoutInMs = input.ReadInt64(); + break; + } + case 98: { + sessionInterOpThreadPool_.AddEntriesFrom(input, _repeated_sessionInterOpThreadPool_codec); + break; + } + case 106: { + if (rpcOptions_ == null) { + RpcOptions = new global::Tensorflow.RPCOptions(); + } + input.ReadMessage(RpcOptions); + break; + } + case 114: { + if (clusterDef_ == null) { + ClusterDef = new global::Tensorflow.ClusterDef(); + } + input.ReadMessage(ClusterDef); + break; + } + case 120: { + IsolateSessionState = input.ReadBool(); + break; + } + case 130: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.ConfigProto.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + case 136: { + ShareClusterDevicesInSession = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + deviceCount_.AddEntriesFrom(ref input, _map_deviceCount_codec); + break; + } + case 16: { + IntraOpParallelismThreads = input.ReadInt32(); + break; + } + case 24: { + PlacementPeriod = input.ReadInt32(); + break; + } + case 34: { + deviceFilters_.AddEntriesFrom(ref input, _repeated_deviceFilters_codec); + break; + } + case 40: { + InterOpParallelismThreads = input.ReadInt32(); + break; + } + case 50: { + if (gpuOptions_ == null) { + GpuOptions = new global::Tensorflow.GPUOptions(); + } + input.ReadMessage(GpuOptions); + break; + } + case 56: { + AllowSoftPlacement = input.ReadBool(); + break; + } + case 64: { + LogDevicePlacement = input.ReadBool(); + break; + } + case 72: { + UsePerSessionThreads = input.ReadBool(); + break; + } + case 82: { + if (graphOptions_ == null) { + GraphOptions = new global::Tensorflow.GraphOptions(); + } + input.ReadMessage(GraphOptions); + break; + } + case 88: { + OperationTimeoutInMs = input.ReadInt64(); + break; + } + case 98: { + sessionInterOpThreadPool_.AddEntriesFrom(ref input, _repeated_sessionInterOpThreadPool_codec); + break; + } + case 106: { + if (rpcOptions_ == null) { + RpcOptions = new global::Tensorflow.RPCOptions(); + } + input.ReadMessage(RpcOptions); + break; + } + case 114: { + if (clusterDef_ == null) { + ClusterDef = new global::Tensorflow.ClusterDef(); + } + input.ReadMessage(ClusterDef); + break; + } + case 120: { + IsolateSessionState = input.ReadBool(); + break; + } + case 130: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.ConfigProto.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + case 136: { + ShareClusterDevicesInSession = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the ConfigProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Everything inside Experimental is subject to change and is not subject + /// to API stability guarantees in + /// https://www.tensorflow.org/guide/version_compat. + /// + public sealed partial class Experimental : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Experimental()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigProto.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental(Experimental other) : this() { + collectiveGroupLeader_ = other.collectiveGroupLeader_; + executorType_ = other.executorType_; + recvBufMaxChunk_ = other.recvBufMaxChunk_; + useNumaAffinity_ = other.useNumaAffinity_; + collectiveDeterministicSequentialExecution_ = other.collectiveDeterministicSequentialExecution_; + collectiveNccl_ = other.collectiveNccl_; + shareSessionStateInClusterspecPropagation_ = other.shareSessionStateInClusterspecPropagation_; + disableThreadSpinning_ = other.disableThreadSpinning_; + shareClusterDevicesInSession_ = other.shareClusterDevicesInSession_; + sessionMetadata_ = other.sessionMetadata_ != null ? other.sessionMetadata_.Clone() : null; + optimizeForStaticGraph_ = other.optimizeForStaticGraph_; + enableMlirBridge_ = other.enableMlirBridge_; + mlirBridgeRollout_ = other.mlirBridgeRollout_; + enableMlirGraphOptimization_ = other.enableMlirGraphOptimization_; + disableOutputPartitionGraphs_ = other.disableOutputPartitionGraphs_; + xlaFusionAutotunerThresh_ = other.xlaFusionAutotunerThresh_; + useTfrt_ = other.useTfrt_; + disableFunctionalOpsLowering_ = other.disableFunctionalOpsLowering_; + xlaPreferSingleGraphCluster_ = other.xlaPreferSingleGraphCluster_; + coordinationConfig_ = other.coordinationConfig_ != null ? other.coordinationConfig_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental Clone() { + return new Experimental(this); + } + + /// Field number for the "collective_group_leader" field. + public const int CollectiveGroupLeaderFieldNumber = 1; + private string collectiveGroupLeader_ = ""; + /// + /// Task name for group resolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CollectiveGroupLeader { + get { return collectiveGroupLeader_; } + set { + collectiveGroupLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "executor_type" field. + public const int ExecutorTypeFieldNumber = 3; + private string executorType_ = ""; + /// + /// Which executor to use, the default executor will be used + /// if it is an empty string or "DEFAULT" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ExecutorType { + get { return executorType_; } + set { + executorType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "recv_buf_max_chunk" field. + public const int RecvBufMaxChunkFieldNumber = 4; + private int recvBufMaxChunk_; + /// + /// Guidance to formatting of large RecvBuf fields for transfer. + /// Any positive value sets the max chunk size. 0 defaults to 4096. + /// Any negative value indicates no max, i.e. one chunk only. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int RecvBufMaxChunk { + get { return recvBufMaxChunk_; } + set { + recvBufMaxChunk_ = value; + } + } + + /// Field number for the "use_numa_affinity" field. + public const int UseNumaAffinityFieldNumber = 5; + private bool useNumaAffinity_; + /// + /// If true, and supported by the platform, the runtime will attempt to + /// use NUMA affinity where applicable. One consequence will be the + /// existence of as many CPU devices as there are available NUMA nodes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseNumaAffinity { + get { return useNumaAffinity_; } + set { + useNumaAffinity_ = value; + } + } + + /// Field number for the "collective_deterministic_sequential_execution" field. + public const int CollectiveDeterministicSequentialExecutionFieldNumber = 6; + private bool collectiveDeterministicSequentialExecution_; + /// + /// If true, make collective op execution order sequential and deterministic + /// for potentially concurrent collective instances. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CollectiveDeterministicSequentialExecution { + get { return collectiveDeterministicSequentialExecution_; } + set { + collectiveDeterministicSequentialExecution_ = value; + } + } + + /// Field number for the "collective_nccl" field. + public const int CollectiveNcclFieldNumber = 7; + private bool collectiveNccl_; + /// + /// If true, use NCCL for CollectiveOps. This feature is highly + /// experimental. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CollectiveNccl { + get { return collectiveNccl_; } + set { + collectiveNccl_ = value; + } + } + + /// Field number for the "share_session_state_in_clusterspec_propagation" field. + public const int ShareSessionStateInClusterspecPropagationFieldNumber = 8; + private bool shareSessionStateInClusterspecPropagation_; + /// + /// In the following, session state means the value of a variable, elements + /// in a hash table, or any other resource, accessible by worker sessions + /// held by a TF server. + /// + /// When ClusterSpec propagation is enabled, the value of + /// isolate_session_state is ignored when deciding whether to share session + /// states in a TF server (for backwards compatibility reasons). + /// - If share_session_state_in_clusterspec_propagation is true, the session + /// states are shared. + /// - If share_session_state_in_clusterspec_propagation is false, session + /// states are isolated. + /// + /// When clusterspec propagation is not used, the value of + /// share_session_state_in_clusterspec_propagation is ignored when deciding + /// whether to share session states in a TF server. + /// - If isolate_session_state is true, session states are isolated. + /// - If isolate_session_state is false, session states are shared. + /// + /// TODO(b/129330037): Add a single API that consistently treats + /// isolate_session_state and ClusterSpec propagation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ShareSessionStateInClusterspecPropagation { + get { return shareSessionStateInClusterspecPropagation_; } + set { + shareSessionStateInClusterspecPropagation_ = value; + } + } + + /// Field number for the "disable_thread_spinning" field. + public const int DisableThreadSpinningFieldNumber = 9; + private bool disableThreadSpinning_; + /// + /// If using a direct session, disable spinning while waiting for work in + /// the thread pool. This may result in higher latency for completing ops, + /// but in the case where there is a lot of spinning may result in lower + /// CPU usage. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableThreadSpinning { + get { return disableThreadSpinning_; } + set { + disableThreadSpinning_ = value; + } + } + + /// Field number for the "share_cluster_devices_in_session" field. + public const int ShareClusterDevicesInSessionFieldNumber = 10; + private bool shareClusterDevicesInSession_; + /// + /// This was promoted to a non-experimental API. Please use + /// ConfigProto.share_cluster_devices_in_session instead. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ShareClusterDevicesInSession { + get { return shareClusterDevicesInSession_; } + set { + shareClusterDevicesInSession_ = value; + } + } + + /// Field number for the "session_metadata" field. + public const int SessionMetadataFieldNumber = 11; + private global::Tensorflow.SessionMetadata sessionMetadata_; + /// + /// Metadata about the session. + /// + /// If set, this can be used by the runtime and the Ops for debugging, + /// monitoring, etc. + /// + /// NOTE: This is currently used and propagated only by the direct session. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SessionMetadata SessionMetadata { + get { return sessionMetadata_; } + set { + sessionMetadata_ = value; + } + } + + /// Field number for the "optimize_for_static_graph" field. + public const int OptimizeForStaticGraphFieldNumber = 12; + private bool optimizeForStaticGraph_; + /// + /// If true, the session may treat the graph as being static for optimization + /// purposes. + /// + /// If this option is set to true when a session is created, the full + /// GraphDef must be passed in a single call to Session::Create(), and + /// Session::Extend() may not be supported. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool OptimizeForStaticGraph { + get { return optimizeForStaticGraph_; } + set { + optimizeForStaticGraph_ = value; + } + } + + /// Field number for the "enable_mlir_bridge" field. + public const int EnableMlirBridgeFieldNumber = 13; + private bool enableMlirBridge_; + /// + /// This field will eventually be deprecated and replaced by + /// mlir_bridge_rollout (b/166038521). + /// + /// Whether to enable the MLIR-based TF->XLA bridge. + /// + /// This is a replacement to the existing bridge, and not ready for + /// production usage yet. + /// If this option is set to true when a session is created, MLIR is used to + /// perform the set of graph transformations to put the graph in a form that + /// can be executed with delegation of some computations to an accelerator. + /// This builds on the model of XLA where a subset of the graph is + /// encapsulated and attached to a "compile" operation, whose result is fed + /// to an "execute" operation. The kernel for these operations is responsible + /// to lower the encapsulated graph to a particular device. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool EnableMlirBridge { + get { return enableMlirBridge_; } + set { + enableMlirBridge_ = value; + } + } + + /// Field number for the "mlir_bridge_rollout" field. + public const int MlirBridgeRolloutFieldNumber = 17; + private global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout mlirBridgeRollout_ = global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified; + /// + /// This field is underdevelopment, for now use enable_mlir_bridge + /// (b/166038521). + /// + /// Whether to enable the MLIR-based TF->XLA bridge. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout MlirBridgeRollout { + get { return mlirBridgeRollout_; } + set { + mlirBridgeRollout_ = value; + } + } + + /// Field number for the "enable_mlir_graph_optimization" field. + public const int EnableMlirGraphOptimizationFieldNumber = 16; + private bool enableMlirGraphOptimization_; + /// + /// Whether to enable the MLIR-based Graph optimizations. + /// + /// This will become a part of standard Tensorflow graph optimization + /// pipeline, currently this is only used for gradual migration and testing + /// new passes that are replacing existing optimizations in Grappler. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool EnableMlirGraphOptimization { + get { return enableMlirGraphOptimization_; } + set { + enableMlirGraphOptimization_ = value; + } + } + + /// Field number for the "disable_output_partition_graphs" field. + public const int DisableOutputPartitionGraphsFieldNumber = 14; + private bool disableOutputPartitionGraphs_; + /// + /// If true, the session will not store an additional copy of the graph for + /// each subgraph. + /// + /// If this option is set to true when a session is created, the + /// `RunOptions.output_partition_graphs` options must not be set. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableOutputPartitionGraphs { + get { return disableOutputPartitionGraphs_; } + set { + disableOutputPartitionGraphs_ = value; + } + } + + /// Field number for the "xla_fusion_autotuner_thresh" field. + public const int XlaFusionAutotunerThreshFieldNumber = 15; + private long xlaFusionAutotunerThresh_; + /// + /// Minimum number of batches run through the XLA graph before XLA fusion + /// autotuner is enabled. Default value of zero disables the autotuner. + /// + /// The XLA fusion autotuner can improve performance by executing a heuristic + /// search on the compiler parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long XlaFusionAutotunerThresh { + get { return xlaFusionAutotunerThresh_; } + set { + xlaFusionAutotunerThresh_ = value; + } + } + + /// Field number for the "use_tfrt" field. + public const int UseTfrtFieldNumber = 18; + private bool useTfrt_; + /// + /// Whether runtime execution uses TFRT. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseTfrt { + get { return useTfrt_; } + set { + useTfrt_ = value; + } + } + + /// Field number for the "disable_functional_ops_lowering" field. + public const int DisableFunctionalOpsLoweringFieldNumber = 21; + private bool disableFunctionalOpsLowering_; + /// + /// Whether functional control flow op lowering should be disabled. This is + /// useful when executing within a portable runtime where control flow op + /// kernels may not be loaded due to selective registration. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableFunctionalOpsLowering { + get { return disableFunctionalOpsLowering_; } + set { + disableFunctionalOpsLowering_ = value; + } + } + + /// Field number for the "xla_prefer_single_graph_cluster" field. + public const int XlaPreferSingleGraphClusterFieldNumber = 22; + private bool xlaPreferSingleGraphCluster_; + /// + /// Provides a hint to XLA auto clustering to prefer forming a single large + /// cluster that encompases most of the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaPreferSingleGraphCluster { + get { return xlaPreferSingleGraphCluster_; } + set { + xlaPreferSingleGraphCluster_ = value; + } + } + + /// Field number for the "coordination_config" field. + public const int CoordinationConfigFieldNumber = 23; + private global::Tensorflow.CoordinationServiceConfig coordinationConfig_; + /// + /// Distributed coordination service configurations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinationServiceConfig CoordinationConfig { + get { return coordinationConfig_; } + set { + coordinationConfig_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Experimental); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Experimental other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (CollectiveGroupLeader != other.CollectiveGroupLeader) return false; + if (ExecutorType != other.ExecutorType) return false; + if (RecvBufMaxChunk != other.RecvBufMaxChunk) return false; + if (UseNumaAffinity != other.UseNumaAffinity) return false; + if (CollectiveDeterministicSequentialExecution != other.CollectiveDeterministicSequentialExecution) return false; + if (CollectiveNccl != other.CollectiveNccl) return false; + if (ShareSessionStateInClusterspecPropagation != other.ShareSessionStateInClusterspecPropagation) return false; + if (DisableThreadSpinning != other.DisableThreadSpinning) return false; + if (ShareClusterDevicesInSession != other.ShareClusterDevicesInSession) return false; + if (!object.Equals(SessionMetadata, other.SessionMetadata)) return false; + if (OptimizeForStaticGraph != other.OptimizeForStaticGraph) return false; + if (EnableMlirBridge != other.EnableMlirBridge) return false; + if (MlirBridgeRollout != other.MlirBridgeRollout) return false; + if (EnableMlirGraphOptimization != other.EnableMlirGraphOptimization) return false; + if (DisableOutputPartitionGraphs != other.DisableOutputPartitionGraphs) return false; + if (XlaFusionAutotunerThresh != other.XlaFusionAutotunerThresh) return false; + if (UseTfrt != other.UseTfrt) return false; + if (DisableFunctionalOpsLowering != other.DisableFunctionalOpsLowering) return false; + if (XlaPreferSingleGraphCluster != other.XlaPreferSingleGraphCluster) return false; + if (!object.Equals(CoordinationConfig, other.CoordinationConfig)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (CollectiveGroupLeader.Length != 0) hash ^= CollectiveGroupLeader.GetHashCode(); + if (ExecutorType.Length != 0) hash ^= ExecutorType.GetHashCode(); + if (RecvBufMaxChunk != 0) hash ^= RecvBufMaxChunk.GetHashCode(); + if (UseNumaAffinity != false) hash ^= UseNumaAffinity.GetHashCode(); + if (CollectiveDeterministicSequentialExecution != false) hash ^= CollectiveDeterministicSequentialExecution.GetHashCode(); + if (CollectiveNccl != false) hash ^= CollectiveNccl.GetHashCode(); + if (ShareSessionStateInClusterspecPropagation != false) hash ^= ShareSessionStateInClusterspecPropagation.GetHashCode(); + if (DisableThreadSpinning != false) hash ^= DisableThreadSpinning.GetHashCode(); + if (ShareClusterDevicesInSession != false) hash ^= ShareClusterDevicesInSession.GetHashCode(); + if (sessionMetadata_ != null) hash ^= SessionMetadata.GetHashCode(); + if (OptimizeForStaticGraph != false) hash ^= OptimizeForStaticGraph.GetHashCode(); + if (EnableMlirBridge != false) hash ^= EnableMlirBridge.GetHashCode(); + if (MlirBridgeRollout != global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified) hash ^= MlirBridgeRollout.GetHashCode(); + if (EnableMlirGraphOptimization != false) hash ^= EnableMlirGraphOptimization.GetHashCode(); + if (DisableOutputPartitionGraphs != false) hash ^= DisableOutputPartitionGraphs.GetHashCode(); + if (XlaFusionAutotunerThresh != 0L) hash ^= XlaFusionAutotunerThresh.GetHashCode(); + if (UseTfrt != false) hash ^= UseTfrt.GetHashCode(); + if (DisableFunctionalOpsLowering != false) hash ^= DisableFunctionalOpsLowering.GetHashCode(); + if (XlaPreferSingleGraphCluster != false) hash ^= XlaPreferSingleGraphCluster.GetHashCode(); + if (coordinationConfig_ != null) hash ^= CoordinationConfig.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (CollectiveGroupLeader.Length != 0) { + output.WriteRawTag(10); + output.WriteString(CollectiveGroupLeader); + } + if (ExecutorType.Length != 0) { + output.WriteRawTag(26); + output.WriteString(ExecutorType); + } + if (RecvBufMaxChunk != 0) { + output.WriteRawTag(32); + output.WriteInt32(RecvBufMaxChunk); + } + if (UseNumaAffinity != false) { + output.WriteRawTag(40); + output.WriteBool(UseNumaAffinity); + } + if (CollectiveDeterministicSequentialExecution != false) { + output.WriteRawTag(48); + output.WriteBool(CollectiveDeterministicSequentialExecution); + } + if (CollectiveNccl != false) { + output.WriteRawTag(56); + output.WriteBool(CollectiveNccl); + } + if (ShareSessionStateInClusterspecPropagation != false) { + output.WriteRawTag(64); + output.WriteBool(ShareSessionStateInClusterspecPropagation); + } + if (DisableThreadSpinning != false) { + output.WriteRawTag(72); + output.WriteBool(DisableThreadSpinning); + } + if (ShareClusterDevicesInSession != false) { + output.WriteRawTag(80); + output.WriteBool(ShareClusterDevicesInSession); + } + if (sessionMetadata_ != null) { + output.WriteRawTag(90); + output.WriteMessage(SessionMetadata); + } + if (OptimizeForStaticGraph != false) { + output.WriteRawTag(96); + output.WriteBool(OptimizeForStaticGraph); + } + if (EnableMlirBridge != false) { + output.WriteRawTag(104); + output.WriteBool(EnableMlirBridge); + } + if (DisableOutputPartitionGraphs != false) { + output.WriteRawTag(112); + output.WriteBool(DisableOutputPartitionGraphs); + } + if (XlaFusionAutotunerThresh != 0L) { + output.WriteRawTag(120); + output.WriteInt64(XlaFusionAutotunerThresh); + } + if (EnableMlirGraphOptimization != false) { + output.WriteRawTag(128, 1); + output.WriteBool(EnableMlirGraphOptimization); + } + if (MlirBridgeRollout != global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified) { + output.WriteRawTag(136, 1); + output.WriteEnum((int) MlirBridgeRollout); + } + if (UseTfrt != false) { + output.WriteRawTag(144, 1); + output.WriteBool(UseTfrt); + } + if (DisableFunctionalOpsLowering != false) { + output.WriteRawTag(168, 1); + output.WriteBool(DisableFunctionalOpsLowering); + } + if (XlaPreferSingleGraphCluster != false) { + output.WriteRawTag(176, 1); + output.WriteBool(XlaPreferSingleGraphCluster); + } + if (coordinationConfig_ != null) { + output.WriteRawTag(186, 1); + output.WriteMessage(CoordinationConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (CollectiveGroupLeader.Length != 0) { + output.WriteRawTag(10); + output.WriteString(CollectiveGroupLeader); + } + if (ExecutorType.Length != 0) { + output.WriteRawTag(26); + output.WriteString(ExecutorType); + } + if (RecvBufMaxChunk != 0) { + output.WriteRawTag(32); + output.WriteInt32(RecvBufMaxChunk); + } + if (UseNumaAffinity != false) { + output.WriteRawTag(40); + output.WriteBool(UseNumaAffinity); + } + if (CollectiveDeterministicSequentialExecution != false) { + output.WriteRawTag(48); + output.WriteBool(CollectiveDeterministicSequentialExecution); + } + if (CollectiveNccl != false) { + output.WriteRawTag(56); + output.WriteBool(CollectiveNccl); + } + if (ShareSessionStateInClusterspecPropagation != false) { + output.WriteRawTag(64); + output.WriteBool(ShareSessionStateInClusterspecPropagation); + } + if (DisableThreadSpinning != false) { + output.WriteRawTag(72); + output.WriteBool(DisableThreadSpinning); + } + if (ShareClusterDevicesInSession != false) { + output.WriteRawTag(80); + output.WriteBool(ShareClusterDevicesInSession); + } + if (sessionMetadata_ != null) { + output.WriteRawTag(90); + output.WriteMessage(SessionMetadata); + } + if (OptimizeForStaticGraph != false) { + output.WriteRawTag(96); + output.WriteBool(OptimizeForStaticGraph); + } + if (EnableMlirBridge != false) { + output.WriteRawTag(104); + output.WriteBool(EnableMlirBridge); + } + if (DisableOutputPartitionGraphs != false) { + output.WriteRawTag(112); + output.WriteBool(DisableOutputPartitionGraphs); + } + if (XlaFusionAutotunerThresh != 0L) { + output.WriteRawTag(120); + output.WriteInt64(XlaFusionAutotunerThresh); + } + if (EnableMlirGraphOptimization != false) { + output.WriteRawTag(128, 1); + output.WriteBool(EnableMlirGraphOptimization); + } + if (MlirBridgeRollout != global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified) { + output.WriteRawTag(136, 1); + output.WriteEnum((int) MlirBridgeRollout); + } + if (UseTfrt != false) { + output.WriteRawTag(144, 1); + output.WriteBool(UseTfrt); + } + if (DisableFunctionalOpsLowering != false) { + output.WriteRawTag(168, 1); + output.WriteBool(DisableFunctionalOpsLowering); + } + if (XlaPreferSingleGraphCluster != false) { + output.WriteRawTag(176, 1); + output.WriteBool(XlaPreferSingleGraphCluster); + } + if (coordinationConfig_ != null) { + output.WriteRawTag(186, 1); + output.WriteMessage(CoordinationConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (CollectiveGroupLeader.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CollectiveGroupLeader); + } + if (ExecutorType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ExecutorType); + } + if (RecvBufMaxChunk != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(RecvBufMaxChunk); + } + if (UseNumaAffinity != false) { + size += 1 + 1; + } + if (CollectiveDeterministicSequentialExecution != false) { + size += 1 + 1; + } + if (CollectiveNccl != false) { + size += 1 + 1; + } + if (ShareSessionStateInClusterspecPropagation != false) { + size += 1 + 1; + } + if (DisableThreadSpinning != false) { + size += 1 + 1; + } + if (ShareClusterDevicesInSession != false) { + size += 1 + 1; + } + if (sessionMetadata_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SessionMetadata); + } + if (OptimizeForStaticGraph != false) { + size += 1 + 1; + } + if (EnableMlirBridge != false) { + size += 1 + 1; + } + if (MlirBridgeRollout != global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) MlirBridgeRollout); + } + if (EnableMlirGraphOptimization != false) { + size += 2 + 1; + } + if (DisableOutputPartitionGraphs != false) { + size += 1 + 1; + } + if (XlaFusionAutotunerThresh != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaFusionAutotunerThresh); + } + if (UseTfrt != false) { + size += 2 + 1; + } + if (DisableFunctionalOpsLowering != false) { + size += 2 + 1; + } + if (XlaPreferSingleGraphCluster != false) { + size += 2 + 1; + } + if (coordinationConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(CoordinationConfig); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Experimental other) { + if (other == null) { + return; + } + if (other.CollectiveGroupLeader.Length != 0) { + CollectiveGroupLeader = other.CollectiveGroupLeader; + } + if (other.ExecutorType.Length != 0) { + ExecutorType = other.ExecutorType; + } + if (other.RecvBufMaxChunk != 0) { + RecvBufMaxChunk = other.RecvBufMaxChunk; + } + if (other.UseNumaAffinity != false) { + UseNumaAffinity = other.UseNumaAffinity; + } + if (other.CollectiveDeterministicSequentialExecution != false) { + CollectiveDeterministicSequentialExecution = other.CollectiveDeterministicSequentialExecution; + } + if (other.CollectiveNccl != false) { + CollectiveNccl = other.CollectiveNccl; + } + if (other.ShareSessionStateInClusterspecPropagation != false) { + ShareSessionStateInClusterspecPropagation = other.ShareSessionStateInClusterspecPropagation; + } + if (other.DisableThreadSpinning != false) { + DisableThreadSpinning = other.DisableThreadSpinning; + } + if (other.ShareClusterDevicesInSession != false) { + ShareClusterDevicesInSession = other.ShareClusterDevicesInSession; + } + if (other.sessionMetadata_ != null) { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + SessionMetadata.MergeFrom(other.SessionMetadata); + } + if (other.OptimizeForStaticGraph != false) { + OptimizeForStaticGraph = other.OptimizeForStaticGraph; + } + if (other.EnableMlirBridge != false) { + EnableMlirBridge = other.EnableMlirBridge; + } + if (other.MlirBridgeRollout != global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout.Unspecified) { + MlirBridgeRollout = other.MlirBridgeRollout; + } + if (other.EnableMlirGraphOptimization != false) { + EnableMlirGraphOptimization = other.EnableMlirGraphOptimization; + } + if (other.DisableOutputPartitionGraphs != false) { + DisableOutputPartitionGraphs = other.DisableOutputPartitionGraphs; + } + if (other.XlaFusionAutotunerThresh != 0L) { + XlaFusionAutotunerThresh = other.XlaFusionAutotunerThresh; + } + if (other.UseTfrt != false) { + UseTfrt = other.UseTfrt; + } + if (other.DisableFunctionalOpsLowering != false) { + DisableFunctionalOpsLowering = other.DisableFunctionalOpsLowering; + } + if (other.XlaPreferSingleGraphCluster != false) { + XlaPreferSingleGraphCluster = other.XlaPreferSingleGraphCluster; + } + if (other.coordinationConfig_ != null) { + if (coordinationConfig_ == null) { + CoordinationConfig = new global::Tensorflow.CoordinationServiceConfig(); + } + CoordinationConfig.MergeFrom(other.CoordinationConfig); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + CollectiveGroupLeader = input.ReadString(); + break; + } + case 26: { + ExecutorType = input.ReadString(); + break; + } + case 32: { + RecvBufMaxChunk = input.ReadInt32(); + break; + } + case 40: { + UseNumaAffinity = input.ReadBool(); + break; + } + case 48: { + CollectiveDeterministicSequentialExecution = input.ReadBool(); + break; + } + case 56: { + CollectiveNccl = input.ReadBool(); + break; + } + case 64: { + ShareSessionStateInClusterspecPropagation = input.ReadBool(); + break; + } + case 72: { + DisableThreadSpinning = input.ReadBool(); + break; + } + case 80: { + ShareClusterDevicesInSession = input.ReadBool(); + break; + } + case 90: { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + input.ReadMessage(SessionMetadata); + break; + } + case 96: { + OptimizeForStaticGraph = input.ReadBool(); + break; + } + case 104: { + EnableMlirBridge = input.ReadBool(); + break; + } + case 112: { + DisableOutputPartitionGraphs = input.ReadBool(); + break; + } + case 120: { + XlaFusionAutotunerThresh = input.ReadInt64(); + break; + } + case 128: { + EnableMlirGraphOptimization = input.ReadBool(); + break; + } + case 136: { + MlirBridgeRollout = (global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout) input.ReadEnum(); + break; + } + case 144: { + UseTfrt = input.ReadBool(); + break; + } + case 168: { + DisableFunctionalOpsLowering = input.ReadBool(); + break; + } + case 176: { + XlaPreferSingleGraphCluster = input.ReadBool(); + break; + } + case 186: { + if (coordinationConfig_ == null) { + CoordinationConfig = new global::Tensorflow.CoordinationServiceConfig(); + } + input.ReadMessage(CoordinationConfig); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + CollectiveGroupLeader = input.ReadString(); + break; + } + case 26: { + ExecutorType = input.ReadString(); + break; + } + case 32: { + RecvBufMaxChunk = input.ReadInt32(); + break; + } + case 40: { + UseNumaAffinity = input.ReadBool(); + break; + } + case 48: { + CollectiveDeterministicSequentialExecution = input.ReadBool(); + break; + } + case 56: { + CollectiveNccl = input.ReadBool(); + break; + } + case 64: { + ShareSessionStateInClusterspecPropagation = input.ReadBool(); + break; + } + case 72: { + DisableThreadSpinning = input.ReadBool(); + break; + } + case 80: { + ShareClusterDevicesInSession = input.ReadBool(); + break; + } + case 90: { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + input.ReadMessage(SessionMetadata); + break; + } + case 96: { + OptimizeForStaticGraph = input.ReadBool(); + break; + } + case 104: { + EnableMlirBridge = input.ReadBool(); + break; + } + case 112: { + DisableOutputPartitionGraphs = input.ReadBool(); + break; + } + case 120: { + XlaFusionAutotunerThresh = input.ReadInt64(); + break; + } + case 128: { + EnableMlirGraphOptimization = input.ReadBool(); + break; + } + case 136: { + MlirBridgeRollout = (global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout) input.ReadEnum(); + break; + } + case 144: { + UseTfrt = input.ReadBool(); + break; + } + case 168: { + DisableFunctionalOpsLowering = input.ReadBool(); + break; + } + case 176: { + XlaPreferSingleGraphCluster = input.ReadBool(); + break; + } + case 186: { + if (coordinationConfig_ == null) { + CoordinationConfig = new global::Tensorflow.CoordinationServiceConfig(); + } + input.ReadMessage(CoordinationConfig); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Experimental message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// An enum that describes the state of the MLIR bridge rollout. + /// + public enum MlirBridgeRollout { + /// + /// If this field is left unspecified, the MLIR bridge may be selectively + /// enabled on a per graph basis. + /// + [pbr::OriginalName("MLIR_BRIDGE_ROLLOUT_UNSPECIFIED")] Unspecified = 0, + /// + /// Enabling the MLIR bridge enables it for all graphs in this session. + /// + [pbr::OriginalName("MLIR_BRIDGE_ROLLOUT_ENABLED")] Enabled = 1, + /// + /// Disabling the MLIR bridge disables it for all graphs in this session. + /// + [pbr::OriginalName("MLIR_BRIDGE_ROLLOUT_DISABLED")] Disabled = 2, + /// + /// Enable the MLIR bridge on a per graph basis based on an analysis of + /// the features used in the graph. If the features used by the graph are + /// supported by the MLIR bridge, the MLIR bridge will be used to run the + /// graph. + /// + [pbr::OriginalName("MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED")] SafeModeEnabled = 3, + /// + /// Enable the MLIR bridge in a fallback mode on a per graph basis based + /// on an analysis of the features used in the graph. + /// Running the MLIR bridge in the fallback mode means that it is + /// executed and it commits all the changes to the TF graph in case + /// of success. And it does not in case of failures and let the old bridge + /// to process the TF graph. + /// + [pbr::OriginalName("MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED")] SafeModeFallbackEnabled = 4, + } + + } + #endregion + + } + + } + #endregion + + } + + /// + /// Options for a single Run() call. + /// + public sealed partial class RunOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RunOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunOptions(RunOptions other) : this() { + traceLevel_ = other.traceLevel_; + timeoutInMs_ = other.timeoutInMs_; + interOpThreadPool_ = other.interOpThreadPool_; + outputPartitionGraphs_ = other.outputPartitionGraphs_; + debugOptions_ = other.debugOptions_ != null ? other.debugOptions_.Clone() : null; + reportTensorAllocationsUponOom_ = other.reportTensorAllocationsUponOom_; + experimental_ = other.experimental_ != null ? other.experimental_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunOptions Clone() { + return new RunOptions(this); + } + + /// Field number for the "trace_level" field. + public const int TraceLevelFieldNumber = 1; + private global::Tensorflow.RunOptions.Types.TraceLevel traceLevel_ = global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RunOptions.Types.TraceLevel TraceLevel { + get { return traceLevel_; } + set { + traceLevel_ = value; + } + } + + /// Field number for the "timeout_in_ms" field. + public const int TimeoutInMsFieldNumber = 2; + private long timeoutInMs_; + /// + /// Time to wait for operation to complete in milliseconds. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TimeoutInMs { + get { return timeoutInMs_; } + set { + timeoutInMs_ = value; + } + } + + /// Field number for the "inter_op_thread_pool" field. + public const int InterOpThreadPoolFieldNumber = 3; + private int interOpThreadPool_; + /// + /// The thread pool to use, if session_inter_op_thread_pool is configured. + /// To use the caller thread set this to -1 - this uses the caller thread + /// to execute Session::Run() and thus avoids a context switch. Using the + /// caller thread to execute Session::Run() should be done ONLY for simple + /// graphs, where the overhead of an additional context switch is + /// comparable with the overhead of Session::Run(). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int InterOpThreadPool { + get { return interOpThreadPool_; } + set { + interOpThreadPool_ = value; + } + } + + /// Field number for the "output_partition_graphs" field. + public const int OutputPartitionGraphsFieldNumber = 5; + private bool outputPartitionGraphs_; + /// + /// Whether the partition graph(s) executed by the executor(s) should be + /// outputted via RunMetadata. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool OutputPartitionGraphs { + get { return outputPartitionGraphs_; } + set { + outputPartitionGraphs_ = value; + } + } + + /// Field number for the "debug_options" field. + public const int DebugOptionsFieldNumber = 6; + private global::Tensorflow.DebugOptions debugOptions_; + /// + /// EXPERIMENTAL. Options used to initialize DebuggerState, if enabled. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DebugOptions DebugOptions { + get { return debugOptions_; } + set { + debugOptions_ = value; + } + } + + /// Field number for the "report_tensor_allocations_upon_oom" field. + public const int ReportTensorAllocationsUponOomFieldNumber = 7; + private bool reportTensorAllocationsUponOom_; + /// + /// When enabled, causes tensor allocation information to be included in + /// the error message when the Run() call fails because the allocator ran + /// out of memory (OOM). + /// + /// Enabling this option can slow down the Run() call. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ReportTensorAllocationsUponOom { + get { return reportTensorAllocationsUponOom_; } + set { + reportTensorAllocationsUponOom_ = value; + } + } + + /// Field number for the "experimental" field. + public const int ExperimentalFieldNumber = 8; + private global::Tensorflow.RunOptions.Types.Experimental experimental_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RunOptions.Types.Experimental Experimental { + get { return experimental_; } + set { + experimental_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RunOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RunOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TraceLevel != other.TraceLevel) return false; + if (TimeoutInMs != other.TimeoutInMs) return false; + if (InterOpThreadPool != other.InterOpThreadPool) return false; + if (OutputPartitionGraphs != other.OutputPartitionGraphs) return false; + if (!object.Equals(DebugOptions, other.DebugOptions)) return false; + if (ReportTensorAllocationsUponOom != other.ReportTensorAllocationsUponOom) return false; + if (!object.Equals(Experimental, other.Experimental)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TraceLevel != global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace) hash ^= TraceLevel.GetHashCode(); + if (TimeoutInMs != 0L) hash ^= TimeoutInMs.GetHashCode(); + if (InterOpThreadPool != 0) hash ^= InterOpThreadPool.GetHashCode(); + if (OutputPartitionGraphs != false) hash ^= OutputPartitionGraphs.GetHashCode(); + if (debugOptions_ != null) hash ^= DebugOptions.GetHashCode(); + if (ReportTensorAllocationsUponOom != false) hash ^= ReportTensorAllocationsUponOom.GetHashCode(); + if (experimental_ != null) hash ^= Experimental.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TraceLevel != global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace) { + output.WriteRawTag(8); + output.WriteEnum((int) TraceLevel); + } + if (TimeoutInMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(TimeoutInMs); + } + if (InterOpThreadPool != 0) { + output.WriteRawTag(24); + output.WriteInt32(InterOpThreadPool); + } + if (OutputPartitionGraphs != false) { + output.WriteRawTag(40); + output.WriteBool(OutputPartitionGraphs); + } + if (debugOptions_ != null) { + output.WriteRawTag(50); + output.WriteMessage(DebugOptions); + } + if (ReportTensorAllocationsUponOom != false) { + output.WriteRawTag(56); + output.WriteBool(ReportTensorAllocationsUponOom); + } + if (experimental_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Experimental); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TraceLevel != global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace) { + output.WriteRawTag(8); + output.WriteEnum((int) TraceLevel); + } + if (TimeoutInMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(TimeoutInMs); + } + if (InterOpThreadPool != 0) { + output.WriteRawTag(24); + output.WriteInt32(InterOpThreadPool); + } + if (OutputPartitionGraphs != false) { + output.WriteRawTag(40); + output.WriteBool(OutputPartitionGraphs); + } + if (debugOptions_ != null) { + output.WriteRawTag(50); + output.WriteMessage(DebugOptions); + } + if (ReportTensorAllocationsUponOom != false) { + output.WriteRawTag(56); + output.WriteBool(ReportTensorAllocationsUponOom); + } + if (experimental_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Experimental); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TraceLevel != global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TraceLevel); + } + if (TimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TimeoutInMs); + } + if (InterOpThreadPool != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(InterOpThreadPool); + } + if (OutputPartitionGraphs != false) { + size += 1 + 1; + } + if (debugOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DebugOptions); + } + if (ReportTensorAllocationsUponOom != false) { + size += 1 + 1; + } + if (experimental_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Experimental); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RunOptions other) { + if (other == null) { + return; + } + if (other.TraceLevel != global::Tensorflow.RunOptions.Types.TraceLevel.NoTrace) { + TraceLevel = other.TraceLevel; + } + if (other.TimeoutInMs != 0L) { + TimeoutInMs = other.TimeoutInMs; + } + if (other.InterOpThreadPool != 0) { + InterOpThreadPool = other.InterOpThreadPool; + } + if (other.OutputPartitionGraphs != false) { + OutputPartitionGraphs = other.OutputPartitionGraphs; + } + if (other.debugOptions_ != null) { + if (debugOptions_ == null) { + DebugOptions = new global::Tensorflow.DebugOptions(); + } + DebugOptions.MergeFrom(other.DebugOptions); + } + if (other.ReportTensorAllocationsUponOom != false) { + ReportTensorAllocationsUponOom = other.ReportTensorAllocationsUponOom; + } + if (other.experimental_ != null) { + if (experimental_ == null) { + Experimental = new global::Tensorflow.RunOptions.Types.Experimental(); + } + Experimental.MergeFrom(other.Experimental); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TraceLevel = (global::Tensorflow.RunOptions.Types.TraceLevel) input.ReadEnum(); + break; + } + case 16: { + TimeoutInMs = input.ReadInt64(); + break; + } + case 24: { + InterOpThreadPool = input.ReadInt32(); + break; + } + case 40: { + OutputPartitionGraphs = input.ReadBool(); + break; + } + case 50: { + if (debugOptions_ == null) { + DebugOptions = new global::Tensorflow.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + case 56: { + ReportTensorAllocationsUponOom = input.ReadBool(); + break; + } + case 66: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.RunOptions.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + TraceLevel = (global::Tensorflow.RunOptions.Types.TraceLevel) input.ReadEnum(); + break; + } + case 16: { + TimeoutInMs = input.ReadInt64(); + break; + } + case 24: { + InterOpThreadPool = input.ReadInt32(); + break; + } + case 40: { + OutputPartitionGraphs = input.ReadBool(); + break; + } + case 50: { + if (debugOptions_ == null) { + DebugOptions = new global::Tensorflow.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + case 56: { + ReportTensorAllocationsUponOom = input.ReadBool(); + break; + } + case 66: { + if (experimental_ == null) { + Experimental = new global::Tensorflow.RunOptions.Types.Experimental(); + } + input.ReadMessage(Experimental); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the RunOptions message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// TODO(pbar) Turn this into a TraceOptions proto which allows + /// tracing to be controlled in a more orthogonal manner? + /// + public enum TraceLevel { + [pbr::OriginalName("NO_TRACE")] NoTrace = 0, + [pbr::OriginalName("SOFTWARE_TRACE")] SoftwareTrace = 1, + [pbr::OriginalName("HARDWARE_TRACE")] HardwareTrace = 2, + [pbr::OriginalName("FULL_TRACE")] FullTrace = 3, + } + + /// + /// Everything inside Experimental is subject to change and is not subject + /// to API stability guarantees in + /// https://www.tensorflow.org/guide/version_compat. + /// + public sealed partial class Experimental : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Experimental()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RunOptions.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental(Experimental other) : this() { + collectiveGraphKey_ = other.collectiveGraphKey_; + useRunHandlerPool_ = other.useRunHandlerPool_; + runHandlerPoolOptions_ = other.runHandlerPoolOptions_ != null ? other.runHandlerPoolOptions_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Experimental Clone() { + return new Experimental(this); + } + + /// Field number for the "collective_graph_key" field. + public const int CollectiveGraphKeyFieldNumber = 1; + private long collectiveGraphKey_; + /// + /// If non-zero, declares that this graph is going to use collective + /// ops and must synchronize step_ids with any other graph with this + /// same group_key value (in a distributed computation where tasks + /// run disjoint graphs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long CollectiveGraphKey { + get { return collectiveGraphKey_; } + set { + collectiveGraphKey_ = value; + } + } + + /// Field number for the "use_run_handler_pool" field. + public const int UseRunHandlerPoolFieldNumber = 2; + private bool useRunHandlerPool_; + /// + /// If true, then operations (using the inter-op pool) across all + /// session::run() calls will be centrally scheduled, optimizing for (median + /// and tail) latency. + /// Consider using this option for CPU-bound workloads like inference. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseRunHandlerPool { + get { return useRunHandlerPool_; } + set { + useRunHandlerPool_ = value; + } + } + + /// Field number for the "run_handler_pool_options" field. + public const int RunHandlerPoolOptionsFieldNumber = 3; + private global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions runHandlerPoolOptions_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions RunHandlerPoolOptions { + get { return runHandlerPoolOptions_; } + set { + runHandlerPoolOptions_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Experimental); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Experimental other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (CollectiveGraphKey != other.CollectiveGraphKey) return false; + if (UseRunHandlerPool != other.UseRunHandlerPool) return false; + if (!object.Equals(RunHandlerPoolOptions, other.RunHandlerPoolOptions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (CollectiveGraphKey != 0L) hash ^= CollectiveGraphKey.GetHashCode(); + if (UseRunHandlerPool != false) hash ^= UseRunHandlerPool.GetHashCode(); + if (runHandlerPoolOptions_ != null) hash ^= RunHandlerPoolOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (CollectiveGraphKey != 0L) { + output.WriteRawTag(8); + output.WriteInt64(CollectiveGraphKey); + } + if (UseRunHandlerPool != false) { + output.WriteRawTag(16); + output.WriteBool(UseRunHandlerPool); + } + if (runHandlerPoolOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(RunHandlerPoolOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (CollectiveGraphKey != 0L) { + output.WriteRawTag(8); + output.WriteInt64(CollectiveGraphKey); + } + if (UseRunHandlerPool != false) { + output.WriteRawTag(16); + output.WriteBool(UseRunHandlerPool); + } + if (runHandlerPoolOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(RunHandlerPoolOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (CollectiveGraphKey != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(CollectiveGraphKey); + } + if (UseRunHandlerPool != false) { + size += 1 + 1; + } + if (runHandlerPoolOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RunHandlerPoolOptions); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Experimental other) { + if (other == null) { + return; + } + if (other.CollectiveGraphKey != 0L) { + CollectiveGraphKey = other.CollectiveGraphKey; + } + if (other.UseRunHandlerPool != false) { + UseRunHandlerPool = other.UseRunHandlerPool; + } + if (other.runHandlerPoolOptions_ != null) { + if (runHandlerPoolOptions_ == null) { + RunHandlerPoolOptions = new global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions(); + } + RunHandlerPoolOptions.MergeFrom(other.RunHandlerPoolOptions); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + CollectiveGraphKey = input.ReadInt64(); + break; + } + case 16: { + UseRunHandlerPool = input.ReadBool(); + break; + } + case 26: { + if (runHandlerPoolOptions_ == null) { + RunHandlerPoolOptions = new global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions(); + } + input.ReadMessage(RunHandlerPoolOptions); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + CollectiveGraphKey = input.ReadInt64(); + break; + } + case 16: { + UseRunHandlerPool = input.ReadBool(); + break; + } + case 26: { + if (runHandlerPoolOptions_ == null) { + RunHandlerPoolOptions = new global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions(); + } + input.ReadMessage(RunHandlerPoolOptions); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Experimental message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Options for run handler thread pool. + /// + public sealed partial class RunHandlerPoolOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RunHandlerPoolOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RunOptions.Types.Experimental.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunHandlerPoolOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunHandlerPoolOptions(RunHandlerPoolOptions other) : this() { + priority_ = other.priority_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunHandlerPoolOptions Clone() { + return new RunHandlerPoolOptions(this); + } + + /// Field number for the "priority" field. + public const int PriorityFieldNumber = 1; + private long priority_; + /// + /// Priority of the request. The run handler thread pool will schedule ops + /// based on the priority number. The larger number means higher priority. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Priority { + get { return priority_; } + set { + priority_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RunHandlerPoolOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RunHandlerPoolOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Priority != other.Priority) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Priority != 0L) hash ^= Priority.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Priority != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Priority); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Priority != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Priority); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Priority != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Priority); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RunHandlerPoolOptions other) { + if (other == null) { + return; + } + if (other.Priority != 0L) { + Priority = other.Priority; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Priority = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Priority = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + } + #endregion + + } + + /// + /// Metadata output (i.e., non-Tensor) for a single Run() call. + /// + public sealed partial class RunMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RunMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunMetadata(RunMetadata other) : this() { + stepStats_ = other.stepStats_ != null ? other.stepStats_.Clone() : null; + costGraph_ = other.costGraph_ != null ? other.costGraph_.Clone() : null; + partitionGraphs_ = other.partitionGraphs_.Clone(); + functionGraphs_ = other.functionGraphs_.Clone(); + sessionMetadata_ = other.sessionMetadata_ != null ? other.sessionMetadata_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RunMetadata Clone() { + return new RunMetadata(this); + } + + /// Field number for the "step_stats" field. + public const int StepStatsFieldNumber = 1; + private global::Tensorflow.StepStats stepStats_; + /// + /// Statistics traced for this step. Populated if tracing is turned on via the + /// "RunOptions" proto. + /// EXPERIMENTAL: The format and set of events may change in future versions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StepStats StepStats { + get { return stepStats_; } + set { + stepStats_ = value; + } + } + + /// Field number for the "cost_graph" field. + public const int CostGraphFieldNumber = 2; + private global::Tensorflow.CostGraphDef costGraph_; + /// + /// The cost graph for the computation defined by the run call. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CostGraphDef CostGraph { + get { return costGraph_; } + set { + costGraph_ = value; + } + } + + /// Field number for the "partition_graphs" field. + public const int PartitionGraphsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_partitionGraphs_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.GraphDef.Parser); + private readonly pbc::RepeatedField partitionGraphs_ = new pbc::RepeatedField(); + /// + /// Graphs of the partitions executed by executors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField PartitionGraphs { + get { return partitionGraphs_; } + } + + /// Field number for the "function_graphs" field. + public const int FunctionGraphsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_functionGraphs_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser); + private readonly pbc::RepeatedField functionGraphs_ = new pbc::RepeatedField(); + /// + /// This is only populated for graphs that are run as functions in TensorFlow + /// V2. There will be an entry below for each function that is traced. + /// The main use cases of the post_optimization_graph and the partition_graphs + /// is to give the caller insight into the graphs that were actually run by the + /// runtime. Additional information (such as those in step_stats) will match + /// these graphs. + /// We also include the pre_optimization_graph since it is usually easier to + /// read, and is helpful in situations where the caller wants to get a high + /// level idea of what the built graph looks like (since the various graph + /// optimization passes might change the structure of the graph significantly). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField FunctionGraphs { + get { return functionGraphs_; } + } + + /// Field number for the "session_metadata" field. + public const int SessionMetadataFieldNumber = 5; + private global::Tensorflow.SessionMetadata sessionMetadata_; + /// + /// Metadata about the session. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SessionMetadata SessionMetadata { + get { return sessionMetadata_; } + set { + sessionMetadata_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RunMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RunMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(StepStats, other.StepStats)) return false; + if (!object.Equals(CostGraph, other.CostGraph)) return false; + if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; + if(!functionGraphs_.Equals(other.functionGraphs_)) return false; + if (!object.Equals(SessionMetadata, other.SessionMetadata)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (stepStats_ != null) hash ^= StepStats.GetHashCode(); + if (costGraph_ != null) hash ^= CostGraph.GetHashCode(); + hash ^= partitionGraphs_.GetHashCode(); + hash ^= functionGraphs_.GetHashCode(); + if (sessionMetadata_ != null) hash ^= SessionMetadata.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (stepStats_ != null) { + output.WriteRawTag(10); + output.WriteMessage(StepStats); + } + if (costGraph_ != null) { + output.WriteRawTag(18); + output.WriteMessage(CostGraph); + } + partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); + functionGraphs_.WriteTo(output, _repeated_functionGraphs_codec); + if (sessionMetadata_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SessionMetadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (stepStats_ != null) { + output.WriteRawTag(10); + output.WriteMessage(StepStats); + } + if (costGraph_ != null) { + output.WriteRawTag(18); + output.WriteMessage(CostGraph); + } + partitionGraphs_.WriteTo(ref output, _repeated_partitionGraphs_codec); + functionGraphs_.WriteTo(ref output, _repeated_functionGraphs_codec); + if (sessionMetadata_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SessionMetadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (stepStats_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(StepStats); + } + if (costGraph_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CostGraph); + } + size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); + size += functionGraphs_.CalculateSize(_repeated_functionGraphs_codec); + if (sessionMetadata_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SessionMetadata); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RunMetadata other) { + if (other == null) { + return; + } + if (other.stepStats_ != null) { + if (stepStats_ == null) { + StepStats = new global::Tensorflow.StepStats(); + } + StepStats.MergeFrom(other.StepStats); + } + if (other.costGraph_ != null) { + if (costGraph_ == null) { + CostGraph = new global::Tensorflow.CostGraphDef(); + } + CostGraph.MergeFrom(other.CostGraph); + } + partitionGraphs_.Add(other.partitionGraphs_); + functionGraphs_.Add(other.functionGraphs_); + if (other.sessionMetadata_ != null) { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + SessionMetadata.MergeFrom(other.SessionMetadata); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (stepStats_ == null) { + StepStats = new global::Tensorflow.StepStats(); + } + input.ReadMessage(StepStats); + break; + } + case 18: { + if (costGraph_ == null) { + CostGraph = new global::Tensorflow.CostGraphDef(); + } + input.ReadMessage(CostGraph); + break; + } + case 26: { + partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); + break; + } + case 34: { + functionGraphs_.AddEntriesFrom(input, _repeated_functionGraphs_codec); + break; + } + case 42: { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + input.ReadMessage(SessionMetadata); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (stepStats_ == null) { + StepStats = new global::Tensorflow.StepStats(); + } + input.ReadMessage(StepStats); + break; + } + case 18: { + if (costGraph_ == null) { + CostGraph = new global::Tensorflow.CostGraphDef(); + } + input.ReadMessage(CostGraph); + break; + } + case 26: { + partitionGraphs_.AddEntriesFrom(ref input, _repeated_partitionGraphs_codec); + break; + } + case 34: { + functionGraphs_.AddEntriesFrom(ref input, _repeated_functionGraphs_codec); + break; + } + case 42: { + if (sessionMetadata_ == null) { + SessionMetadata = new global::Tensorflow.SessionMetadata(); + } + input.ReadMessage(SessionMetadata); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the RunMetadata message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class FunctionGraphs : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionGraphs()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RunMetadata.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionGraphs() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionGraphs(FunctionGraphs other) : this() { + partitionGraphs_ = other.partitionGraphs_.Clone(); + preOptimizationGraph_ = other.preOptimizationGraph_ != null ? other.preOptimizationGraph_.Clone() : null; + postOptimizationGraph_ = other.postOptimizationGraph_ != null ? other.postOptimizationGraph_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionGraphs Clone() { + return new FunctionGraphs(this); + } + + /// Field number for the "partition_graphs" field. + public const int PartitionGraphsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_partitionGraphs_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.GraphDef.Parser); + private readonly pbc::RepeatedField partitionGraphs_ = new pbc::RepeatedField(); + /// + /// TODO(nareshmodi): Include some sort of function/cache-key identifier? + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField PartitionGraphs { + get { return partitionGraphs_; } + } + + /// Field number for the "pre_optimization_graph" field. + public const int PreOptimizationGraphFieldNumber = 2; + private global::Tensorflow.GraphDef preOptimizationGraph_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GraphDef PreOptimizationGraph { + get { return preOptimizationGraph_; } + set { + preOptimizationGraph_ = value; + } + } + + /// Field number for the "post_optimization_graph" field. + public const int PostOptimizationGraphFieldNumber = 3; + private global::Tensorflow.GraphDef postOptimizationGraph_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GraphDef PostOptimizationGraph { + get { return postOptimizationGraph_; } + set { + postOptimizationGraph_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FunctionGraphs); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FunctionGraphs other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!partitionGraphs_.Equals(other.partitionGraphs_)) return false; + if (!object.Equals(PreOptimizationGraph, other.PreOptimizationGraph)) return false; + if (!object.Equals(PostOptimizationGraph, other.PostOptimizationGraph)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= partitionGraphs_.GetHashCode(); + if (preOptimizationGraph_ != null) hash ^= PreOptimizationGraph.GetHashCode(); + if (postOptimizationGraph_ != null) hash ^= PostOptimizationGraph.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + partitionGraphs_.WriteTo(output, _repeated_partitionGraphs_codec); + if (preOptimizationGraph_ != null) { + output.WriteRawTag(18); + output.WriteMessage(PreOptimizationGraph); + } + if (postOptimizationGraph_ != null) { + output.WriteRawTag(26); + output.WriteMessage(PostOptimizationGraph); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + partitionGraphs_.WriteTo(ref output, _repeated_partitionGraphs_codec); + if (preOptimizationGraph_ != null) { + output.WriteRawTag(18); + output.WriteMessage(PreOptimizationGraph); + } + if (postOptimizationGraph_ != null) { + output.WriteRawTag(26); + output.WriteMessage(PostOptimizationGraph); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += partitionGraphs_.CalculateSize(_repeated_partitionGraphs_codec); + if (preOptimizationGraph_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PreOptimizationGraph); + } + if (postOptimizationGraph_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PostOptimizationGraph); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FunctionGraphs other) { + if (other == null) { + return; + } + partitionGraphs_.Add(other.partitionGraphs_); + if (other.preOptimizationGraph_ != null) { + if (preOptimizationGraph_ == null) { + PreOptimizationGraph = new global::Tensorflow.GraphDef(); + } + PreOptimizationGraph.MergeFrom(other.PreOptimizationGraph); + } + if (other.postOptimizationGraph_ != null) { + if (postOptimizationGraph_ == null) { + PostOptimizationGraph = new global::Tensorflow.GraphDef(); + } + PostOptimizationGraph.MergeFrom(other.PostOptimizationGraph); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + partitionGraphs_.AddEntriesFrom(input, _repeated_partitionGraphs_codec); + break; + } + case 18: { + if (preOptimizationGraph_ == null) { + PreOptimizationGraph = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(PreOptimizationGraph); + break; + } + case 26: { + if (postOptimizationGraph_ == null) { + PostOptimizationGraph = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(PostOptimizationGraph); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + partitionGraphs_.AddEntriesFrom(ref input, _repeated_partitionGraphs_codec); + break; + } + case 18: { + if (preOptimizationGraph_ == null) { + PreOptimizationGraph = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(PreOptimizationGraph); + break; + } + case 26: { + if (postOptimizationGraph_ == null) { + PostOptimizationGraph = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(PostOptimizationGraph); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Defines a connection between two tensors in a `GraphDef`. + /// + public sealed partial class TensorConnection : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorConnection()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorConnection() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorConnection(TensorConnection other) : this() { + fromTensor_ = other.fromTensor_; + toTensor_ = other.toTensor_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorConnection Clone() { + return new TensorConnection(this); + } + + /// Field number for the "from_tensor" field. + public const int FromTensorFieldNumber = 1; + private string fromTensor_ = ""; + /// + /// A tensor name. The value of this tensor will be substituted for + /// the tensor named in `to_tensor`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FromTensor { + get { return fromTensor_; } + set { + fromTensor_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "to_tensor" field. + public const int ToTensorFieldNumber = 2; + private string toTensor_ = ""; + /// + /// A tensor name. The value of this tensor will be bound to the + /// value of the tensor named in `from_tensor`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ToTensor { + get { return toTensor_; } + set { + toTensor_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TensorConnection); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TensorConnection other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FromTensor != other.FromTensor) return false; + if (ToTensor != other.ToTensor) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (FromTensor.Length != 0) hash ^= FromTensor.GetHashCode(); + if (ToTensor.Length != 0) hash ^= ToTensor.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (FromTensor.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FromTensor); + } + if (ToTensor.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ToTensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (FromTensor.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FromTensor); + } + if (ToTensor.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ToTensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (FromTensor.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FromTensor); + } + if (ToTensor.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ToTensor); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TensorConnection other) { + if (other == null) { + return; + } + if (other.FromTensor.Length != 0) { + FromTensor = other.FromTensor; + } + if (other.ToTensor.Length != 0) { + ToTensor = other.ToTensor; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FromTensor = input.ReadString(); + break; + } + case 18: { + ToTensor = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + FromTensor = input.ReadString(); + break; + } + case 18: { + ToTensor = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Defines a subgraph in another `GraphDef` as a set of feed points and nodes + /// to be fetched or executed. + /// + /// Compare with the arguments to `Session::Run()`. + /// + public sealed partial class CallableOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CallableOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ConfigReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CallableOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CallableOptions(CallableOptions other) : this() { + feed_ = other.feed_.Clone(); + fetch_ = other.fetch_.Clone(); + target_ = other.target_.Clone(); + runOptions_ = other.runOptions_ != null ? other.runOptions_.Clone() : null; + tensorConnection_ = other.tensorConnection_.Clone(); + feedDevices_ = other.feedDevices_.Clone(); + fetchDevices_ = other.fetchDevices_.Clone(); + fetchSkipSync_ = other.fetchSkipSync_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CallableOptions Clone() { + return new CallableOptions(this); + } + + /// Field number for the "feed" field. + public const int FeedFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_feed_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField feed_ = new pbc::RepeatedField(); + /// + /// Tensors to be fed in the callable. Each feed is the name of a tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Feed { + get { return feed_; } + } + + /// Field number for the "fetch" field. + public const int FetchFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_fetch_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField fetch_ = new pbc::RepeatedField(); + /// + /// Fetches. A list of tensor names. The caller of the callable expects a + /// tensor to be returned for each fetch[i] (see RunStepResponse.tensor). The + /// order of specified fetches does not change the execution order. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Fetch { + get { return fetch_; } + } + + /// Field number for the "target" field. + public const int TargetFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_target_codec + = pb::FieldCodec.ForString(26); + private readonly pbc::RepeatedField target_ = new pbc::RepeatedField(); + /// + /// Target Nodes. A list of node names. The named nodes will be run by the + /// callable but their outputs will not be returned. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Target { + get { return target_; } + } + + /// Field number for the "run_options" field. + public const int RunOptionsFieldNumber = 4; + private global::Tensorflow.RunOptions runOptions_; + /// + /// Options that will be applied to each run. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RunOptions RunOptions { + get { return runOptions_; } + set { + runOptions_ = value; + } + } + + /// Field number for the "tensor_connection" field. + public const int TensorConnectionFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_tensorConnection_codec + = pb::FieldCodec.ForMessage(42, global::Tensorflow.TensorConnection.Parser); + private readonly pbc::RepeatedField tensorConnection_ = new pbc::RepeatedField(); + /// + /// Tensors to be connected in the callable. Each TensorConnection denotes + /// a pair of tensors in the graph, between which an edge will be created + /// in the callable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TensorConnection { + get { return tensorConnection_; } + } + + /// Field number for the "feed_devices" field. + public const int FeedDevicesFieldNumber = 6; + private static readonly pbc::MapField.Codec _map_feedDevices_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 50); + private readonly pbc::MapField feedDevices_ = new pbc::MapField(); + /// + /// The Tensor objects fed in the callable and fetched from the callable + /// are expected to be backed by host (CPU) memory by default. + /// + /// The options below allow changing that - feeding tensors backed by + /// device memory, or returning tensors that are backed by device memory. + /// + /// The maps below map the name of a feed/fetch tensor (which appears in + /// 'feed' or 'fetch' fields above), to the fully qualified name of the device + /// owning the memory backing the contents of the tensor. + /// + /// For example, creating a callable with the following options: + /// + /// CallableOptions { + /// feed: "a:0" + /// feed: "b:0" + /// + /// fetch: "x:0" + /// fetch: "y:0" + /// + /// feed_devices: { + /// "a:0": "/job:localhost/replica:0/task:0/device:GPU:0" + /// } + /// + /// fetch_devices: { + /// "y:0": "/job:localhost/replica:0/task:0/device:GPU:0" + /// } + /// } + /// + /// means that the Callable expects: + /// - The first argument ("a:0") is a Tensor backed by GPU memory. + /// - The second argument ("b:0") is a Tensor backed by host memory. + /// and of its return values: + /// - The first output ("x:0") will be backed by host memory. + /// - The second output ("y:0") will be backed by GPU memory. + /// + /// FEEDS: + /// It is the responsibility of the caller to ensure that the memory of the fed + /// tensors will be correctly initialized and synchronized before it is + /// accessed by operations executed during the call to Session::RunCallable(). + /// + /// This is typically ensured by using the TensorFlow memory allocators + /// (Device::GetAllocator()) to create the Tensor to be fed. + /// + /// Alternatively, for CUDA-enabled GPU devices, this typically means that the + /// operation that produced the contents of the tensor has completed, i.e., the + /// CUDA stream has been synchronized (e.g., via cuCtxSynchronize() or + /// cuStreamSynchronize()). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField FeedDevices { + get { return feedDevices_; } + } + + /// Field number for the "fetch_devices" field. + public const int FetchDevicesFieldNumber = 7; + private static readonly pbc::MapField.Codec _map_fetchDevices_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 58); + private readonly pbc::MapField fetchDevices_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField FetchDevices { + get { return fetchDevices_; } + } + + /// Field number for the "fetch_skip_sync" field. + public const int FetchSkipSyncFieldNumber = 8; + private bool fetchSkipSync_; + /// + /// By default, RunCallable() will synchronize the GPU stream before returning + /// fetched tensors on a GPU device, to ensure that the values in those tensors + /// have been produced. This simplifies interacting with the tensors, but + /// potentially incurs a performance hit. + /// + /// If this options is set to true, the caller is responsible for ensuring + /// that the values in the fetched tensors have been produced before they are + /// used. The caller can do this by invoking `Device::Sync()` on the underlying + /// device(s), or by feeding the tensors back to the same Session using + /// `feed_devices` with the same corresponding device name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool FetchSkipSync { + get { return fetchSkipSync_; } + set { + fetchSkipSync_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CallableOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CallableOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!feed_.Equals(other.feed_)) return false; + if(!fetch_.Equals(other.fetch_)) return false; + if(!target_.Equals(other.target_)) return false; + if (!object.Equals(RunOptions, other.RunOptions)) return false; + if(!tensorConnection_.Equals(other.tensorConnection_)) return false; + if (!FeedDevices.Equals(other.FeedDevices)) return false; + if (!FetchDevices.Equals(other.FetchDevices)) return false; + if (FetchSkipSync != other.FetchSkipSync) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= feed_.GetHashCode(); + hash ^= fetch_.GetHashCode(); + hash ^= target_.GetHashCode(); + if (runOptions_ != null) hash ^= RunOptions.GetHashCode(); + hash ^= tensorConnection_.GetHashCode(); + hash ^= FeedDevices.GetHashCode(); + hash ^= FetchDevices.GetHashCode(); + if (FetchSkipSync != false) hash ^= FetchSkipSync.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + feed_.WriteTo(output, _repeated_feed_codec); + fetch_.WriteTo(output, _repeated_fetch_codec); + target_.WriteTo(output, _repeated_target_codec); + if (runOptions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(RunOptions); + } + tensorConnection_.WriteTo(output, _repeated_tensorConnection_codec); + feedDevices_.WriteTo(output, _map_feedDevices_codec); + fetchDevices_.WriteTo(output, _map_fetchDevices_codec); + if (FetchSkipSync != false) { + output.WriteRawTag(64); + output.WriteBool(FetchSkipSync); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + feed_.WriteTo(ref output, _repeated_feed_codec); + fetch_.WriteTo(ref output, _repeated_fetch_codec); + target_.WriteTo(ref output, _repeated_target_codec); + if (runOptions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(RunOptions); + } + tensorConnection_.WriteTo(ref output, _repeated_tensorConnection_codec); + feedDevices_.WriteTo(ref output, _map_feedDevices_codec); + fetchDevices_.WriteTo(ref output, _map_fetchDevices_codec); + if (FetchSkipSync != false) { + output.WriteRawTag(64); + output.WriteBool(FetchSkipSync); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += feed_.CalculateSize(_repeated_feed_codec); + size += fetch_.CalculateSize(_repeated_fetch_codec); + size += target_.CalculateSize(_repeated_target_codec); + if (runOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RunOptions); + } + size += tensorConnection_.CalculateSize(_repeated_tensorConnection_codec); + size += feedDevices_.CalculateSize(_map_feedDevices_codec); + size += fetchDevices_.CalculateSize(_map_fetchDevices_codec); + if (FetchSkipSync != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CallableOptions other) { + if (other == null) { + return; + } + feed_.Add(other.feed_); + fetch_.Add(other.fetch_); + target_.Add(other.target_); + if (other.runOptions_ != null) { + if (runOptions_ == null) { + RunOptions = new global::Tensorflow.RunOptions(); + } + RunOptions.MergeFrom(other.RunOptions); + } + tensorConnection_.Add(other.tensorConnection_); + feedDevices_.Add(other.feedDevices_); + fetchDevices_.Add(other.fetchDevices_); + if (other.FetchSkipSync != false) { + FetchSkipSync = other.FetchSkipSync; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + feed_.AddEntriesFrom(input, _repeated_feed_codec); + break; + } + case 18: { + fetch_.AddEntriesFrom(input, _repeated_fetch_codec); + break; + } + case 26: { + target_.AddEntriesFrom(input, _repeated_target_codec); + break; + } + case 34: { + if (runOptions_ == null) { + RunOptions = new global::Tensorflow.RunOptions(); + } + input.ReadMessage(RunOptions); + break; + } + case 42: { + tensorConnection_.AddEntriesFrom(input, _repeated_tensorConnection_codec); + break; + } + case 50: { + feedDevices_.AddEntriesFrom(input, _map_feedDevices_codec); + break; + } + case 58: { + fetchDevices_.AddEntriesFrom(input, _map_fetchDevices_codec); + break; + } + case 64: { + FetchSkipSync = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + feed_.AddEntriesFrom(ref input, _repeated_feed_codec); + break; + } + case 18: { + fetch_.AddEntriesFrom(ref input, _repeated_fetch_codec); + break; + } + case 26: { + target_.AddEntriesFrom(ref input, _repeated_target_codec); + break; + } + case 34: { + if (runOptions_ == null) { + RunOptions = new global::Tensorflow.RunOptions(); + } + input.ReadMessage(RunOptions); + break; + } + case 42: { + tensorConnection_.AddEntriesFrom(ref input, _repeated_tensorConnection_codec); + break; + } + case 50: { + feedDevices_.AddEntriesFrom(ref input, _map_feedDevices_codec); + break; + } + case 58: { + fetchDevices_.AddEntriesFrom(ref input, _map_fetchDevices_codec); + break; + } + case 64: { + FetchSkipSync = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs new file mode 100644 index 000000000..3ede374cb --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs @@ -0,0 +1,1574 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/control_flow.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/control_flow.proto + public static partial class ControlFlowReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/control_flow.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ControlFlowReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cit0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29udHJvbF9mbG93LnByb3Rv", + "Egp0ZW5zb3JmbG93IpYBCglWYWx1ZXNEZWYSDgoGdmFsdWVzGAEgAygJEkIK", + "D2V4dGVybmFsX3ZhbHVlcxgCIAMoCzIpLnRlbnNvcmZsb3cuVmFsdWVzRGVm", + "LkV4dGVybmFsVmFsdWVzRW50cnkaNQoTRXh0ZXJuYWxWYWx1ZXNFbnRyeRIL", + "CgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6AjgBIoMBChVDb250cm9sRmxv", + "d0NvbnRleHREZWYSLwoJY29uZF9jdHh0GAEgASgLMhoudGVuc29yZmxvdy5D", + "b25kQ29udGV4dERlZkgAEjEKCndoaWxlX2N0eHQYAiABKAsyGy50ZW5zb3Jm", + "bG93LldoaWxlQ29udGV4dERlZkgAQgYKBGN0eHQixAEKDkNvbmRDb250ZXh0", + "RGVmEhQKDGNvbnRleHRfbmFtZRgBIAEoCRIRCglwcmVkX25hbWUYAiABKAkS", + "EgoKcGl2b3RfbmFtZRgDIAEoCRIOCgZicmFuY2gYBCABKAUSKQoKdmFsdWVz", + "X2RlZhgFIAEoCzIVLnRlbnNvcmZsb3cuVmFsdWVzRGVmEjoKD25lc3RlZF9j", + "b250ZXh0cxgGIAMoCzIhLnRlbnNvcmZsb3cuQ29udHJvbEZsb3dDb250ZXh0", + "RGVmIvUCCg9XaGlsZUNvbnRleHREZWYSFAoMY29udGV4dF9uYW1lGAEgASgJ", + "EhsKE3BhcmFsbGVsX2l0ZXJhdGlvbnMYAiABKAUSEQoJYmFja19wcm9wGAMg", + "ASgIEhMKC3N3YXBfbWVtb3J5GAQgASgIEhIKCnBpdm90X25hbWUYBSABKAkS", + "GwoTcGl2b3RfZm9yX3ByZWRfbmFtZRgGIAEoCRIbChNwaXZvdF9mb3JfYm9k", + "eV9uYW1lGAcgASgJEhcKD2xvb3BfZXhpdF9uYW1lcxgIIAMoCRIYChBsb29w", + "X2VudGVyX25hbWVzGAogAygJEikKCnZhbHVlc19kZWYYCSABKAsyFS50ZW5z", + "b3JmbG93LlZhbHVlc0RlZhIfChdtYXhpbXVtX2l0ZXJhdGlvbnNfbmFtZRgL", + "IAEoCRI6Cg9uZXN0ZWRfY29udGV4dHMYDCADKAsyIS50ZW5zb3JmbG93LkNv", + "bnRyb2xGbG93Q29udGV4dERlZkKJAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3", + "b3JrQhFDb250cm9sRmxvd1Byb3Rvc1ABWlVnaXRodWIuY29tL3RlbnNvcmZs", + "b3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvcHJvdG9idWYvZm9y", + "X2NvcmVfcHJvdG9zX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ValuesDef), global::Tensorflow.ValuesDef.Parser, new[]{ "Values", "ExternalValues" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ControlFlowContextDef), global::Tensorflow.ControlFlowContextDef.Parser, new[]{ "CondCtxt", "WhileCtxt" }, new[]{ "Ctxt" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CondContextDef), global::Tensorflow.CondContextDef.Parser, new[]{ "ContextName", "PredName", "PivotName", "Branch", "ValuesDef", "NestedContexts" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WhileContextDef), global::Tensorflow.WhileContextDef.Parser, new[]{ "ContextName", "ParallelIterations", "BackProp", "SwapMemory", "PivotName", "PivotForPredName", "PivotForBodyName", "LoopExitNames", "LoopEnterNames", "ValuesDef", "MaximumIterationsName", "NestedContexts" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the values in ControlFlowContext. + /// + public sealed partial class ValuesDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ValuesDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ValuesDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ValuesDef(ValuesDef other) : this() { + values_ = other.values_.Clone(); + externalValues_ = other.externalValues_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ValuesDef Clone() { + return new ValuesDef(this); + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + /// + /// Value names that have been seen in this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Values { + get { return values_; } + } + + /// Field number for the "external_values" field. + public const int ExternalValuesFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_externalValues_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 18); + private readonly pbc::MapField externalValues_ = new pbc::MapField(); + /// + /// Value names referenced by but external to this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ExternalValues { + get { return externalValues_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ValuesDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ValuesDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!values_.Equals(other.values_)) return false; + if (!ExternalValues.Equals(other.ExternalValues)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= values_.GetHashCode(); + hash ^= ExternalValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + values_.WriteTo(output, _repeated_values_codec); + externalValues_.WriteTo(output, _map_externalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + values_.WriteTo(ref output, _repeated_values_codec); + externalValues_.WriteTo(ref output, _map_externalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += values_.CalculateSize(_repeated_values_codec); + size += externalValues_.CalculateSize(_map_externalValues_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ValuesDef other) { + if (other == null) { + return; + } + values_.Add(other.values_); + externalValues_.Add(other.externalValues_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + case 18: { + externalValues_.AddEntriesFrom(input, _map_externalValues_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + values_.AddEntriesFrom(ref input, _repeated_values_codec); + break; + } + case 18: { + externalValues_.AddEntriesFrom(ref input, _map_externalValues_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Container for any kind of control flow context. Any other control flow + /// contexts that are added below should also be added here. + /// + public sealed partial class ControlFlowContextDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ControlFlowContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ControlFlowContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ControlFlowContextDef(ControlFlowContextDef other) : this() { + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + CondCtxt = other.CondCtxt.Clone(); + break; + case CtxtOneofCase.WhileCtxt: + WhileCtxt = other.WhileCtxt.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ControlFlowContextDef Clone() { + return new ControlFlowContextDef(this); + } + + /// Field number for the "cond_ctxt" field. + public const int CondCtxtFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CondContextDef CondCtxt { + get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.CondCtxt; + } + } + + /// Field number for the "while_ctxt" field. + public const int WhileCtxtFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.WhileContextDef WhileCtxt { + get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.WhileCtxt; + } + } + + private object ctxt_; + /// Enum of possible cases for the "ctxt" oneof. + public enum CtxtOneofCase { + None = 0, + CondCtxt = 1, + WhileCtxt = 2, + } + private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CtxtOneofCase CtxtCase { + get { return ctxtCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearCtxt() { + ctxtCase_ = CtxtOneofCase.None; + ctxt_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ControlFlowContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ControlFlowContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(CondCtxt, other.CondCtxt)) return false; + if (!object.Equals(WhileCtxt, other.WhileCtxt)) return false; + if (CtxtCase != other.CtxtCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) hash ^= WhileCtxt.GetHashCode(); + hash ^= (int) ctxtCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + output.WriteRawTag(10); + output.WriteMessage(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + output.WriteRawTag(18); + output.WriteMessage(WhileCtxt); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + output.WriteRawTag(10); + output.WriteMessage(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + output.WriteRawTag(18); + output.WriteMessage(WhileCtxt); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WhileCtxt); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ControlFlowContextDef other) { + if (other == null) { + return; + } + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + if (CondCtxt == null) { + CondCtxt = new global::Tensorflow.CondContextDef(); + } + CondCtxt.MergeFrom(other.CondCtxt); + break; + case CtxtOneofCase.WhileCtxt: + if (WhileCtxt == null) { + WhileCtxt = new global::Tensorflow.WhileContextDef(); + } + WhileCtxt.MergeFrom(other.WhileCtxt); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + subBuilder.MergeFrom(CondCtxt); + } + input.ReadMessage(subBuilder); + CondCtxt = subBuilder; + break; + } + case 18: { + global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + subBuilder.MergeFrom(WhileCtxt); + } + input.ReadMessage(subBuilder); + WhileCtxt = subBuilder; + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + subBuilder.MergeFrom(CondCtxt); + } + input.ReadMessage(subBuilder); + CondCtxt = subBuilder; + break; + } + case 18: { + global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + subBuilder.MergeFrom(WhileCtxt); + } + input.ReadMessage(subBuilder); + WhileCtxt = subBuilder; + break; + } + } + } + } + #endif + + } + + /// + /// Protocol buffer representing a CondContext object. + /// + public sealed partial class CondContextDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CondContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CondContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CondContextDef(CondContextDef other) : this() { + contextName_ = other.contextName_; + predName_ = other.predName_; + pivotName_ = other.pivotName_; + branch_ = other.branch_; + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CondContextDef Clone() { + return new CondContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pred_name" field. + public const int PredNameFieldNumber = 2; + private string predName_ = ""; + /// + /// Name of the pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PredName { + get { return predName_; } + set { + predName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 3; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "branch" field. + public const int BranchFieldNumber = 4; + private int branch_; + /// + /// Branch prediction. 0 or 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Branch { + get { return branch_; } + set { + branch_ = value; + } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 5; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested conds). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CondContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CondContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (PredName != other.PredName) return false; + if (PivotName != other.PivotName) return false; + if (Branch != other.Branch) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (PredName.Length != 0) hash ^= PredName.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (Branch != 0) hash ^= Branch.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (PredName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PredName); + } + if (PivotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PivotName); + } + if (Branch != 0) { + output.WriteRawTag(32); + output.WriteInt32(Branch); + } + if (valuesDef_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ValuesDef); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (PredName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PredName); + } + if (PivotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PivotName); + } + if (Branch != 0) { + output.WriteRawTag(32); + output.WriteInt32(Branch); + } + if (valuesDef_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ValuesDef); + } + nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (PredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PredName); + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (Branch != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Branch); + } + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CondContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.PredName.Length != 0) { + PredName = other.PredName; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.Branch != 0) { + Branch = other.Branch; + } + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 18: { + PredName = input.ReadString(); + break; + } + case 26: { + PivotName = input.ReadString(); + break; + } + case 32: { + Branch = input.ReadInt32(); + break; + } + case 42: { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(ValuesDef); + break; + } + case 50: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 18: { + PredName = input.ReadString(); + break; + } + case 26: { + PivotName = input.ReadString(); + break; + } + case 32: { + Branch = input.ReadInt32(); + break; + } + case 42: { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(ValuesDef); + break; + } + case 50: { + nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Protocol buffer representing a WhileContext object. + /// + public sealed partial class WhileContextDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WhileContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileContextDef(WhileContextDef other) : this() { + contextName_ = other.contextName_; + parallelIterations_ = other.parallelIterations_; + backProp_ = other.backProp_; + swapMemory_ = other.swapMemory_; + pivotName_ = other.pivotName_; + pivotForPredName_ = other.pivotForPredName_; + pivotForBodyName_ = other.pivotForBodyName_; + loopExitNames_ = other.loopExitNames_.Clone(); + loopEnterNames_ = other.loopEnterNames_.Clone(); + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + maximumIterationsName_ = other.maximumIterationsName_; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileContextDef Clone() { + return new WhileContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "parallel_iterations" field. + public const int ParallelIterationsFieldNumber = 2; + private int parallelIterations_; + /// + /// The number of iterations allowed to run in parallel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ParallelIterations { + get { return parallelIterations_; } + set { + parallelIterations_ = value; + } + } + + /// Field number for the "back_prop" field. + public const int BackPropFieldNumber = 3; + private bool backProp_; + /// + /// Whether backprop is enabled for this while loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool BackProp { + get { return backProp_; } + set { + backProp_ = value; + } + } + + /// Field number for the "swap_memory" field. + public const int SwapMemoryFieldNumber = 4; + private bool swapMemory_; + /// + /// Whether GPU-CPU memory swap is enabled for this loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool SwapMemory { + get { return swapMemory_; } + set { + swapMemory_ = value; + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 5; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_pred_name" field. + public const int PivotForPredNameFieldNumber = 6; + private string pivotForPredName_ = ""; + /// + /// Name of the pivot_for_pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PivotForPredName { + get { return pivotForPredName_; } + set { + pivotForPredName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_body_name" field. + public const int PivotForBodyNameFieldNumber = 7; + private string pivotForBodyName_ = ""; + /// + /// Name of the pivot_for_body tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PivotForBodyName { + get { return pivotForBodyName_; } + set { + pivotForBodyName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "loop_exit_names" field. + public const int LoopExitNamesFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_loopExitNames_codec + = pb::FieldCodec.ForString(66); + private readonly pbc::RepeatedField loopExitNames_ = new pbc::RepeatedField(); + /// + /// List of names for exit tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LoopExitNames { + get { return loopExitNames_; } + } + + /// Field number for the "loop_enter_names" field. + public const int LoopEnterNamesFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_loopEnterNames_codec + = pb::FieldCodec.ForString(82); + private readonly pbc::RepeatedField loopEnterNames_ = new pbc::RepeatedField(); + /// + /// List of names for enter tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LoopEnterNames { + get { return loopEnterNames_; } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 9; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "maximum_iterations_name" field. + public const int MaximumIterationsNameFieldNumber = 11; + private string maximumIterationsName_ = ""; + /// + /// Optional name of the maximum_iterations tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string MaximumIterationsName { + get { return maximumIterationsName_; } + set { + maximumIterationsName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(98, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested whiles). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WhileContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WhileContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (ParallelIterations != other.ParallelIterations) return false; + if (BackProp != other.BackProp) return false; + if (SwapMemory != other.SwapMemory) return false; + if (PivotName != other.PivotName) return false; + if (PivotForPredName != other.PivotForPredName) return false; + if (PivotForBodyName != other.PivotForBodyName) return false; + if(!loopExitNames_.Equals(other.loopExitNames_)) return false; + if(!loopEnterNames_.Equals(other.loopEnterNames_)) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if (MaximumIterationsName != other.MaximumIterationsName) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (ParallelIterations != 0) hash ^= ParallelIterations.GetHashCode(); + if (BackProp != false) hash ^= BackProp.GetHashCode(); + if (SwapMemory != false) hash ^= SwapMemory.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (PivotForPredName.Length != 0) hash ^= PivotForPredName.GetHashCode(); + if (PivotForBodyName.Length != 0) hash ^= PivotForBodyName.GetHashCode(); + hash ^= loopExitNames_.GetHashCode(); + hash ^= loopEnterNames_.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + if (MaximumIterationsName.Length != 0) hash ^= MaximumIterationsName.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (ParallelIterations != 0) { + output.WriteRawTag(16); + output.WriteInt32(ParallelIterations); + } + if (BackProp != false) { + output.WriteRawTag(24); + output.WriteBool(BackProp); + } + if (SwapMemory != false) { + output.WriteRawTag(32); + output.WriteBool(SwapMemory); + } + if (PivotName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(PivotName); + } + if (PivotForPredName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PivotForBodyName); + } + loopExitNames_.WriteTo(output, _repeated_loopExitNames_codec); + if (valuesDef_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ValuesDef); + } + loopEnterNames_.WriteTo(output, _repeated_loopEnterNames_codec); + if (MaximumIterationsName.Length != 0) { + output.WriteRawTag(90); + output.WriteString(MaximumIterationsName); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (ParallelIterations != 0) { + output.WriteRawTag(16); + output.WriteInt32(ParallelIterations); + } + if (BackProp != false) { + output.WriteRawTag(24); + output.WriteBool(BackProp); + } + if (SwapMemory != false) { + output.WriteRawTag(32); + output.WriteBool(SwapMemory); + } + if (PivotName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(PivotName); + } + if (PivotForPredName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PivotForBodyName); + } + loopExitNames_.WriteTo(ref output, _repeated_loopExitNames_codec); + if (valuesDef_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ValuesDef); + } + loopEnterNames_.WriteTo(ref output, _repeated_loopEnterNames_codec); + if (MaximumIterationsName.Length != 0) { + output.WriteRawTag(90); + output.WriteString(MaximumIterationsName); + } + nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (ParallelIterations != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ParallelIterations); + } + if (BackProp != false) { + size += 1 + 1; + } + if (SwapMemory != false) { + size += 1 + 1; + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (PivotForPredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForBodyName); + } + size += loopExitNames_.CalculateSize(_repeated_loopExitNames_codec); + size += loopEnterNames_.CalculateSize(_repeated_loopEnterNames_codec); + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + if (MaximumIterationsName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MaximumIterationsName); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WhileContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.ParallelIterations != 0) { + ParallelIterations = other.ParallelIterations; + } + if (other.BackProp != false) { + BackProp = other.BackProp; + } + if (other.SwapMemory != false) { + SwapMemory = other.SwapMemory; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.PivotForPredName.Length != 0) { + PivotForPredName = other.PivotForPredName; + } + if (other.PivotForBodyName.Length != 0) { + PivotForBodyName = other.PivotForBodyName; + } + loopExitNames_.Add(other.loopExitNames_); + loopEnterNames_.Add(other.loopEnterNames_); + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + if (other.MaximumIterationsName.Length != 0) { + MaximumIterationsName = other.MaximumIterationsName; + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 16: { + ParallelIterations = input.ReadInt32(); + break; + } + case 24: { + BackProp = input.ReadBool(); + break; + } + case 32: { + SwapMemory = input.ReadBool(); + break; + } + case 42: { + PivotName = input.ReadString(); + break; + } + case 50: { + PivotForPredName = input.ReadString(); + break; + } + case 58: { + PivotForBodyName = input.ReadString(); + break; + } + case 66: { + loopExitNames_.AddEntriesFrom(input, _repeated_loopExitNames_codec); + break; + } + case 74: { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(ValuesDef); + break; + } + case 82: { + loopEnterNames_.AddEntriesFrom(input, _repeated_loopEnterNames_codec); + break; + } + case 90: { + MaximumIterationsName = input.ReadString(); + break; + } + case 98: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 16: { + ParallelIterations = input.ReadInt32(); + break; + } + case 24: { + BackProp = input.ReadBool(); + break; + } + case 32: { + SwapMemory = input.ReadBool(); + break; + } + case 42: { + PivotName = input.ReadString(); + break; + } + case 50: { + PivotForPredName = input.ReadString(); + break; + } + case 58: { + PivotForBodyName = input.ReadString(); + break; + } + case 66: { + loopExitNames_.AddEntriesFrom(ref input, _repeated_loopExitNames_codec); + break; + } + case 74: { + if (valuesDef_ == null) { + ValuesDef = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(ValuesDef); + break; + } + case 82: { + loopEnterNames_.AddEntriesFrom(ref input, _repeated_loopEnterNames_codec); + break; + } + case 90: { + MaximumIterationsName = input.ReadString(); + break; + } + case 98: { + nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/CoordinationConfig.cs b/src/TensorFlowNET.Core/Protobuf/CoordinationConfig.cs new file mode 100644 index 000000000..c949067cd --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CoordinationConfig.cs @@ -0,0 +1,791 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/coordination_config.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/coordination_config.proto + public static partial class CoordinationConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/coordination_config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CoordinationConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjJ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX2NvbmZp", + "Zy5wcm90bxIKdGVuc29yZmxvdyIxCg5Db29yZGluYXRlZEpvYhIMCgRuYW1l", + "GAEgASgJEhEKCW51bV90YXNrcxgCIAEoBSLdAgoZQ29vcmRpbmF0aW9uU2Vy", + "dmljZUNvbmZpZxIUCgxzZXJ2aWNlX3R5cGUYASABKAkSFgoOc2VydmljZV9s", + "ZWFkZXIYAiABKAkSGwoTZW5hYmxlX2hlYWx0aF9jaGVjaxgDIAEoCBImCh5j", + "bHVzdGVyX3JlZ2lzdGVyX3RpbWVvdXRfaW5fbXMYBCABKAMSHwoXaGVhcnRi", + "ZWF0X3RpbWVvdXRfaW5fbXMYBSABKAMSOAoUY29vcmRpbmF0ZWRfam9iX2xp", + "c3QYCiADKAsyGi50ZW5zb3JmbG93LkNvb3JkaW5hdGVkSm9iEiYKHnNodXRk", + "b3duX2JhcnJpZXJfdGltZW91dF9pbl9tcxgHIAEoAxIqCiJhZ2VudF9kZXN0", + "cnVjdGlvbl93aXRob3V0X3NodXRkb3duGAggASgIEhgKEHJlY292ZXJhYmxl", + "X2pvYnMYCSADKAlKBAgGEAdCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", + "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl", + "X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedJob), global::Tensorflow.CoordinatedJob.Parser, new[]{ "Name", "NumTasks" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceConfig), global::Tensorflow.CoordinationServiceConfig.Parser, new[]{ "ServiceType", "ServiceLeader", "EnableHealthCheck", "ClusterRegisterTimeoutInMs", "HeartbeatTimeoutInMs", "CoordinatedJobList", "ShutdownBarrierTimeoutInMs", "AgentDestructionWithoutShutdown", "RecoverableJobs" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Represents a job type and the number of tasks under this job. + /// For example, ("worker", 20) implies that there will be 20 worker tasks. + /// + public sealed partial class CoordinatedJob : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinatedJob()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedJob() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedJob(CoordinatedJob other) : this() { + name_ = other.name_; + numTasks_ = other.numTasks_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedJob Clone() { + return new CoordinatedJob(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_tasks" field. + public const int NumTasksFieldNumber = 2; + private int numTasks_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumTasks { + get { return numTasks_; } + set { + numTasks_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinatedJob); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinatedJob other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (NumTasks != other.NumTasks) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (NumTasks != 0) hash ^= NumTasks.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NumTasks != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumTasks); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NumTasks != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumTasks); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (NumTasks != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumTasks); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinatedJob other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.NumTasks != 0) { + NumTasks = other.NumTasks; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NumTasks = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NumTasks = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + /// + /// Coordination service configuration parameters. + /// The system picks appropriate values for fields that are not set. + /// + public sealed partial class CoordinationServiceConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinationServiceConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceConfig(CoordinationServiceConfig other) : this() { + serviceType_ = other.serviceType_; + serviceLeader_ = other.serviceLeader_; + enableHealthCheck_ = other.enableHealthCheck_; + clusterRegisterTimeoutInMs_ = other.clusterRegisterTimeoutInMs_; + heartbeatTimeoutInMs_ = other.heartbeatTimeoutInMs_; + coordinatedJobList_ = other.coordinatedJobList_.Clone(); + shutdownBarrierTimeoutInMs_ = other.shutdownBarrierTimeoutInMs_; + agentDestructionWithoutShutdown_ = other.agentDestructionWithoutShutdown_; + recoverableJobs_ = other.recoverableJobs_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceConfig Clone() { + return new CoordinationServiceConfig(this); + } + + /// Field number for the "service_type" field. + public const int ServiceTypeFieldNumber = 1; + private string serviceType_ = ""; + /// + /// Type of coordination service implementation to enable. + /// For example, setting the service type as "standalone" starts a service + /// instance on the leader task to provide the coordination services such as + /// heartbeats and consistent key-value store. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ServiceType { + get { return serviceType_; } + set { + serviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "service_leader" field. + public const int ServiceLeaderFieldNumber = 2; + private string serviceLeader_ = ""; + /// + /// Address where the coordination service instance is hosted. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ServiceLeader { + get { return serviceLeader_; } + set { + serviceLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "enable_health_check" field. + public const int EnableHealthCheckFieldNumber = 3; + private bool enableHealthCheck_; + /// + /// Whether to enable the health check mechanism. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool EnableHealthCheck { + get { return enableHealthCheck_; } + set { + enableHealthCheck_ = value; + } + } + + /// Field number for the "cluster_register_timeout_in_ms" field. + public const int ClusterRegisterTimeoutInMsFieldNumber = 4; + private long clusterRegisterTimeoutInMs_; + /// + /// Maximum wait time for all members in the cluster to be registered. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ClusterRegisterTimeoutInMs { + get { return clusterRegisterTimeoutInMs_; } + set { + clusterRegisterTimeoutInMs_ = value; + } + } + + /// Field number for the "heartbeat_timeout_in_ms" field. + public const int HeartbeatTimeoutInMsFieldNumber = 5; + private long heartbeatTimeoutInMs_; + /// + /// Heartbeat timeout, if a task does not record heartbeat in this time + /// window, it will be considered disconnected. + /// Note: This is also used as a grace period to accept any heartbeats after + /// the agent has disconnected, to account for the lag time between the service + /// recording the state change and the agent stopping heartbeats. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long HeartbeatTimeoutInMs { + get { return heartbeatTimeoutInMs_; } + set { + heartbeatTimeoutInMs_ = value; + } + } + + /// Field number for the "coordinated_job_list" field. + public const int CoordinatedJobListFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_coordinatedJobList_codec + = pb::FieldCodec.ForMessage(82, global::Tensorflow.CoordinatedJob.Parser); + private readonly pbc::RepeatedField coordinatedJobList_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CoordinatedJobList { + get { return coordinatedJobList_; } + } + + /// Field number for the "shutdown_barrier_timeout_in_ms" field. + public const int ShutdownBarrierTimeoutInMsFieldNumber = 7; + private long shutdownBarrierTimeoutInMs_; + /// + /// Denotes how long to wait for all coordination agents to reach the barriers + /// (after the first shutdown request) before disconnecting together. If + /// set to 0, no barrier is imposed upon shutdown and each worker can + /// disconnect individually. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ShutdownBarrierTimeoutInMs { + get { return shutdownBarrierTimeoutInMs_; } + set { + shutdownBarrierTimeoutInMs_ = value; + } + } + + /// Field number for the "agent_destruction_without_shutdown" field. + public const int AgentDestructionWithoutShutdownFieldNumber = 8; + private bool agentDestructionWithoutShutdown_; + /// + /// If set, agents do not make an explicit Shutdown() call. Service will only + /// find out about the disconnecte agent via stale heartbeats. Used for + /// testing. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool AgentDestructionWithoutShutdown { + get { return agentDestructionWithoutShutdown_; } + set { + agentDestructionWithoutShutdown_ = value; + } + } + + /// Field number for the "recoverable_jobs" field. + public const int RecoverableJobsFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_recoverableJobs_codec + = pb::FieldCodec.ForString(74); + private readonly pbc::RepeatedField recoverableJobs_ = new pbc::RepeatedField(); + /// + /// The list of jobs which are recoverable. If a task in this list fails, + /// it will not propagate error to other tasks. + /// If empty, no jobs will be recoverable and every task failure will cause + /// error propagation to other tasks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField RecoverableJobs { + get { return recoverableJobs_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinationServiceConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinationServiceConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ServiceType != other.ServiceType) return false; + if (ServiceLeader != other.ServiceLeader) return false; + if (EnableHealthCheck != other.EnableHealthCheck) return false; + if (ClusterRegisterTimeoutInMs != other.ClusterRegisterTimeoutInMs) return false; + if (HeartbeatTimeoutInMs != other.HeartbeatTimeoutInMs) return false; + if(!coordinatedJobList_.Equals(other.coordinatedJobList_)) return false; + if (ShutdownBarrierTimeoutInMs != other.ShutdownBarrierTimeoutInMs) return false; + if (AgentDestructionWithoutShutdown != other.AgentDestructionWithoutShutdown) return false; + if(!recoverableJobs_.Equals(other.recoverableJobs_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ServiceType.Length != 0) hash ^= ServiceType.GetHashCode(); + if (ServiceLeader.Length != 0) hash ^= ServiceLeader.GetHashCode(); + if (EnableHealthCheck != false) hash ^= EnableHealthCheck.GetHashCode(); + if (ClusterRegisterTimeoutInMs != 0L) hash ^= ClusterRegisterTimeoutInMs.GetHashCode(); + if (HeartbeatTimeoutInMs != 0L) hash ^= HeartbeatTimeoutInMs.GetHashCode(); + hash ^= coordinatedJobList_.GetHashCode(); + if (ShutdownBarrierTimeoutInMs != 0L) hash ^= ShutdownBarrierTimeoutInMs.GetHashCode(); + if (AgentDestructionWithoutShutdown != false) hash ^= AgentDestructionWithoutShutdown.GetHashCode(); + hash ^= recoverableJobs_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ServiceType.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ServiceType); + } + if (ServiceLeader.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ServiceLeader); + } + if (EnableHealthCheck != false) { + output.WriteRawTag(24); + output.WriteBool(EnableHealthCheck); + } + if (ClusterRegisterTimeoutInMs != 0L) { + output.WriteRawTag(32); + output.WriteInt64(ClusterRegisterTimeoutInMs); + } + if (HeartbeatTimeoutInMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(HeartbeatTimeoutInMs); + } + if (ShutdownBarrierTimeoutInMs != 0L) { + output.WriteRawTag(56); + output.WriteInt64(ShutdownBarrierTimeoutInMs); + } + if (AgentDestructionWithoutShutdown != false) { + output.WriteRawTag(64); + output.WriteBool(AgentDestructionWithoutShutdown); + } + recoverableJobs_.WriteTo(output, _repeated_recoverableJobs_codec); + coordinatedJobList_.WriteTo(output, _repeated_coordinatedJobList_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ServiceType.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ServiceType); + } + if (ServiceLeader.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ServiceLeader); + } + if (EnableHealthCheck != false) { + output.WriteRawTag(24); + output.WriteBool(EnableHealthCheck); + } + if (ClusterRegisterTimeoutInMs != 0L) { + output.WriteRawTag(32); + output.WriteInt64(ClusterRegisterTimeoutInMs); + } + if (HeartbeatTimeoutInMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(HeartbeatTimeoutInMs); + } + if (ShutdownBarrierTimeoutInMs != 0L) { + output.WriteRawTag(56); + output.WriteInt64(ShutdownBarrierTimeoutInMs); + } + if (AgentDestructionWithoutShutdown != false) { + output.WriteRawTag(64); + output.WriteBool(AgentDestructionWithoutShutdown); + } + recoverableJobs_.WriteTo(ref output, _repeated_recoverableJobs_codec); + coordinatedJobList_.WriteTo(ref output, _repeated_coordinatedJobList_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ServiceType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceType); + } + if (ServiceLeader.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceLeader); + } + if (EnableHealthCheck != false) { + size += 1 + 1; + } + if (ClusterRegisterTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClusterRegisterTimeoutInMs); + } + if (HeartbeatTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatTimeoutInMs); + } + size += coordinatedJobList_.CalculateSize(_repeated_coordinatedJobList_codec); + if (ShutdownBarrierTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownBarrierTimeoutInMs); + } + if (AgentDestructionWithoutShutdown != false) { + size += 1 + 1; + } + size += recoverableJobs_.CalculateSize(_repeated_recoverableJobs_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinationServiceConfig other) { + if (other == null) { + return; + } + if (other.ServiceType.Length != 0) { + ServiceType = other.ServiceType; + } + if (other.ServiceLeader.Length != 0) { + ServiceLeader = other.ServiceLeader; + } + if (other.EnableHealthCheck != false) { + EnableHealthCheck = other.EnableHealthCheck; + } + if (other.ClusterRegisterTimeoutInMs != 0L) { + ClusterRegisterTimeoutInMs = other.ClusterRegisterTimeoutInMs; + } + if (other.HeartbeatTimeoutInMs != 0L) { + HeartbeatTimeoutInMs = other.HeartbeatTimeoutInMs; + } + coordinatedJobList_.Add(other.coordinatedJobList_); + if (other.ShutdownBarrierTimeoutInMs != 0L) { + ShutdownBarrierTimeoutInMs = other.ShutdownBarrierTimeoutInMs; + } + if (other.AgentDestructionWithoutShutdown != false) { + AgentDestructionWithoutShutdown = other.AgentDestructionWithoutShutdown; + } + recoverableJobs_.Add(other.recoverableJobs_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ServiceType = input.ReadString(); + break; + } + case 18: { + ServiceLeader = input.ReadString(); + break; + } + case 24: { + EnableHealthCheck = input.ReadBool(); + break; + } + case 32: { + ClusterRegisterTimeoutInMs = input.ReadInt64(); + break; + } + case 40: { + HeartbeatTimeoutInMs = input.ReadInt64(); + break; + } + case 56: { + ShutdownBarrierTimeoutInMs = input.ReadInt64(); + break; + } + case 64: { + AgentDestructionWithoutShutdown = input.ReadBool(); + break; + } + case 74: { + recoverableJobs_.AddEntriesFrom(input, _repeated_recoverableJobs_codec); + break; + } + case 82: { + coordinatedJobList_.AddEntriesFrom(input, _repeated_coordinatedJobList_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ServiceType = input.ReadString(); + break; + } + case 18: { + ServiceLeader = input.ReadString(); + break; + } + case 24: { + EnableHealthCheck = input.ReadBool(); + break; + } + case 32: { + ClusterRegisterTimeoutInMs = input.ReadInt64(); + break; + } + case 40: { + HeartbeatTimeoutInMs = input.ReadInt64(); + break; + } + case 56: { + ShutdownBarrierTimeoutInMs = input.ReadInt64(); + break; + } + case 64: { + AgentDestructionWithoutShutdown = input.ReadBool(); + break; + } + case 74: { + recoverableJobs_.AddEntriesFrom(ref input, _repeated_recoverableJobs_codec); + break; + } + case 82: { + coordinatedJobList_.AddEntriesFrom(ref input, _repeated_coordinatedJobList_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/CoordinationService.cs b/src/TensorFlowNET.Core/Protobuf/CoordinationService.cs new file mode 100644 index 000000000..a974d724d --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CoordinationService.cs @@ -0,0 +1,7964 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/coordination_service.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/coordination_service.proto + public static partial class CoordinationServiceReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/coordination_service.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CoordinationServiceReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjN0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX3NlcnZp", + "Y2UucHJvdG8SCnRlbnNvcmZsb3caN3RlbnNvcmZsb3cvY29tcGlsZXIveGxh", + "L3BqcnQvZGlzdHJpYnV0ZWQvcHJvdG9jb2wucHJvdG8aMXRlbnNvcmZsb3cv", + "Y29yZS9mcmFtZXdvcmsvZGV2aWNlX2F0dHJpYnV0ZXMucHJvdG8iNAoPQ29v", + "cmRpbmF0ZWRUYXNrEhAKCGpvYl9uYW1lGAEgASgJEg8KB3Rhc2tfaWQYAiAB", + "KAUicwoYQ29vcmRpbmF0aW9uU2VydmljZUVycm9yEhkKEWlzX3JlcG9ydGVk", + "X2Vycm9yGAMgASgIEjAKC3NvdXJjZV90YXNrGAQgASgLMhsudGVuc29yZmxv", + "dy5Db29yZGluYXRlZFRhc2tKBAgBEAJKBAgCEAMi3gEKGENvb3JkaW5hdGVk", + "VGFza1N0YXRlSW5mbxIpCgR0YXNrGAEgASgLMhsudGVuc29yZmxvdy5Db29y", + "ZGluYXRlZFRhc2sSLwoFc3RhdGUYAiABKA4yIC50ZW5zb3JmbG93LkNvb3Jk", + "aW5hdGVkVGFza1N0YXRlEhIKCmVycm9yX2NvZGUYAyABKAUSFQoNZXJyb3Jf", + "bWVzc2FnZRgEIAEoCRI7Cg1lcnJvcl9wYXlsb2FkGAUgASgLMiQudGVuc29y", + "Zmxvdy5Db29yZGluYXRpb25TZXJ2aWNlRXJyb3IiPQoMVGZEZXZpY2VMaXN0", + "Ei0KB2RldmljZXMYASADKAsyHC50ZW5zb3JmbG93LkRldmljZUF0dHJpYnV0", + "ZXMiOgoNWGxhRGV2aWNlTGlzdBIpCgdkZXZpY2VzGAEgASgLMhgueGxhLkds", + "b2JhbFRvcG9sb2d5UHJvdG8ieQodQ29vcmRpbmF0aW9uU2VydmljZURldmlj", + "ZUluZm8SJgoCdGYYASABKAsyGC50ZW5zb3JmbG93LlRmRGV2aWNlTGlzdEgA", + "EigKA3hsYRgCIAEoCzIZLnRlbnNvcmZsb3cuWGxhRGV2aWNlTGlzdEgAQgYK", + "BHR5cGUibgoTUmVnaXN0ZXJUYXNrUmVxdWVzdBITCgtpbmNhcm5hdGlvbhgD", + "IAEoBhIwCgtzb3VyY2VfdGFzaxgFIAEoCzIbLnRlbnNvcmZsb3cuQ29vcmRp", + "bmF0ZWRUYXNrSgQIARACSgQIAhADSgQIBBAFIjIKFFJlZ2lzdGVyVGFza1Jl", + "c3BvbnNlEhoKEmxlYWRlcl9pbmNhcm5hdGlvbhgBIAEoBiJlChBIZWFydGJl", + "YXRSZXF1ZXN0EhMKC2luY2FybmF0aW9uGAMgASgGEjAKC3NvdXJjZV90YXNr", + "GAQgASgLMhsudGVuc29yZmxvdy5Db29yZGluYXRlZFRhc2tKBAgBEAJKBAgC", + "EAMiLwoRSGVhcnRiZWF0UmVzcG9uc2USGgoSbGVhZGVyX2luY2FybmF0aW9u", + "GAEgASgGIqIBChZXYWl0Rm9yQWxsVGFza3NSZXF1ZXN0EkQKEWxvY2FsX2Rl", + "dmljZV9pbmZvGAQgASgLMikudGVuc29yZmxvdy5Db29yZGluYXRpb25TZXJ2", + "aWNlRGV2aWNlSW5mbxIwCgtzb3VyY2VfdGFzaxgFIAEoCzIbLnRlbnNvcmZs", + "b3cuQ29vcmRpbmF0ZWRUYXNrSgQIARACSgQIAhADSgQIAxAEIoMBChdXYWl0", + "Rm9yQWxsVGFza3NSZXNwb25zZRIaChJsZWFkZXJfaW5jYXJuYXRpb24YASAB", + "KAYSRgoTY2x1c3Rlcl9kZXZpY2VfaW5mbxgDIAEoCzIpLnRlbnNvcmZsb3cu", + "Q29vcmRpbmF0aW9uU2VydmljZURldmljZUluZm9KBAgCEAMiRwoTU2h1dGRv", + "d25UYXNrUmVxdWVzdBIwCgtzb3VyY2VfdGFzaxgBIAEoCzIbLnRlbnNvcmZs", + "b3cuQ29vcmRpbmF0ZWRUYXNrIhYKFFNodXRkb3duVGFza1Jlc3BvbnNlIkQK", + "EFJlc2V0VGFza1JlcXVlc3QSMAoLc291cmNlX3Rhc2sYASABKAsyGy50ZW5z", + "b3JmbG93LkNvb3JkaW5hdGVkVGFzayITChFSZXNldFRhc2tSZXNwb25zZSKO", + "AQoYUmVwb3J0RXJyb3JUb1Rhc2tSZXF1ZXN0EhIKCmVycm9yX2NvZGUYASAB", + "KAUSFQoNZXJyb3JfbWVzc2FnZRgCIAEoCRI7Cg1lcnJvcl9wYXlsb2FkGAUg", + "ASgLMiQudGVuc29yZmxvdy5Db29yZGluYXRpb25TZXJ2aWNlRXJyb3JKBAgD", + "EARKBAgEEAUiGwoZUmVwb3J0RXJyb3JUb1Rhc2tSZXNwb25zZSKHAQobUmVw", + "b3J0RXJyb3JUb1NlcnZpY2VSZXF1ZXN0EhIKCmVycm9yX2NvZGUYASABKAUS", + "FQoNZXJyb3JfbWVzc2FnZRgCIAEoCRIxCgxlcnJvcl9vcmlnaW4YBSABKAsy", + "Gy50ZW5zb3JmbG93LkNvb3JkaW5hdGVkVGFza0oECAMQBEoECAQQBSIeChxS", + "ZXBvcnRFcnJvclRvU2VydmljZVJlc3BvbnNlIkcKE0dldFRhc2tTdGF0ZVJl", + "cXVlc3QSMAoLc291cmNlX3Rhc2sYASADKAsyGy50ZW5zb3JmbG93LkNvb3Jk", + "aW5hdGVkVGFzayJQChRHZXRUYXNrU3RhdGVSZXNwb25zZRI4Cgp0YXNrX3N0", + "YXRlGAEgAygLMiQudGVuc29yZmxvdy5Db29yZGluYXRlZFRhc2tTdGF0ZUlu", + "Zm8iKwoNS2V5VmFsdWVFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiAB", + "KAwiPgoVSW5zZXJ0S2V5VmFsdWVSZXF1ZXN0EiUKAmt2GAEgASgLMhkudGVu", + "c29yZmxvdy5LZXlWYWx1ZUVudHJ5IhgKFkluc2VydEtleVZhbHVlUmVzcG9u", + "c2UiIQoSR2V0S2V5VmFsdWVSZXF1ZXN0EgsKA2tleRgBIAEoCSI8ChNHZXRL", + "ZXlWYWx1ZVJlc3BvbnNlEiUKAmt2GAEgASgLMhkudGVuc29yZmxvdy5LZXlW", + "YWx1ZUVudHJ5IiQKFVRyeUdldEtleVZhbHVlUmVxdWVzdBILCgNrZXkYASAB", + "KAkiPwoWVHJ5R2V0S2V5VmFsdWVSZXNwb25zZRIlCgJrdhgBIAEoCzIZLnRl", + "bnNvcmZsb3cuS2V5VmFsdWVFbnRyeSIuChVHZXRLZXlWYWx1ZURpclJlcXVl", + "c3QSFQoNZGlyZWN0b3J5X2tleRgBIAEoCSJWChZHZXRLZXlWYWx1ZURpclJl", + "c3BvbnNlEhUKDWRpcmVjdG9yeV9rZXkYASABKAkSJQoCa3YYAiADKAsyGS50", + "ZW5zb3JmbG93LktleVZhbHVlRW50cnkiOgoVRGVsZXRlS2V5VmFsdWVSZXF1", + "ZXN0EgsKA2tleRgBIAEoCRIUCgxpc19kaXJlY3RvcnkYAiABKAgiGAoWRGVs", + "ZXRlS2V5VmFsdWVSZXNwb25zZSKhAQoOQmFycmllclJlcXVlc3QSEgoKYmFy", + "cmllcl9pZBgBIAEoCRIdChViYXJyaWVyX3RpbWVvdXRfaW5fbXMYAiABKAMS", + "KgoFdGFza3MYAyADKAsyGy50ZW5zb3JmbG93LkNvb3JkaW5hdGVkVGFzaxIw", + "Cgtzb3VyY2VfdGFzaxgEIAEoCzIbLnRlbnNvcmZsb3cuQ29vcmRpbmF0ZWRU", + "YXNrIhEKD0JhcnJpZXJSZXNwb25zZSJcChRDYW5jZWxCYXJyaWVyUmVxdWVz", + "dBISCgpiYXJyaWVyX2lkGAEgASgJEjAKC3NvdXJjZV90YXNrGAIgASgLMhsu", + "dGVuc29yZmxvdy5Db29yZGluYXRlZFRhc2siFwoVQ2FuY2VsQmFycmllclJl", + "c3BvbnNlKpgBChRDb29yZGluYXRlZFRhc2tTdGF0ZRIZChVUQVNLU1RBVEVf", + "VU5TUEVDSUZJRUQQABIbChdUQVNLU1RBVEVfVU5JTklUSUFMSVpFRBABEhoK", + "FlRBU0tTVEFURV9ESVNDT05ORUNURUQQAhIXChNUQVNLU1RBVEVfQ09OTkVD", + "VEVEEAMSEwoPVEFTS1NUQVRFX0VSUk9SEAQymQoKE0Nvb3JkaW5hdGlvblNl", + "cnZpY2USUQoMUmVnaXN0ZXJUYXNrEh8udGVuc29yZmxvdy5SZWdpc3RlclRh", + "c2tSZXF1ZXN0GiAudGVuc29yZmxvdy5SZWdpc3RlclRhc2tSZXNwb25zZRJI", + "CglIZWFydGJlYXQSHC50ZW5zb3JmbG93LkhlYXJ0YmVhdFJlcXVlc3QaHS50", + "ZW5zb3JmbG93LkhlYXJ0YmVhdFJlc3BvbnNlEloKD1dhaXRGb3JBbGxUYXNr", + "cxIiLnRlbnNvcmZsb3cuV2FpdEZvckFsbFRhc2tzUmVxdWVzdBojLnRlbnNv", + "cmZsb3cuV2FpdEZvckFsbFRhc2tzUmVzcG9uc2USUQoMU2h1dGRvd25UYXNr", + "Eh8udGVuc29yZmxvdy5TaHV0ZG93blRhc2tSZXF1ZXN0GiAudGVuc29yZmxv", + "dy5TaHV0ZG93blRhc2tSZXNwb25zZRJICglSZXNldFRhc2sSHC50ZW5zb3Jm", + "bG93LlJlc2V0VGFza1JlcXVlc3QaHS50ZW5zb3JmbG93LlJlc2V0VGFza1Jl", + "c3BvbnNlEmAKEVJlcG9ydEVycm9yVG9UYXNrEiQudGVuc29yZmxvdy5SZXBv", + "cnRFcnJvclRvVGFza1JlcXVlc3QaJS50ZW5zb3JmbG93LlJlcG9ydEVycm9y", + "VG9UYXNrUmVzcG9uc2USaQoUUmVwb3J0RXJyb3JUb1NlcnZpY2USJy50ZW5z", + "b3JmbG93LlJlcG9ydEVycm9yVG9TZXJ2aWNlUmVxdWVzdBooLnRlbnNvcmZs", + "b3cuUmVwb3J0RXJyb3JUb1NlcnZpY2VSZXNwb25zZRJRCgxHZXRUYXNrU3Rh", + "dGUSHy50ZW5zb3JmbG93LkdldFRhc2tTdGF0ZVJlcXVlc3QaIC50ZW5zb3Jm", + "bG93LkdldFRhc2tTdGF0ZVJlc3BvbnNlElcKDkluc2VydEtleVZhbHVlEiEu", + "dGVuc29yZmxvdy5JbnNlcnRLZXlWYWx1ZVJlcXVlc3QaIi50ZW5zb3JmbG93", + "Lkluc2VydEtleVZhbHVlUmVzcG9uc2USTgoLR2V0S2V5VmFsdWUSHi50ZW5z", + "b3JmbG93LkdldEtleVZhbHVlUmVxdWVzdBofLnRlbnNvcmZsb3cuR2V0S2V5", + "VmFsdWVSZXNwb25zZRJXCg5UcnlHZXRLZXlWYWx1ZRIhLnRlbnNvcmZsb3cu", + "VHJ5R2V0S2V5VmFsdWVSZXF1ZXN0GiIudGVuc29yZmxvdy5UcnlHZXRLZXlW", + "YWx1ZVJlc3BvbnNlElcKDkdldEtleVZhbHVlRGlyEiEudGVuc29yZmxvdy5H", + "ZXRLZXlWYWx1ZURpclJlcXVlc3QaIi50ZW5zb3JmbG93LkdldEtleVZhbHVl", + "RGlyUmVzcG9uc2USVwoORGVsZXRlS2V5VmFsdWUSIS50ZW5zb3JmbG93LkRl", + "bGV0ZUtleVZhbHVlUmVxdWVzdBoiLnRlbnNvcmZsb3cuRGVsZXRlS2V5VmFs", + "dWVSZXNwb25zZRJCCgdCYXJyaWVyEhoudGVuc29yZmxvdy5CYXJyaWVyUmVx", + "dWVzdBobLnRlbnNvcmZsb3cuQmFycmllclJlc3BvbnNlElQKDUNhbmNlbEJh", + "cnJpZXISIC50ZW5zb3JmbG93LkNhbmNlbEJhcnJpZXJSZXF1ZXN0GiEudGVu", + "c29yZmxvdy5DYW5jZWxCYXJyaWVyUmVzcG9uc2VCV1pVZ2l0aHViLmNvbS90", + "ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3Rv", + "YnVmL2Zvcl9jb3JlX3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Xla.ProtocolReflection.Descriptor, global::Tensorflow.DeviceAttributesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.CoordinatedTaskState), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedTask), global::Tensorflow.CoordinatedTask.Parser, new[]{ "JobName", "TaskId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceError), global::Tensorflow.CoordinationServiceError.Parser, new[]{ "IsReportedError", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedTaskStateInfo), global::Tensorflow.CoordinatedTaskStateInfo.Parser, new[]{ "Task", "State", "ErrorCode", "ErrorMessage", "ErrorPayload" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TfDeviceList), global::Tensorflow.TfDeviceList.Parser, new[]{ "Devices" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.XlaDeviceList), global::Tensorflow.XlaDeviceList.Parser, new[]{ "Devices" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceDeviceInfo), global::Tensorflow.CoordinationServiceDeviceInfo.Parser, new[]{ "Tf", "Xla" }, new[]{ "Type" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RegisterTaskRequest), global::Tensorflow.RegisterTaskRequest.Parser, new[]{ "Incarnation", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RegisterTaskResponse), global::Tensorflow.RegisterTaskResponse.Parser, new[]{ "LeaderIncarnation" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.HeartbeatRequest), global::Tensorflow.HeartbeatRequest.Parser, new[]{ "Incarnation", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.HeartbeatResponse), global::Tensorflow.HeartbeatResponse.Parser, new[]{ "LeaderIncarnation" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WaitForAllTasksRequest), global::Tensorflow.WaitForAllTasksRequest.Parser, new[]{ "LocalDeviceInfo", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WaitForAllTasksResponse), global::Tensorflow.WaitForAllTasksResponse.Parser, new[]{ "LeaderIncarnation", "ClusterDeviceInfo" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ShutdownTaskRequest), global::Tensorflow.ShutdownTaskRequest.Parser, new[]{ "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ShutdownTaskResponse), global::Tensorflow.ShutdownTaskResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResetTaskRequest), global::Tensorflow.ResetTaskRequest.Parser, new[]{ "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResetTaskResponse), global::Tensorflow.ResetTaskResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ReportErrorToTaskRequest), global::Tensorflow.ReportErrorToTaskRequest.Parser, new[]{ "ErrorCode", "ErrorMessage", "ErrorPayload" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ReportErrorToTaskResponse), global::Tensorflow.ReportErrorToTaskResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ReportErrorToServiceRequest), global::Tensorflow.ReportErrorToServiceRequest.Parser, new[]{ "ErrorCode", "ErrorMessage", "ErrorOrigin" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ReportErrorToServiceResponse), global::Tensorflow.ReportErrorToServiceResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetTaskStateRequest), global::Tensorflow.GetTaskStateRequest.Parser, new[]{ "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetTaskStateResponse), global::Tensorflow.GetTaskStateResponse.Parser, new[]{ "TaskState" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.KeyValueEntry), global::Tensorflow.KeyValueEntry.Parser, new[]{ "Key", "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InsertKeyValueRequest), global::Tensorflow.InsertKeyValueRequest.Parser, new[]{ "Kv" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InsertKeyValueResponse), global::Tensorflow.InsertKeyValueResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetKeyValueRequest), global::Tensorflow.GetKeyValueRequest.Parser, new[]{ "Key" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetKeyValueResponse), global::Tensorflow.GetKeyValueResponse.Parser, new[]{ "Kv" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TryGetKeyValueRequest), global::Tensorflow.TryGetKeyValueRequest.Parser, new[]{ "Key" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TryGetKeyValueResponse), global::Tensorflow.TryGetKeyValueResponse.Parser, new[]{ "Kv" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetKeyValueDirRequest), global::Tensorflow.GetKeyValueDirRequest.Parser, new[]{ "DirectoryKey" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GetKeyValueDirResponse), global::Tensorflow.GetKeyValueDirResponse.Parser, new[]{ "DirectoryKey", "Kv" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeleteKeyValueRequest), global::Tensorflow.DeleteKeyValueRequest.Parser, new[]{ "Key", "IsDirectory" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeleteKeyValueResponse), global::Tensorflow.DeleteKeyValueResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.BarrierRequest), global::Tensorflow.BarrierRequest.Parser, new[]{ "BarrierId", "BarrierTimeoutInMs", "Tasks", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.BarrierResponse), global::Tensorflow.BarrierResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CancelBarrierRequest), global::Tensorflow.CancelBarrierRequest.Parser, new[]{ "BarrierId", "SourceTask" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CancelBarrierResponse), global::Tensorflow.CancelBarrierResponse.Parser, null, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Represents the state of a remote worker + /// + public enum CoordinatedTaskState { + /// + /// TASKSTATE_UNSPECIFIED is an invalid state such that indicates a bug. + /// + [pbr::OriginalName("TASKSTATE_UNSPECIFIED")] TaskstateUnspecified = 0, + /// + /// TASKSTATE_UNINITIALIZED is an agent-only state. While the agent is + /// disconnected, the service has no way of knowing if the task is + /// initialized/uninitialized. + /// + [pbr::OriginalName("TASKSTATE_UNINITIALIZED")] TaskstateUninitialized = 1, + [pbr::OriginalName("TASKSTATE_DISCONNECTED")] TaskstateDisconnected = 2, + [pbr::OriginalName("TASKSTATE_CONNECTED")] TaskstateConnected = 3, + [pbr::OriginalName("TASKSTATE_ERROR")] TaskstateError = 4, + } + + #endregion + + #region Messages + /// + /// Represents a remote worker task, specified by job name and task id. + /// + public sealed partial class CoordinatedTask : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinatedTask()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTask() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTask(CoordinatedTask other) : this() { + jobName_ = other.jobName_; + taskId_ = other.taskId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTask Clone() { + return new CoordinatedTask(this); + } + + /// Field number for the "job_name" field. + public const int JobNameFieldNumber = 1; + private string jobName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string JobName { + get { return jobName_; } + set { + jobName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "task_id" field. + public const int TaskIdFieldNumber = 2; + private int taskId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int TaskId { + get { return taskId_; } + set { + taskId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinatedTask); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinatedTask other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (JobName != other.JobName) return false; + if (TaskId != other.TaskId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (JobName.Length != 0) hash ^= JobName.GetHashCode(); + if (TaskId != 0) hash ^= TaskId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (JobName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(JobName); + } + if (TaskId != 0) { + output.WriteRawTag(16); + output.WriteInt32(TaskId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (JobName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(JobName); + } + if (TaskId != 0) { + output.WriteRawTag(16); + output.WriteInt32(TaskId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (JobName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(JobName); + } + if (TaskId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TaskId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinatedTask other) { + if (other == null) { + return; + } + if (other.JobName.Length != 0) { + JobName = other.JobName; + } + if (other.TaskId != 0) { + TaskId = other.TaskId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + JobName = input.ReadString(); + break; + } + case 16: { + TaskId = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + JobName = input.ReadString(); + break; + } + case 16: { + TaskId = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + /// + /// Status payload for all coordination service errors. + /// Note: an empty proto may be set if the error is triggered by the task's own + /// agent calls (i.e. not propagated by the service from another remote task). + /// + public sealed partial class CoordinationServiceError : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinationServiceError()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceError() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceError(CoordinationServiceError other) : this() { + isReportedError_ = other.isReportedError_; + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceError Clone() { + return new CoordinationServiceError(this); + } + + /// Field number for the "is_reported_error" field. + public const int IsReportedErrorFieldNumber = 3; + private bool isReportedError_; + /// + /// If true, error is reported via the agent API by the user (and not an + /// internal service error). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsReportedError { + get { return isReportedError_; } + set { + isReportedError_ = value; + } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 4; + private global::Tensorflow.CoordinatedTask sourceTask_; + /// + /// Denotes which task hit the error. If unset, the error originated from the + /// same task that is processing this error. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinationServiceError); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinationServiceError other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (IsReportedError != other.IsReportedError) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (IsReportedError != false) hash ^= IsReportedError.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (IsReportedError != false) { + output.WriteRawTag(24); + output.WriteBool(IsReportedError); + } + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (IsReportedError != false) { + output.WriteRawTag(24); + output.WriteBool(IsReportedError); + } + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (IsReportedError != false) { + size += 1 + 1; + } + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinationServiceError other) { + if (other == null) { + return; + } + if (other.IsReportedError != false) { + IsReportedError = other.IsReportedError; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 24: { + IsReportedError = input.ReadBool(); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 24: { + IsReportedError = input.ReadBool(); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class CoordinatedTaskStateInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinatedTaskStateInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTaskStateInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTaskStateInfo(CoordinatedTaskStateInfo other) : this() { + task_ = other.task_ != null ? other.task_.Clone() : null; + state_ = other.state_; + errorCode_ = other.errorCode_; + errorMessage_ = other.errorMessage_; + errorPayload_ = other.errorPayload_ != null ? other.errorPayload_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinatedTaskStateInfo Clone() { + return new CoordinatedTaskStateInfo(this); + } + + /// Field number for the "task" field. + public const int TaskFieldNumber = 1; + private global::Tensorflow.CoordinatedTask task_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask Task { + get { return task_; } + set { + task_ = value; + } + } + + /// Field number for the "state" field. + public const int StateFieldNumber = 2; + private global::Tensorflow.CoordinatedTaskState state_ = global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTaskState State { + get { return state_; } + set { + state_ = value; + } + } + + /// Field number for the "error_code" field. + public const int ErrorCodeFieldNumber = 3; + private int errorCode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ErrorCode { + get { return errorCode_; } + set { + errorCode_ = value; + } + } + + /// Field number for the "error_message" field. + public const int ErrorMessageFieldNumber = 4; + private string errorMessage_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ErrorMessage { + get { return errorMessage_; } + set { + errorMessage_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "error_payload" field. + public const int ErrorPayloadFieldNumber = 5; + private global::Tensorflow.CoordinationServiceError errorPayload_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinationServiceError ErrorPayload { + get { return errorPayload_; } + set { + errorPayload_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinatedTaskStateInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinatedTaskStateInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Task, other.Task)) return false; + if (State != other.State) return false; + if (ErrorCode != other.ErrorCode) return false; + if (ErrorMessage != other.ErrorMessage) return false; + if (!object.Equals(ErrorPayload, other.ErrorPayload)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (task_ != null) hash ^= Task.GetHashCode(); + if (State != global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified) hash ^= State.GetHashCode(); + if (ErrorCode != 0) hash ^= ErrorCode.GetHashCode(); + if (ErrorMessage.Length != 0) hash ^= ErrorMessage.GetHashCode(); + if (errorPayload_ != null) hash ^= ErrorPayload.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (task_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Task); + } + if (State != global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified) { + output.WriteRawTag(16); + output.WriteEnum((int) State); + } + if (ErrorCode != 0) { + output.WriteRawTag(24); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(34); + output.WriteString(ErrorMessage); + } + if (errorPayload_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorPayload); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (task_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Task); + } + if (State != global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified) { + output.WriteRawTag(16); + output.WriteEnum((int) State); + } + if (ErrorCode != 0) { + output.WriteRawTag(24); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(34); + output.WriteString(ErrorMessage); + } + if (errorPayload_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorPayload); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (task_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Task); + } + if (State != global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) State); + } + if (ErrorCode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ErrorCode); + } + if (ErrorMessage.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ErrorMessage); + } + if (errorPayload_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ErrorPayload); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinatedTaskStateInfo other) { + if (other == null) { + return; + } + if (other.task_ != null) { + if (task_ == null) { + Task = new global::Tensorflow.CoordinatedTask(); + } + Task.MergeFrom(other.Task); + } + if (other.State != global::Tensorflow.CoordinatedTaskState.TaskstateUnspecified) { + State = other.State; + } + if (other.ErrorCode != 0) { + ErrorCode = other.ErrorCode; + } + if (other.ErrorMessage.Length != 0) { + ErrorMessage = other.ErrorMessage; + } + if (other.errorPayload_ != null) { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + ErrorPayload.MergeFrom(other.ErrorPayload); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (task_ == null) { + Task = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(Task); + break; + } + case 16: { + State = (global::Tensorflow.CoordinatedTaskState) input.ReadEnum(); + break; + } + case 24: { + ErrorCode = input.ReadInt32(); + break; + } + case 34: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + input.ReadMessage(ErrorPayload); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (task_ == null) { + Task = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(Task); + break; + } + case 16: { + State = (global::Tensorflow.CoordinatedTaskState) input.ReadEnum(); + break; + } + case 24: { + ErrorCode = input.ReadInt32(); + break; + } + case 34: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + input.ReadMessage(ErrorPayload); + break; + } + } + } + } + #endif + + } + + /// + /// Represent device information from different runtimes. + /// + public sealed partial class TfDeviceList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TfDeviceList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TfDeviceList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TfDeviceList(TfDeviceList other) : this() { + devices_ = other.devices_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TfDeviceList Clone() { + return new TfDeviceList(this); + } + + /// Field number for the "devices" field. + public const int DevicesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_devices_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.DeviceAttributes.Parser); + private readonly pbc::RepeatedField devices_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Devices { + get { return devices_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TfDeviceList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TfDeviceList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!devices_.Equals(other.devices_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= devices_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + devices_.WriteTo(output, _repeated_devices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + devices_.WriteTo(ref output, _repeated_devices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += devices_.CalculateSize(_repeated_devices_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TfDeviceList other) { + if (other == null) { + return; + } + devices_.Add(other.devices_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + devices_.AddEntriesFrom(input, _repeated_devices_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + devices_.AddEntriesFrom(ref input, _repeated_devices_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class XlaDeviceList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XlaDeviceList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaDeviceList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaDeviceList(XlaDeviceList other) : this() { + devices_ = other.devices_ != null ? other.devices_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaDeviceList Clone() { + return new XlaDeviceList(this); + } + + /// Field number for the "devices" field. + public const int DevicesFieldNumber = 1; + private global::Xla.GlobalTopologyProto devices_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalTopologyProto Devices { + get { return devices_; } + set { + devices_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as XlaDeviceList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(XlaDeviceList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Devices, other.Devices)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (devices_ != null) hash ^= Devices.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (devices_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Devices); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (devices_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Devices); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (devices_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Devices); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(XlaDeviceList other) { + if (other == null) { + return; + } + if (other.devices_ != null) { + if (devices_ == null) { + Devices = new global::Xla.GlobalTopologyProto(); + } + Devices.MergeFrom(other.Devices); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (devices_ == null) { + Devices = new global::Xla.GlobalTopologyProto(); + } + input.ReadMessage(Devices); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (devices_ == null) { + Devices = new global::Xla.GlobalTopologyProto(); + } + input.ReadMessage(Devices); + break; + } + } + } + } + #endif + + } + + public sealed partial class CoordinationServiceDeviceInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CoordinationServiceDeviceInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceDeviceInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceDeviceInfo(CoordinationServiceDeviceInfo other) : this() { + switch (other.TypeCase) { + case TypeOneofCase.Tf: + Tf = other.Tf.Clone(); + break; + case TypeOneofCase.Xla: + Xla = other.Xla.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CoordinationServiceDeviceInfo Clone() { + return new CoordinationServiceDeviceInfo(this); + } + + /// Field number for the "tf" field. + public const int TfFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TfDeviceList Tf { + get { return typeCase_ == TypeOneofCase.Tf ? (global::Tensorflow.TfDeviceList) type_ : null; } + set { + type_ = value; + typeCase_ = value == null ? TypeOneofCase.None : TypeOneofCase.Tf; + } + } + + /// Field number for the "xla" field. + public const int XlaFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.XlaDeviceList Xla { + get { return typeCase_ == TypeOneofCase.Xla ? (global::Tensorflow.XlaDeviceList) type_ : null; } + set { + type_ = value; + typeCase_ = value == null ? TypeOneofCase.None : TypeOneofCase.Xla; + } + } + + private object type_; + /// Enum of possible cases for the "type" oneof. + public enum TypeOneofCase { + None = 0, + Tf = 1, + Xla = 2, + } + private TypeOneofCase typeCase_ = TypeOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TypeOneofCase TypeCase { + get { return typeCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearType() { + typeCase_ = TypeOneofCase.None; + type_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CoordinationServiceDeviceInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CoordinationServiceDeviceInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Tf, other.Tf)) return false; + if (!object.Equals(Xla, other.Xla)) return false; + if (TypeCase != other.TypeCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (typeCase_ == TypeOneofCase.Tf) hash ^= Tf.GetHashCode(); + if (typeCase_ == TypeOneofCase.Xla) hash ^= Xla.GetHashCode(); + hash ^= (int) typeCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (typeCase_ == TypeOneofCase.Tf) { + output.WriteRawTag(10); + output.WriteMessage(Tf); + } + if (typeCase_ == TypeOneofCase.Xla) { + output.WriteRawTag(18); + output.WriteMessage(Xla); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (typeCase_ == TypeOneofCase.Tf) { + output.WriteRawTag(10); + output.WriteMessage(Tf); + } + if (typeCase_ == TypeOneofCase.Xla) { + output.WriteRawTag(18); + output.WriteMessage(Xla); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (typeCase_ == TypeOneofCase.Tf) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Tf); + } + if (typeCase_ == TypeOneofCase.Xla) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Xla); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CoordinationServiceDeviceInfo other) { + if (other == null) { + return; + } + switch (other.TypeCase) { + case TypeOneofCase.Tf: + if (Tf == null) { + Tf = new global::Tensorflow.TfDeviceList(); + } + Tf.MergeFrom(other.Tf); + break; + case TypeOneofCase.Xla: + if (Xla == null) { + Xla = new global::Tensorflow.XlaDeviceList(); + } + Xla.MergeFrom(other.Xla); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.TfDeviceList subBuilder = new global::Tensorflow.TfDeviceList(); + if (typeCase_ == TypeOneofCase.Tf) { + subBuilder.MergeFrom(Tf); + } + input.ReadMessage(subBuilder); + Tf = subBuilder; + break; + } + case 18: { + global::Tensorflow.XlaDeviceList subBuilder = new global::Tensorflow.XlaDeviceList(); + if (typeCase_ == TypeOneofCase.Xla) { + subBuilder.MergeFrom(Xla); + } + input.ReadMessage(subBuilder); + Xla = subBuilder; + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + global::Tensorflow.TfDeviceList subBuilder = new global::Tensorflow.TfDeviceList(); + if (typeCase_ == TypeOneofCase.Tf) { + subBuilder.MergeFrom(Tf); + } + input.ReadMessage(subBuilder); + Tf = subBuilder; + break; + } + case 18: { + global::Tensorflow.XlaDeviceList subBuilder = new global::Tensorflow.XlaDeviceList(); + if (typeCase_ == TypeOneofCase.Xla) { + subBuilder.MergeFrom(Xla); + } + input.ReadMessage(subBuilder); + Xla = subBuilder; + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for registering a task to the cluster leader. + /// A task is uniquely represented by its `job_name`, `task_id` and + /// `incarnation`. Leader responds with its `incarnation` to identify a leader + /// process. + /// + public sealed partial class RegisterTaskRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegisterTaskRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskRequest(RegisterTaskRequest other) : this() { + incarnation_ = other.incarnation_; + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskRequest Clone() { + return new RegisterTaskRequest(this); + } + + /// Field number for the "incarnation" field. + public const int IncarnationFieldNumber = 3; + private ulong incarnation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Incarnation { + get { return incarnation_; } + set { + incarnation_ = value; + } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 5; + private global::Tensorflow.CoordinatedTask sourceTask_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RegisterTaskRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RegisterTaskRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Incarnation != other.Incarnation) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Incarnation != 0UL) { + output.WriteRawTag(25); + output.WriteFixed64(Incarnation); + } + if (sourceTask_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Incarnation != 0UL) { + output.WriteRawTag(25); + output.WriteFixed64(Incarnation); + } + if (sourceTask_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Incarnation != 0UL) { + size += 1 + 8; + } + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RegisterTaskRequest other) { + if (other == null) { + return; + } + if (other.Incarnation != 0UL) { + Incarnation = other.Incarnation; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 25: { + Incarnation = input.ReadFixed64(); + break; + } + case 42: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 25: { + Incarnation = input.ReadFixed64(); + break; + } + case 42: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class RegisterTaskResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegisterTaskResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskResponse(RegisterTaskResponse other) : this() { + leaderIncarnation_ = other.leaderIncarnation_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisterTaskResponse Clone() { + return new RegisterTaskResponse(this); + } + + /// Field number for the "leader_incarnation" field. + public const int LeaderIncarnationFieldNumber = 1; + private ulong leaderIncarnation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong LeaderIncarnation { + get { return leaderIncarnation_; } + set { + leaderIncarnation_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RegisterTaskResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RegisterTaskResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LeaderIncarnation != other.LeaderIncarnation) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LeaderIncarnation != 0UL) hash ^= LeaderIncarnation.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LeaderIncarnation != 0UL) { + size += 1 + 8; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RegisterTaskResponse other) { + if (other == null) { + return; + } + if (other.LeaderIncarnation != 0UL) { + LeaderIncarnation = other.LeaderIncarnation; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for sending heartbeats. + /// + public sealed partial class HeartbeatRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeartbeatRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest(HeartbeatRequest other) : this() { + incarnation_ = other.incarnation_; + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest Clone() { + return new HeartbeatRequest(this); + } + + /// Field number for the "incarnation" field. + public const int IncarnationFieldNumber = 3; + private ulong incarnation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Incarnation { + get { return incarnation_; } + set { + incarnation_ = value; + } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 4; + private global::Tensorflow.CoordinatedTask sourceTask_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HeartbeatRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HeartbeatRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Incarnation != other.Incarnation) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Incarnation != 0UL) { + output.WriteRawTag(25); + output.WriteFixed64(Incarnation); + } + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Incarnation != 0UL) { + output.WriteRawTag(25); + output.WriteFixed64(Incarnation); + } + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Incarnation != 0UL) { + size += 1 + 8; + } + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HeartbeatRequest other) { + if (other == null) { + return; + } + if (other.Incarnation != 0UL) { + Incarnation = other.Incarnation; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 25: { + Incarnation = input.ReadFixed64(); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 25: { + Incarnation = input.ReadFixed64(); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class HeartbeatResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeartbeatResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse(HeartbeatResponse other) : this() { + leaderIncarnation_ = other.leaderIncarnation_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse Clone() { + return new HeartbeatResponse(this); + } + + /// Field number for the "leader_incarnation" field. + public const int LeaderIncarnationFieldNumber = 1; + private ulong leaderIncarnation_; + /// + /// If there are failures in cluster, use additional metadata in response to + /// broadcast error code and message to other tasks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong LeaderIncarnation { + get { return leaderIncarnation_; } + set { + leaderIncarnation_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HeartbeatResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HeartbeatResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LeaderIncarnation != other.LeaderIncarnation) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LeaderIncarnation != 0UL) hash ^= LeaderIncarnation.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LeaderIncarnation != 0UL) { + size += 1 + 8; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HeartbeatResponse other) { + if (other == null) { + return; + } + if (other.LeaderIncarnation != 0UL) { + LeaderIncarnation = other.LeaderIncarnation; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for waiting for all tasks. + /// + public sealed partial class WaitForAllTasksRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitForAllTasksRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksRequest(WaitForAllTasksRequest other) : this() { + localDeviceInfo_ = other.localDeviceInfo_ != null ? other.localDeviceInfo_.Clone() : null; + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksRequest Clone() { + return new WaitForAllTasksRequest(this); + } + + /// Field number for the "local_device_info" field. + public const int LocalDeviceInfoFieldNumber = 4; + private global::Tensorflow.CoordinationServiceDeviceInfo localDeviceInfo_; + /// + /// All local device attributes on the request sender. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinationServiceDeviceInfo LocalDeviceInfo { + get { return localDeviceInfo_; } + set { + localDeviceInfo_ = value; + } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 5; + private global::Tensorflow.CoordinatedTask sourceTask_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitForAllTasksRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitForAllTasksRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(LocalDeviceInfo, other.LocalDeviceInfo)) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (localDeviceInfo_ != null) hash ^= LocalDeviceInfo.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (localDeviceInfo_ != null) { + output.WriteRawTag(34); + output.WriteMessage(LocalDeviceInfo); + } + if (sourceTask_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (localDeviceInfo_ != null) { + output.WriteRawTag(34); + output.WriteMessage(LocalDeviceInfo); + } + if (sourceTask_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (localDeviceInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LocalDeviceInfo); + } + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitForAllTasksRequest other) { + if (other == null) { + return; + } + if (other.localDeviceInfo_ != null) { + if (localDeviceInfo_ == null) { + LocalDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + LocalDeviceInfo.MergeFrom(other.LocalDeviceInfo); + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 34: { + if (localDeviceInfo_ == null) { + LocalDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + input.ReadMessage(LocalDeviceInfo); + break; + } + case 42: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 34: { + if (localDeviceInfo_ == null) { + LocalDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + input.ReadMessage(LocalDeviceInfo); + break; + } + case 42: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class WaitForAllTasksResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitForAllTasksResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksResponse(WaitForAllTasksResponse other) : this() { + leaderIncarnation_ = other.leaderIncarnation_; + clusterDeviceInfo_ = other.clusterDeviceInfo_ != null ? other.clusterDeviceInfo_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForAllTasksResponse Clone() { + return new WaitForAllTasksResponse(this); + } + + /// Field number for the "leader_incarnation" field. + public const int LeaderIncarnationFieldNumber = 1; + private ulong leaderIncarnation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong LeaderIncarnation { + get { return leaderIncarnation_; } + set { + leaderIncarnation_ = value; + } + } + + /// Field number for the "cluster_device_info" field. + public const int ClusterDeviceInfoFieldNumber = 3; + private global::Tensorflow.CoordinationServiceDeviceInfo clusterDeviceInfo_; + /// + /// All devices in the cluster. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinationServiceDeviceInfo ClusterDeviceInfo { + get { return clusterDeviceInfo_; } + set { + clusterDeviceInfo_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitForAllTasksResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitForAllTasksResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LeaderIncarnation != other.LeaderIncarnation) return false; + if (!object.Equals(ClusterDeviceInfo, other.ClusterDeviceInfo)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LeaderIncarnation != 0UL) hash ^= LeaderIncarnation.GetHashCode(); + if (clusterDeviceInfo_ != null) hash ^= ClusterDeviceInfo.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (clusterDeviceInfo_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ClusterDeviceInfo); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LeaderIncarnation != 0UL) { + output.WriteRawTag(9); + output.WriteFixed64(LeaderIncarnation); + } + if (clusterDeviceInfo_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ClusterDeviceInfo); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LeaderIncarnation != 0UL) { + size += 1 + 8; + } + if (clusterDeviceInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ClusterDeviceInfo); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitForAllTasksResponse other) { + if (other == null) { + return; + } + if (other.LeaderIncarnation != 0UL) { + LeaderIncarnation = other.LeaderIncarnation; + } + if (other.clusterDeviceInfo_ != null) { + if (clusterDeviceInfo_ == null) { + ClusterDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + ClusterDeviceInfo.MergeFrom(other.ClusterDeviceInfo); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + case 26: { + if (clusterDeviceInfo_ == null) { + ClusterDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + input.ReadMessage(ClusterDeviceInfo); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + LeaderIncarnation = input.ReadFixed64(); + break; + } + case 26: { + if (clusterDeviceInfo_ == null) { + ClusterDeviceInfo = new global::Tensorflow.CoordinationServiceDeviceInfo(); + } + input.ReadMessage(ClusterDeviceInfo); + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for disconnecting a task from the service. + /// + public sealed partial class ShutdownTaskRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShutdownTaskRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskRequest(ShutdownTaskRequest other) : this() { + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskRequest Clone() { + return new ShutdownTaskRequest(this); + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 1; + private global::Tensorflow.CoordinatedTask sourceTask_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShutdownTaskRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShutdownTaskRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (sourceTask_ != null) { + output.WriteRawTag(10); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (sourceTask_ != null) { + output.WriteRawTag(10); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShutdownTaskRequest other) { + if (other == null) { + return; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class ShutdownTaskResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShutdownTaskResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskResponse(ShutdownTaskResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownTaskResponse Clone() { + return new ShutdownTaskResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShutdownTaskResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShutdownTaskResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShutdownTaskResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for resetting a task state in the service. + /// + public sealed partial class ResetTaskRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResetTaskRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskRequest(ResetTaskRequest other) : this() { + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskRequest Clone() { + return new ResetTaskRequest(this); + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 1; + private global::Tensorflow.CoordinatedTask sourceTask_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ResetTaskRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ResetTaskRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (sourceTask_ != null) { + output.WriteRawTag(10); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (sourceTask_ != null) { + output.WriteRawTag(10); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ResetTaskRequest other) { + if (other == null) { + return; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class ResetTaskResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResetTaskResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskResponse(ResetTaskResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetTaskResponse Clone() { + return new ResetTaskResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ResetTaskResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ResetTaskResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ResetTaskResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for reporting errors to task. + /// + public sealed partial class ReportErrorToTaskRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReportErrorToTaskRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskRequest(ReportErrorToTaskRequest other) : this() { + errorCode_ = other.errorCode_; + errorMessage_ = other.errorMessage_; + errorPayload_ = other.errorPayload_ != null ? other.errorPayload_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskRequest Clone() { + return new ReportErrorToTaskRequest(this); + } + + /// Field number for the "error_code" field. + public const int ErrorCodeFieldNumber = 1; + private int errorCode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ErrorCode { + get { return errorCode_; } + set { + errorCode_ = value; + } + } + + /// Field number for the "error_message" field. + public const int ErrorMessageFieldNumber = 2; + private string errorMessage_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ErrorMessage { + get { return errorMessage_; } + set { + errorMessage_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "error_payload" field. + public const int ErrorPayloadFieldNumber = 5; + private global::Tensorflow.CoordinationServiceError errorPayload_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinationServiceError ErrorPayload { + get { return errorPayload_; } + set { + errorPayload_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ReportErrorToTaskRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ReportErrorToTaskRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ErrorCode != other.ErrorCode) return false; + if (ErrorMessage != other.ErrorMessage) return false; + if (!object.Equals(ErrorPayload, other.ErrorPayload)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ErrorCode != 0) hash ^= ErrorCode.GetHashCode(); + if (ErrorMessage.Length != 0) hash ^= ErrorMessage.GetHashCode(); + if (errorPayload_ != null) hash ^= ErrorPayload.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ErrorCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ErrorMessage); + } + if (errorPayload_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorPayload); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ErrorCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ErrorMessage); + } + if (errorPayload_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorPayload); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ErrorCode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ErrorCode); + } + if (ErrorMessage.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ErrorMessage); + } + if (errorPayload_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ErrorPayload); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ReportErrorToTaskRequest other) { + if (other == null) { + return; + } + if (other.ErrorCode != 0) { + ErrorCode = other.ErrorCode; + } + if (other.ErrorMessage.Length != 0) { + ErrorMessage = other.ErrorMessage; + } + if (other.errorPayload_ != null) { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + ErrorPayload.MergeFrom(other.ErrorPayload); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ErrorCode = input.ReadInt32(); + break; + } + case 18: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + input.ReadMessage(ErrorPayload); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ErrorCode = input.ReadInt32(); + break; + } + case 18: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorPayload_ == null) { + ErrorPayload = new global::Tensorflow.CoordinationServiceError(); + } + input.ReadMessage(ErrorPayload); + break; + } + } + } + } + #endif + + } + + public sealed partial class ReportErrorToTaskResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReportErrorToTaskResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[17]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskResponse(ReportErrorToTaskResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToTaskResponse Clone() { + return new ReportErrorToTaskResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ReportErrorToTaskResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ReportErrorToTaskResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ReportErrorToTaskResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for reporting errors to service instance. + /// + public sealed partial class ReportErrorToServiceRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReportErrorToServiceRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[18]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceRequest(ReportErrorToServiceRequest other) : this() { + errorCode_ = other.errorCode_; + errorMessage_ = other.errorMessage_; + errorOrigin_ = other.errorOrigin_ != null ? other.errorOrigin_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceRequest Clone() { + return new ReportErrorToServiceRequest(this); + } + + /// Field number for the "error_code" field. + public const int ErrorCodeFieldNumber = 1; + private int errorCode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ErrorCode { + get { return errorCode_; } + set { + errorCode_ = value; + } + } + + /// Field number for the "error_message" field. + public const int ErrorMessageFieldNumber = 2; + private string errorMessage_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ErrorMessage { + get { return errorMessage_; } + set { + errorMessage_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "error_origin" field. + public const int ErrorOriginFieldNumber = 5; + private global::Tensorflow.CoordinatedTask errorOrigin_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask ErrorOrigin { + get { return errorOrigin_; } + set { + errorOrigin_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ReportErrorToServiceRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ReportErrorToServiceRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ErrorCode != other.ErrorCode) return false; + if (ErrorMessage != other.ErrorMessage) return false; + if (!object.Equals(ErrorOrigin, other.ErrorOrigin)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ErrorCode != 0) hash ^= ErrorCode.GetHashCode(); + if (ErrorMessage.Length != 0) hash ^= ErrorMessage.GetHashCode(); + if (errorOrigin_ != null) hash ^= ErrorOrigin.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ErrorCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ErrorMessage); + } + if (errorOrigin_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorOrigin); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ErrorCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ErrorCode); + } + if (ErrorMessage.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ErrorMessage); + } + if (errorOrigin_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ErrorOrigin); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ErrorCode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ErrorCode); + } + if (ErrorMessage.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ErrorMessage); + } + if (errorOrigin_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ErrorOrigin); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ReportErrorToServiceRequest other) { + if (other == null) { + return; + } + if (other.ErrorCode != 0) { + ErrorCode = other.ErrorCode; + } + if (other.ErrorMessage.Length != 0) { + ErrorMessage = other.ErrorMessage; + } + if (other.errorOrigin_ != null) { + if (errorOrigin_ == null) { + ErrorOrigin = new global::Tensorflow.CoordinatedTask(); + } + ErrorOrigin.MergeFrom(other.ErrorOrigin); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ErrorCode = input.ReadInt32(); + break; + } + case 18: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorOrigin_ == null) { + ErrorOrigin = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(ErrorOrigin); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ErrorCode = input.ReadInt32(); + break; + } + case 18: { + ErrorMessage = input.ReadString(); + break; + } + case 42: { + if (errorOrigin_ == null) { + ErrorOrigin = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(ErrorOrigin); + break; + } + } + } + } + #endif + + } + + public sealed partial class ReportErrorToServiceResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReportErrorToServiceResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[19]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceResponse(ReportErrorToServiceResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReportErrorToServiceResponse Clone() { + return new ReportErrorToServiceResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ReportErrorToServiceResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ReportErrorToServiceResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ReportErrorToServiceResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for getting state of a remote task. + /// + public sealed partial class GetTaskStateRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetTaskStateRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[20]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateRequest(GetTaskStateRequest other) : this() { + sourceTask_ = other.sourceTask_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateRequest Clone() { + return new GetTaskStateRequest(this); + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_sourceTask_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.CoordinatedTask.Parser); + private readonly pbc::RepeatedField sourceTask_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SourceTask { + get { return sourceTask_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetTaskStateRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetTaskStateRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!sourceTask_.Equals(other.sourceTask_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= sourceTask_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + sourceTask_.WriteTo(output, _repeated_sourceTask_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + sourceTask_.WriteTo(ref output, _repeated_sourceTask_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += sourceTask_.CalculateSize(_repeated_sourceTask_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetTaskStateRequest other) { + if (other == null) { + return; + } + sourceTask_.Add(other.sourceTask_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + sourceTask_.AddEntriesFrom(input, _repeated_sourceTask_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + sourceTask_.AddEntriesFrom(ref input, _repeated_sourceTask_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetTaskStateResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetTaskStateResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[21]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateResponse(GetTaskStateResponse other) : this() { + taskState_ = other.taskState_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetTaskStateResponse Clone() { + return new GetTaskStateResponse(this); + } + + /// Field number for the "task_state" field. + public const int TaskStateFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_taskState_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.CoordinatedTaskStateInfo.Parser); + private readonly pbc::RepeatedField taskState_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TaskState { + get { return taskState_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetTaskStateResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetTaskStateResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!taskState_.Equals(other.taskState_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= taskState_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + taskState_.WriteTo(output, _repeated_taskState_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + taskState_.WriteTo(ref output, _repeated_taskState_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += taskState_.CalculateSize(_repeated_taskState_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetTaskStateResponse other) { + if (other == null) { + return; + } + taskState_.Add(other.taskState_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + taskState_.AddEntriesFrom(input, _repeated_taskState_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + taskState_.AddEntriesFrom(ref input, _repeated_taskState_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Message for configuration key value. + /// Key is structured like Unix file system, with multiple levels of directory + /// names separated by the slash ('/') characters. + /// + public sealed partial class KeyValueEntry : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeyValueEntry()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[22]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueEntry() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueEntry(KeyValueEntry other) : this() { + key_ = other.key_; + value_ = other.value_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueEntry Clone() { + return new KeyValueEntry(this); + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 2; + private pb::ByteString value_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Value { + get { return value_; } + set { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KeyValueEntry); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KeyValueEntry other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Key != other.Key) return false; + if (Value != other.Value) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (Value.Length != 0) hash ^= Value.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (Value.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (Value.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (Value.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Value); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KeyValueEntry other) { + if (other == null) { + return; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + if (other.Value.Length != 0) { + Value = other.Value; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 18: { + Value = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 18: { + Value = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for inserting configuration key-value data. + /// + public sealed partial class InsertKeyValueRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InsertKeyValueRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[23]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueRequest(InsertKeyValueRequest other) : this() { + kv_ = other.kv_ != null ? other.kv_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueRequest Clone() { + return new InsertKeyValueRequest(this); + } + + /// Field number for the "kv" field. + public const int KvFieldNumber = 1; + private global::Tensorflow.KeyValueEntry kv_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.KeyValueEntry Kv { + get { return kv_; } + set { + kv_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as InsertKeyValueRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(InsertKeyValueRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Kv, other.Kv)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (kv_ != null) hash ^= Kv.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (kv_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Kv); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(InsertKeyValueRequest other) { + if (other == null) { + return; + } + if (other.kv_ != null) { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + Kv.MergeFrom(other.Kv); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + } + #endif + + } + + public sealed partial class InsertKeyValueResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InsertKeyValueResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[24]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueResponse(InsertKeyValueResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InsertKeyValueResponse Clone() { + return new InsertKeyValueResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as InsertKeyValueResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(InsertKeyValueResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(InsertKeyValueResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for getting configuration key-value data. + /// + public sealed partial class GetKeyValueRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetKeyValueRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[25]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueRequest(GetKeyValueRequest other) : this() { + key_ = other.key_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueRequest Clone() { + return new GetKeyValueRequest(this); + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetKeyValueRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetKeyValueRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Key != other.Key) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetKeyValueRequest other) { + if (other == null) { + return; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Key = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Key = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetKeyValueResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetKeyValueResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[26]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueResponse(GetKeyValueResponse other) : this() { + kv_ = other.kv_ != null ? other.kv_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueResponse Clone() { + return new GetKeyValueResponse(this); + } + + /// Field number for the "kv" field. + public const int KvFieldNumber = 1; + private global::Tensorflow.KeyValueEntry kv_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.KeyValueEntry Kv { + get { return kv_; } + set { + kv_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetKeyValueResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetKeyValueResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Kv, other.Kv)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (kv_ != null) hash ^= Kv.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (kv_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Kv); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetKeyValueResponse other) { + if (other == null) { + return; + } + if (other.kv_ != null) { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + Kv.MergeFrom(other.Kv); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + } + #endif + + } + + public sealed partial class TryGetKeyValueRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TryGetKeyValueRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[27]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueRequest(TryGetKeyValueRequest other) : this() { + key_ = other.key_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueRequest Clone() { + return new TryGetKeyValueRequest(this); + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TryGetKeyValueRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TryGetKeyValueRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Key != other.Key) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TryGetKeyValueRequest other) { + if (other == null) { + return; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Key = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Key = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class TryGetKeyValueResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TryGetKeyValueResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[28]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueResponse(TryGetKeyValueResponse other) : this() { + kv_ = other.kv_ != null ? other.kv_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TryGetKeyValueResponse Clone() { + return new TryGetKeyValueResponse(this); + } + + /// Field number for the "kv" field. + public const int KvFieldNumber = 1; + private global::Tensorflow.KeyValueEntry kv_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.KeyValueEntry Kv { + get { return kv_; } + set { + kv_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TryGetKeyValueResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TryGetKeyValueResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Kv, other.Kv)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (kv_ != null) hash ^= Kv.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (kv_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Kv); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (kv_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Kv); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TryGetKeyValueResponse other) { + if (other == null) { + return; + } + if (other.kv_ != null) { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + Kv.MergeFrom(other.Kv); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (kv_ == null) { + Kv = new global::Tensorflow.KeyValueEntry(); + } + input.ReadMessage(Kv); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetKeyValueDirRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetKeyValueDirRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[29]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirRequest(GetKeyValueDirRequest other) : this() { + directoryKey_ = other.directoryKey_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirRequest Clone() { + return new GetKeyValueDirRequest(this); + } + + /// Field number for the "directory_key" field. + public const int DirectoryKeyFieldNumber = 1; + private string directoryKey_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DirectoryKey { + get { return directoryKey_; } + set { + directoryKey_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetKeyValueDirRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetKeyValueDirRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DirectoryKey != other.DirectoryKey) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DirectoryKey.Length != 0) hash ^= DirectoryKey.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DirectoryKey.Length != 0) { + output.WriteRawTag(10); + output.WriteString(DirectoryKey); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DirectoryKey.Length != 0) { + output.WriteRawTag(10); + output.WriteString(DirectoryKey); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DirectoryKey.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DirectoryKey); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetKeyValueDirRequest other) { + if (other == null) { + return; + } + if (other.DirectoryKey.Length != 0) { + DirectoryKey = other.DirectoryKey; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + DirectoryKey = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + DirectoryKey = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetKeyValueDirResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetKeyValueDirResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[30]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirResponse(GetKeyValueDirResponse other) : this() { + directoryKey_ = other.directoryKey_; + kv_ = other.kv_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetKeyValueDirResponse Clone() { + return new GetKeyValueDirResponse(this); + } + + /// Field number for the "directory_key" field. + public const int DirectoryKeyFieldNumber = 1; + private string directoryKey_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DirectoryKey { + get { return directoryKey_; } + set { + directoryKey_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "kv" field. + public const int KvFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_kv_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.KeyValueEntry.Parser); + private readonly pbc::RepeatedField kv_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Kv { + get { return kv_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetKeyValueDirResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetKeyValueDirResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DirectoryKey != other.DirectoryKey) return false; + if(!kv_.Equals(other.kv_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DirectoryKey.Length != 0) hash ^= DirectoryKey.GetHashCode(); + hash ^= kv_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DirectoryKey.Length != 0) { + output.WriteRawTag(10); + output.WriteString(DirectoryKey); + } + kv_.WriteTo(output, _repeated_kv_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DirectoryKey.Length != 0) { + output.WriteRawTag(10); + output.WriteString(DirectoryKey); + } + kv_.WriteTo(ref output, _repeated_kv_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DirectoryKey.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DirectoryKey); + } + size += kv_.CalculateSize(_repeated_kv_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetKeyValueDirResponse other) { + if (other == null) { + return; + } + if (other.DirectoryKey.Length != 0) { + DirectoryKey = other.DirectoryKey; + } + kv_.Add(other.kv_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + DirectoryKey = input.ReadString(); + break; + } + case 18: { + kv_.AddEntriesFrom(input, _repeated_kv_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + DirectoryKey = input.ReadString(); + break; + } + case 18: { + kv_.AddEntriesFrom(ref input, _repeated_kv_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Request and response messages for deleting configuration key-value data. + /// When is_directory is true, delete key-values recursively under `key`. + /// + public sealed partial class DeleteKeyValueRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeleteKeyValueRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[31]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueRequest(DeleteKeyValueRequest other) : this() { + key_ = other.key_; + isDirectory_ = other.isDirectory_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueRequest Clone() { + return new DeleteKeyValueRequest(this); + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "is_directory" field. + public const int IsDirectoryFieldNumber = 2; + private bool isDirectory_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsDirectory { + get { return isDirectory_; } + set { + isDirectory_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeleteKeyValueRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeleteKeyValueRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Key != other.Key) return false; + if (IsDirectory != other.IsDirectory) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (IsDirectory != false) hash ^= IsDirectory.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (IsDirectory != false) { + output.WriteRawTag(16); + output.WriteBool(IsDirectory); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (IsDirectory != false) { + output.WriteRawTag(16); + output.WriteBool(IsDirectory); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (IsDirectory != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeleteKeyValueRequest other) { + if (other == null) { + return; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + if (other.IsDirectory != false) { + IsDirectory = other.IsDirectory; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 16: { + IsDirectory = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 16: { + IsDirectory = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeleteKeyValueResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeleteKeyValueResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[32]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueResponse(DeleteKeyValueResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeleteKeyValueResponse Clone() { + return new DeleteKeyValueResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeleteKeyValueResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeleteKeyValueResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeleteKeyValueResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for generic sync barriers. + /// + public sealed partial class BarrierRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BarrierRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[33]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierRequest(BarrierRequest other) : this() { + barrierId_ = other.barrierId_; + barrierTimeoutInMs_ = other.barrierTimeoutInMs_; + tasks_ = other.tasks_.Clone(); + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierRequest Clone() { + return new BarrierRequest(this); + } + + /// Field number for the "barrier_id" field. + public const int BarrierIdFieldNumber = 1; + private string barrierId_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string BarrierId { + get { return barrierId_; } + set { + barrierId_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "barrier_timeout_in_ms" field. + public const int BarrierTimeoutInMsFieldNumber = 2; + private long barrierTimeoutInMs_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BarrierTimeoutInMs { + get { return barrierTimeoutInMs_; } + set { + barrierTimeoutInMs_ = value; + } + } + + /// Field number for the "tasks" field. + public const int TasksFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_tasks_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.CoordinatedTask.Parser); + private readonly pbc::RepeatedField tasks_ = new pbc::RepeatedField(); + /// + /// Denotes list of tasks that will wait for the barrier. If unspecified, it + /// implies that the entire cluster is participating in the barrier. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Tasks { + get { return tasks_; } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 4; + private global::Tensorflow.CoordinatedTask sourceTask_; + /// + /// Task that is making the request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BarrierRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BarrierRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BarrierId != other.BarrierId) return false; + if (BarrierTimeoutInMs != other.BarrierTimeoutInMs) return false; + if(!tasks_.Equals(other.tasks_)) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (BarrierId.Length != 0) hash ^= BarrierId.GetHashCode(); + if (BarrierTimeoutInMs != 0L) hash ^= BarrierTimeoutInMs.GetHashCode(); + hash ^= tasks_.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (BarrierId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BarrierId); + } + if (BarrierTimeoutInMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(BarrierTimeoutInMs); + } + tasks_.WriteTo(output, _repeated_tasks_codec); + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (BarrierId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BarrierId); + } + if (BarrierTimeoutInMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(BarrierTimeoutInMs); + } + tasks_.WriteTo(ref output, _repeated_tasks_codec); + if (sourceTask_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (BarrierId.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BarrierId); + } + if (BarrierTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BarrierTimeoutInMs); + } + size += tasks_.CalculateSize(_repeated_tasks_codec); + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BarrierRequest other) { + if (other == null) { + return; + } + if (other.BarrierId.Length != 0) { + BarrierId = other.BarrierId; + } + if (other.BarrierTimeoutInMs != 0L) { + BarrierTimeoutInMs = other.BarrierTimeoutInMs; + } + tasks_.Add(other.tasks_); + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + BarrierId = input.ReadString(); + break; + } + case 16: { + BarrierTimeoutInMs = input.ReadInt64(); + break; + } + case 26: { + tasks_.AddEntriesFrom(input, _repeated_tasks_codec); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + BarrierId = input.ReadString(); + break; + } + case 16: { + BarrierTimeoutInMs = input.ReadInt64(); + break; + } + case 26: { + tasks_.AddEntriesFrom(ref input, _repeated_tasks_codec); + break; + } + case 34: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class BarrierResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BarrierResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[34]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierResponse(BarrierResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BarrierResponse Clone() { + return new BarrierResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BarrierResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BarrierResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BarrierResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Request and response messages for cancelling generic sync barriers. + /// + public sealed partial class CancelBarrierRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CancelBarrierRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[35]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierRequest(CancelBarrierRequest other) : this() { + barrierId_ = other.barrierId_; + sourceTask_ = other.sourceTask_ != null ? other.sourceTask_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierRequest Clone() { + return new CancelBarrierRequest(this); + } + + /// Field number for the "barrier_id" field. + public const int BarrierIdFieldNumber = 1; + private string barrierId_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string BarrierId { + get { return barrierId_; } + set { + barrierId_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "source_task" field. + public const int SourceTaskFieldNumber = 2; + private global::Tensorflow.CoordinatedTask sourceTask_; + /// + /// Task that is making the request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CoordinatedTask SourceTask { + get { return sourceTask_; } + set { + sourceTask_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CancelBarrierRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CancelBarrierRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BarrierId != other.BarrierId) return false; + if (!object.Equals(SourceTask, other.SourceTask)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (BarrierId.Length != 0) hash ^= BarrierId.GetHashCode(); + if (sourceTask_ != null) hash ^= SourceTask.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (BarrierId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BarrierId); + } + if (sourceTask_ != null) { + output.WriteRawTag(18); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (BarrierId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(BarrierId); + } + if (sourceTask_ != null) { + output.WriteRawTag(18); + output.WriteMessage(SourceTask); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (BarrierId.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BarrierId); + } + if (sourceTask_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SourceTask); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CancelBarrierRequest other) { + if (other == null) { + return; + } + if (other.BarrierId.Length != 0) { + BarrierId = other.BarrierId; + } + if (other.sourceTask_ != null) { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + SourceTask.MergeFrom(other.SourceTask); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + BarrierId = input.ReadString(); + break; + } + case 18: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + BarrierId = input.ReadString(); + break; + } + case 18: { + if (sourceTask_ == null) { + SourceTask = new global::Tensorflow.CoordinatedTask(); + } + input.ReadMessage(SourceTask); + break; + } + } + } + } + #endif + + } + + public sealed partial class CancelBarrierResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CancelBarrierResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CoordinationServiceReflection.Descriptor.MessageTypes[36]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierResponse(CancelBarrierResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CancelBarrierResponse Clone() { + return new CancelBarrierResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CancelBarrierResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CancelBarrierResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CancelBarrierResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/CostGraph.cs b/src/TensorFlowNET.Core/Protobuf/CostGraph.cs new file mode 100644 index 000000000..fc655d400 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CostGraph.cs @@ -0,0 +1,1825 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/cost_graph.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/cost_graph.proto + public static partial class CostGraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/cost_graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CostGraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Nvc3RfZ3JhcGgucHJvdG8S", + "CnRlbnNvcmZsb3caLHRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdGVuc29y", + "X3NoYXBlLnByb3RvGiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3R5cGVz", + "LnByb3RvIsoGCgxDb3N0R3JhcGhEZWYSKwoEbm9kZRgBIAMoCzIdLnRlbnNv", + "cmZsb3cuQ29zdEdyYXBoRGVmLk5vZGUSNQoEY29zdBgCIAMoCzInLnRlbnNv", + "cmZsb3cuQ29zdEdyYXBoRGVmLkFnZ3JlZ2F0ZWRDb3N0GqIFCgROb2RlEgwK", + "BG5hbWUYASABKAkSDgoGZGV2aWNlGAIgASgJEgoKAmlkGAMgASgFEjsKCmlu", + "cHV0X2luZm8YBCADKAsyJy50ZW5zb3JmbG93LkNvc3RHcmFwaERlZi5Ob2Rl", + "LklucHV0SW5mbxI9CgtvdXRwdXRfaW5mbxgFIAMoCzIoLnRlbnNvcmZsb3cu", + "Q29zdEdyYXBoRGVmLk5vZGUuT3V0cHV0SW5mbxIdChV0ZW1wb3JhcnlfbWVt", + "b3J5X3NpemUYBiABKAMSHgoWcGVyc2lzdGVudF9tZW1vcnlfc2l6ZRgMIAEo", + "AxIhChVob3N0X3RlbXBfbWVtb3J5X3NpemUYCiABKANCAhgBEiMKF2Rldmlj", + "ZV90ZW1wX21lbW9yeV9zaXplGAsgASgDQgIYARIpCh1kZXZpY2VfcGVyc2lz", + "dGVudF9tZW1vcnlfc2l6ZRgQIAEoA0ICGAESFAoMY29tcHV0ZV9jb3N0GAkg", + "ASgDEhQKDGNvbXB1dGVfdGltZRgOIAEoAxITCgttZW1vcnlfdGltZRgPIAEo", + "AxIQCghpc19maW5hbBgHIAEoCBIVCg1jb250cm9sX2lucHV0GAggAygFEhIK", + "CmluYWNjdXJhdGUYESABKAgaOwoJSW5wdXRJbmZvEhYKDnByZWNlZGluZ19u", + "b2RlGAEgASgFEhYKDnByZWNlZGluZ19wb3J0GAIgASgFGoYBCgpPdXRwdXRJ", + "bmZvEgwKBHNpemUYASABKAMSGAoQYWxpYXNfaW5wdXRfcG9ydBgCIAEoAxIr", + "CgVzaGFwZRgDIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90bxIj", + "CgVkdHlwZRgEIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGUaMQoOQWdncmVn", + "YXRlZENvc3QSDAoEY29zdBgBIAEoAhIRCglkaW1lbnNpb24YAiABKAlCgwEK", + "GG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IPQ29zdEdyYXBoUHJvdG9zUAFa", + "UWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cv", + "Z28vY29yZS9mcmFtZXdvcmsvY29zdF9ncmFwaF9nb19wcm90b/gBAWIGcHJv", + "dG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CostGraphDef), global::Tensorflow.CostGraphDef.Parser, new[]{ "Node", "Cost" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CostGraphDef.Types.Node), global::Tensorflow.CostGraphDef.Types.Node.Parser, new[]{ "Name", "Device", "Id", "InputInfo", "OutputInfo", "TemporaryMemorySize", "PersistentMemorySize", "HostTempMemorySize", "DeviceTempMemorySize", "DevicePersistentMemorySize", "ComputeCost", "ComputeTime", "MemoryTime", "IsFinal", "ControlInput", "Inaccurate" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CostGraphDef.Types.Node.Types.InputInfo), global::Tensorflow.CostGraphDef.Types.Node.Types.InputInfo.Parser, new[]{ "PrecedingNode", "PrecedingPort" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CostGraphDef.Types.Node.Types.OutputInfo), global::Tensorflow.CostGraphDef.Types.Node.Types.OutputInfo.Parser, new[]{ "Size", "AliasInputPort", "Shape", "Dtype" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CostGraphDef.Types.AggregatedCost), global::Tensorflow.CostGraphDef.Types.AggregatedCost.Parser, new[]{ "Cost", "Dimension" }, null, null, null, null)}) + })); + } + #endregion + + } + #region Messages + public sealed partial class CostGraphDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CostGraphDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CostGraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CostGraphDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CostGraphDef(CostGraphDef other) : this() { + node_ = other.node_.Clone(); + cost_ = other.cost_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CostGraphDef Clone() { + return new CostGraphDef(this); + } + + /// Field number for the "node" field. + public const int NodeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_node_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.CostGraphDef.Types.Node.Parser); + private readonly pbc::RepeatedField node_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Node { + get { return node_; } + } + + /// Field number for the "cost" field. + public const int CostFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_cost_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.CostGraphDef.Types.AggregatedCost.Parser); + private readonly pbc::RepeatedField cost_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Cost { + get { return cost_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CostGraphDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CostGraphDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!node_.Equals(other.node_)) return false; + if(!cost_.Equals(other.cost_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= node_.GetHashCode(); + hash ^= cost_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + node_.WriteTo(output, _repeated_node_codec); + cost_.WriteTo(output, _repeated_cost_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + node_.WriteTo(ref output, _repeated_node_codec); + cost_.WriteTo(ref output, _repeated_cost_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += node_.CalculateSize(_repeated_node_codec); + size += cost_.CalculateSize(_repeated_cost_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CostGraphDef other) { + if (other == null) { + return; + } + node_.Add(other.node_); + cost_.Add(other.cost_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + node_.AddEntriesFrom(input, _repeated_node_codec); + break; + } + case 18: { + cost_.AddEntriesFrom(input, _repeated_cost_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + node_.AddEntriesFrom(ref input, _repeated_node_codec); + break; + } + case 18: { + cost_.AddEntriesFrom(ref input, _repeated_cost_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the CostGraphDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class Node : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Node()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CostGraphDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Node() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Node(Node other) : this() { + name_ = other.name_; + device_ = other.device_; + id_ = other.id_; + inputInfo_ = other.inputInfo_.Clone(); + outputInfo_ = other.outputInfo_.Clone(); + temporaryMemorySize_ = other.temporaryMemorySize_; + persistentMemorySize_ = other.persistentMemorySize_; + hostTempMemorySize_ = other.hostTempMemorySize_; + deviceTempMemorySize_ = other.deviceTempMemorySize_; + devicePersistentMemorySize_ = other.devicePersistentMemorySize_; + computeCost_ = other.computeCost_; + computeTime_ = other.computeTime_; + memoryTime_ = other.memoryTime_; + isFinal_ = other.isFinal_; + controlInput_ = other.controlInput_.Clone(); + inaccurate_ = other.inaccurate_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Node Clone() { + return new Node(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// The name of the node. Names are globally unique. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 2; + private string device_ = ""; + /// + /// The device of the node. Can be empty if the node is mapped to the + /// default partition or partitioning hasn't been run yet. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 3; + private int id_; + /// + /// The id of the node. Node ids are only unique inside a partition. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "input_info" field. + public const int InputInfoFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_inputInfo_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.CostGraphDef.Types.Node.Types.InputInfo.Parser); + private readonly pbc::RepeatedField inputInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InputInfo { + get { return inputInfo_; } + } + + /// Field number for the "output_info" field. + public const int OutputInfoFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_outputInfo_codec + = pb::FieldCodec.ForMessage(42, global::Tensorflow.CostGraphDef.Types.Node.Types.OutputInfo.Parser); + private readonly pbc::RepeatedField outputInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OutputInfo { + get { return outputInfo_; } + } + + /// Field number for the "temporary_memory_size" field. + public const int TemporaryMemorySizeFieldNumber = 6; + private long temporaryMemorySize_; + /// + /// Temporary memory used by this node. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TemporaryMemorySize { + get { return temporaryMemorySize_; } + set { + temporaryMemorySize_ = value; + } + } + + /// Field number for the "persistent_memory_size" field. + public const int PersistentMemorySizeFieldNumber = 12; + private long persistentMemorySize_; + /// + /// Persistent memory used by this node. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PersistentMemorySize { + get { return persistentMemorySize_; } + set { + persistentMemorySize_ = value; + } + } + + /// Field number for the "host_temp_memory_size" field. + public const int HostTempMemorySizeFieldNumber = 10; + private long hostTempMemorySize_; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long HostTempMemorySize { + get { return hostTempMemorySize_; } + set { + hostTempMemorySize_ = value; + } + } + + /// Field number for the "device_temp_memory_size" field. + public const int DeviceTempMemorySizeFieldNumber = 11; + private long deviceTempMemorySize_; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DeviceTempMemorySize { + get { return deviceTempMemorySize_; } + set { + deviceTempMemorySize_ = value; + } + } + + /// Field number for the "device_persistent_memory_size" field. + public const int DevicePersistentMemorySizeFieldNumber = 16; + private long devicePersistentMemorySize_; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DevicePersistentMemorySize { + get { return devicePersistentMemorySize_; } + set { + devicePersistentMemorySize_ = value; + } + } + + /// Field number for the "compute_cost" field. + public const int ComputeCostFieldNumber = 9; + private long computeCost_; + /// + /// Estimate of the computational cost of this node, in microseconds. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ComputeCost { + get { return computeCost_; } + set { + computeCost_ = value; + } + } + + /// Field number for the "compute_time" field. + public const int ComputeTimeFieldNumber = 14; + private long computeTime_; + /// + /// Analytical estimate of the computational cost of this node, in + /// microseconds. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ComputeTime { + get { return computeTime_; } + set { + computeTime_ = value; + } + } + + /// Field number for the "memory_time" field. + public const int MemoryTimeFieldNumber = 15; + private long memoryTime_; + /// + /// Analytical estimate of the memory access cost of this node, in + /// microseconds. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long MemoryTime { + get { return memoryTime_; } + set { + memoryTime_ = value; + } + } + + /// Field number for the "is_final" field. + public const int IsFinalFieldNumber = 7; + private bool isFinal_; + /// + /// If true, the output is permanent: it can't be discarded, because this + /// node is part of the "final output". Nodes may depend on final nodes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsFinal { + get { return isFinal_; } + set { + isFinal_ = value; + } + } + + /// Field number for the "control_input" field. + public const int ControlInputFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_controlInput_codec + = pb::FieldCodec.ForInt32(66); + private readonly pbc::RepeatedField controlInput_ = new pbc::RepeatedField(); + /// + /// Ids of the control inputs for this node. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ControlInput { + get { return controlInput_; } + } + + /// Field number for the "inaccurate" field. + public const int InaccurateFieldNumber = 17; + private bool inaccurate_; + /// + /// Are the costs inaccurate? + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Inaccurate { + get { return inaccurate_; } + set { + inaccurate_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Node); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Node other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Device != other.Device) return false; + if (Id != other.Id) return false; + if(!inputInfo_.Equals(other.inputInfo_)) return false; + if(!outputInfo_.Equals(other.outputInfo_)) return false; + if (TemporaryMemorySize != other.TemporaryMemorySize) return false; + if (PersistentMemorySize != other.PersistentMemorySize) return false; + if (HostTempMemorySize != other.HostTempMemorySize) return false; + if (DeviceTempMemorySize != other.DeviceTempMemorySize) return false; + if (DevicePersistentMemorySize != other.DevicePersistentMemorySize) return false; + if (ComputeCost != other.ComputeCost) return false; + if (ComputeTime != other.ComputeTime) return false; + if (MemoryTime != other.MemoryTime) return false; + if (IsFinal != other.IsFinal) return false; + if(!controlInput_.Equals(other.controlInput_)) return false; + if (Inaccurate != other.Inaccurate) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Device.Length != 0) hash ^= Device.GetHashCode(); + if (Id != 0) hash ^= Id.GetHashCode(); + hash ^= inputInfo_.GetHashCode(); + hash ^= outputInfo_.GetHashCode(); + if (TemporaryMemorySize != 0L) hash ^= TemporaryMemorySize.GetHashCode(); + if (PersistentMemorySize != 0L) hash ^= PersistentMemorySize.GetHashCode(); + if (HostTempMemorySize != 0L) hash ^= HostTempMemorySize.GetHashCode(); + if (DeviceTempMemorySize != 0L) hash ^= DeviceTempMemorySize.GetHashCode(); + if (DevicePersistentMemorySize != 0L) hash ^= DevicePersistentMemorySize.GetHashCode(); + if (ComputeCost != 0L) hash ^= ComputeCost.GetHashCode(); + if (ComputeTime != 0L) hash ^= ComputeTime.GetHashCode(); + if (MemoryTime != 0L) hash ^= MemoryTime.GetHashCode(); + if (IsFinal != false) hash ^= IsFinal.GetHashCode(); + hash ^= controlInput_.GetHashCode(); + if (Inaccurate != false) hash ^= Inaccurate.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Device.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Device); + } + if (Id != 0) { + output.WriteRawTag(24); + output.WriteInt32(Id); + } + inputInfo_.WriteTo(output, _repeated_inputInfo_codec); + outputInfo_.WriteTo(output, _repeated_outputInfo_codec); + if (TemporaryMemorySize != 0L) { + output.WriteRawTag(48); + output.WriteInt64(TemporaryMemorySize); + } + if (IsFinal != false) { + output.WriteRawTag(56); + output.WriteBool(IsFinal); + } + controlInput_.WriteTo(output, _repeated_controlInput_codec); + if (ComputeCost != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ComputeCost); + } + if (HostTempMemorySize != 0L) { + output.WriteRawTag(80); + output.WriteInt64(HostTempMemorySize); + } + if (DeviceTempMemorySize != 0L) { + output.WriteRawTag(88); + output.WriteInt64(DeviceTempMemorySize); + } + if (PersistentMemorySize != 0L) { + output.WriteRawTag(96); + output.WriteInt64(PersistentMemorySize); + } + if (ComputeTime != 0L) { + output.WriteRawTag(112); + output.WriteInt64(ComputeTime); + } + if (MemoryTime != 0L) { + output.WriteRawTag(120); + output.WriteInt64(MemoryTime); + } + if (DevicePersistentMemorySize != 0L) { + output.WriteRawTag(128, 1); + output.WriteInt64(DevicePersistentMemorySize); + } + if (Inaccurate != false) { + output.WriteRawTag(136, 1); + output.WriteBool(Inaccurate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Device.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Device); + } + if (Id != 0) { + output.WriteRawTag(24); + output.WriteInt32(Id); + } + inputInfo_.WriteTo(ref output, _repeated_inputInfo_codec); + outputInfo_.WriteTo(ref output, _repeated_outputInfo_codec); + if (TemporaryMemorySize != 0L) { + output.WriteRawTag(48); + output.WriteInt64(TemporaryMemorySize); + } + if (IsFinal != false) { + output.WriteRawTag(56); + output.WriteBool(IsFinal); + } + controlInput_.WriteTo(ref output, _repeated_controlInput_codec); + if (ComputeCost != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ComputeCost); + } + if (HostTempMemorySize != 0L) { + output.WriteRawTag(80); + output.WriteInt64(HostTempMemorySize); + } + if (DeviceTempMemorySize != 0L) { + output.WriteRawTag(88); + output.WriteInt64(DeviceTempMemorySize); + } + if (PersistentMemorySize != 0L) { + output.WriteRawTag(96); + output.WriteInt64(PersistentMemorySize); + } + if (ComputeTime != 0L) { + output.WriteRawTag(112); + output.WriteInt64(ComputeTime); + } + if (MemoryTime != 0L) { + output.WriteRawTag(120); + output.WriteInt64(MemoryTime); + } + if (DevicePersistentMemorySize != 0L) { + output.WriteRawTag(128, 1); + output.WriteInt64(DevicePersistentMemorySize); + } + if (Inaccurate != false) { + output.WriteRawTag(136, 1); + output.WriteBool(Inaccurate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + if (Id != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); + } + size += inputInfo_.CalculateSize(_repeated_inputInfo_codec); + size += outputInfo_.CalculateSize(_repeated_outputInfo_codec); + if (TemporaryMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TemporaryMemorySize); + } + if (PersistentMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PersistentMemorySize); + } + if (HostTempMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(HostTempMemorySize); + } + if (DeviceTempMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DeviceTempMemorySize); + } + if (DevicePersistentMemorySize != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(DevicePersistentMemorySize); + } + if (ComputeCost != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ComputeCost); + } + if (ComputeTime != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ComputeTime); + } + if (MemoryTime != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(MemoryTime); + } + if (IsFinal != false) { + size += 1 + 1; + } + size += controlInput_.CalculateSize(_repeated_controlInput_codec); + if (Inaccurate != false) { + size += 2 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Node other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Device.Length != 0) { + Device = other.Device; + } + if (other.Id != 0) { + Id = other.Id; + } + inputInfo_.Add(other.inputInfo_); + outputInfo_.Add(other.outputInfo_); + if (other.TemporaryMemorySize != 0L) { + TemporaryMemorySize = other.TemporaryMemorySize; + } + if (other.PersistentMemorySize != 0L) { + PersistentMemorySize = other.PersistentMemorySize; + } + if (other.HostTempMemorySize != 0L) { + HostTempMemorySize = other.HostTempMemorySize; + } + if (other.DeviceTempMemorySize != 0L) { + DeviceTempMemorySize = other.DeviceTempMemorySize; + } + if (other.DevicePersistentMemorySize != 0L) { + DevicePersistentMemorySize = other.DevicePersistentMemorySize; + } + if (other.ComputeCost != 0L) { + ComputeCost = other.ComputeCost; + } + if (other.ComputeTime != 0L) { + ComputeTime = other.ComputeTime; + } + if (other.MemoryTime != 0L) { + MemoryTime = other.MemoryTime; + } + if (other.IsFinal != false) { + IsFinal = other.IsFinal; + } + controlInput_.Add(other.controlInput_); + if (other.Inaccurate != false) { + Inaccurate = other.Inaccurate; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Device = input.ReadString(); + break; + } + case 24: { + Id = input.ReadInt32(); + break; + } + case 34: { + inputInfo_.AddEntriesFrom(input, _repeated_inputInfo_codec); + break; + } + case 42: { + outputInfo_.AddEntriesFrom(input, _repeated_outputInfo_codec); + break; + } + case 48: { + TemporaryMemorySize = input.ReadInt64(); + break; + } + case 56: { + IsFinal = input.ReadBool(); + break; + } + case 66: + case 64: { + controlInput_.AddEntriesFrom(input, _repeated_controlInput_codec); + break; + } + case 72: { + ComputeCost = input.ReadInt64(); + break; + } + case 80: { + HostTempMemorySize = input.ReadInt64(); + break; + } + case 88: { + DeviceTempMemorySize = input.ReadInt64(); + break; + } + case 96: { + PersistentMemorySize = input.ReadInt64(); + break; + } + case 112: { + ComputeTime = input.ReadInt64(); + break; + } + case 120: { + MemoryTime = input.ReadInt64(); + break; + } + case 128: { + DevicePersistentMemorySize = input.ReadInt64(); + break; + } + case 136: { + Inaccurate = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Device = input.ReadString(); + break; + } + case 24: { + Id = input.ReadInt32(); + break; + } + case 34: { + inputInfo_.AddEntriesFrom(ref input, _repeated_inputInfo_codec); + break; + } + case 42: { + outputInfo_.AddEntriesFrom(ref input, _repeated_outputInfo_codec); + break; + } + case 48: { + TemporaryMemorySize = input.ReadInt64(); + break; + } + case 56: { + IsFinal = input.ReadBool(); + break; + } + case 66: + case 64: { + controlInput_.AddEntriesFrom(ref input, _repeated_controlInput_codec); + break; + } + case 72: { + ComputeCost = input.ReadInt64(); + break; + } + case 80: { + HostTempMemorySize = input.ReadInt64(); + break; + } + case 88: { + DeviceTempMemorySize = input.ReadInt64(); + break; + } + case 96: { + PersistentMemorySize = input.ReadInt64(); + break; + } + case 112: { + ComputeTime = input.ReadInt64(); + break; + } + case 120: { + MemoryTime = input.ReadInt64(); + break; + } + case 128: { + DevicePersistentMemorySize = input.ReadInt64(); + break; + } + case 136: { + Inaccurate = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Node message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Inputs of this node. They must be executed before this node can be + /// executed. An input is a particular output of another node, specified + /// by the node id and the output index. + /// + public sealed partial class InputInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InputInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CostGraphDef.Types.Node.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InputInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InputInfo(InputInfo other) : this() { + precedingNode_ = other.precedingNode_; + precedingPort_ = other.precedingPort_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InputInfo Clone() { + return new InputInfo(this); + } + + /// Field number for the "preceding_node" field. + public const int PrecedingNodeFieldNumber = 1; + private int precedingNode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PrecedingNode { + get { return precedingNode_; } + set { + precedingNode_ = value; + } + } + + /// Field number for the "preceding_port" field. + public const int PrecedingPortFieldNumber = 2; + private int precedingPort_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PrecedingPort { + get { return precedingPort_; } + set { + precedingPort_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as InputInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(InputInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (PrecedingNode != other.PrecedingNode) return false; + if (PrecedingPort != other.PrecedingPort) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (PrecedingNode != 0) hash ^= PrecedingNode.GetHashCode(); + if (PrecedingPort != 0) hash ^= PrecedingPort.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (PrecedingNode != 0) { + output.WriteRawTag(8); + output.WriteInt32(PrecedingNode); + } + if (PrecedingPort != 0) { + output.WriteRawTag(16); + output.WriteInt32(PrecedingPort); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (PrecedingNode != 0) { + output.WriteRawTag(8); + output.WriteInt32(PrecedingNode); + } + if (PrecedingPort != 0) { + output.WriteRawTag(16); + output.WriteInt32(PrecedingPort); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (PrecedingNode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PrecedingNode); + } + if (PrecedingPort != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PrecedingPort); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(InputInfo other) { + if (other == null) { + return; + } + if (other.PrecedingNode != 0) { + PrecedingNode = other.PrecedingNode; + } + if (other.PrecedingPort != 0) { + PrecedingPort = other.PrecedingPort; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + PrecedingNode = input.ReadInt32(); + break; + } + case 16: { + PrecedingPort = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + PrecedingNode = input.ReadInt32(); + break; + } + case 16: { + PrecedingPort = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + /// + /// Outputs of this node. + /// + public sealed partial class OutputInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OutputInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CostGraphDef.Types.Node.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OutputInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OutputInfo(OutputInfo other) : this() { + size_ = other.size_; + aliasInputPort_ = other.aliasInputPort_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + dtype_ = other.dtype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OutputInfo Clone() { + return new OutputInfo(this); + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 1; + private long size_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + /// Field number for the "alias_input_port" field. + public const int AliasInputPortFieldNumber = 2; + private long aliasInputPort_; + /// + /// If >= 0, the output is an alias of an input. Note that an alias input + /// may itself be an alias. The algorithm will therefore need to follow + /// those pointers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AliasInputPort { + get { return aliasInputPort_; } + set { + aliasInputPort_ = value; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 3; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 4; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OutputInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OutputInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Size != other.Size) return false; + if (AliasInputPort != other.AliasInputPort) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (Dtype != other.Dtype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Size != 0L) hash ^= Size.GetHashCode(); + if (AliasInputPort != 0L) hash ^= AliasInputPort.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (AliasInputPort != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AliasInputPort); + } + if (shape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(32); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (AliasInputPort != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AliasInputPort); + } + if (shape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(32); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (AliasInputPort != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AliasInputPort); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OutputInfo other) { + if (other == null) { + return; + } + if (other.Size != 0L) { + Size = other.Size; + } + if (other.AliasInputPort != 0L) { + AliasInputPort = other.AliasInputPort; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 16: { + AliasInputPort = input.ReadInt64(); + break; + } + case 26: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 32: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 16: { + AliasInputPort = input.ReadInt64(); + break; + } + case 26: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 32: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Total cost of this graph, typically used for balancing decisions. + /// + public sealed partial class AggregatedCost : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AggregatedCost()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CostGraphDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AggregatedCost() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AggregatedCost(AggregatedCost other) : this() { + cost_ = other.cost_; + dimension_ = other.dimension_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AggregatedCost Clone() { + return new AggregatedCost(this); + } + + /// Field number for the "cost" field. + public const int CostFieldNumber = 1; + private float cost_; + /// + /// Aggregated cost value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public float Cost { + get { return cost_; } + set { + cost_ = value; + } + } + + /// Field number for the "dimension" field. + public const int DimensionFieldNumber = 2; + private string dimension_ = ""; + /// + /// Aggregated cost dimension (e.g. 'memory', 'compute', 'network'). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Dimension { + get { return dimension_; } + set { + dimension_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AggregatedCost); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AggregatedCost other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Cost, other.Cost)) return false; + if (Dimension != other.Dimension) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Cost != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Cost); + if (Dimension.Length != 0) hash ^= Dimension.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Cost != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Cost); + } + if (Dimension.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Dimension); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Cost != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Cost); + } + if (Dimension.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Dimension); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Cost != 0F) { + size += 1 + 4; + } + if (Dimension.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Dimension); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AggregatedCost other) { + if (other == null) { + return; + } + if (other.Cost != 0F) { + Cost = other.Cost; + } + if (other.Dimension.Length != 0) { + Dimension = other.Dimension; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Cost = input.ReadFloat(); + break; + } + case 18: { + Dimension = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 13: { + Cost = input.ReadFloat(); + break; + } + case 18: { + Dimension = input.ReadString(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs new file mode 100644 index 000000000..c6de97c6b --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs @@ -0,0 +1,1021 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/python/framework/cpp_shape_inference.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/python/framework/cpp_shape_inference.proto + public static partial class CppShapeInferenceReflection { + + #region Descriptor + /// File descriptor for tensorflow/python/framework/cpp_shape_inference.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CppShapeInferenceReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjV0ZW5zb3JmbG93L3B5dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVy", + "ZW5jZS5wcm90bxIKdGVuc29yZmxvdxopdGVuc29yZmxvdy9jb3JlL2ZyYW1l", + "d29yay9mdWxsX3R5cGUucHJvdG8aLHRlbnNvcmZsb3cvY29yZS9mcmFtZXdv", + "cmsvdGVuc29yX3NoYXBlLnByb3RvGiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3", + "b3JrL3R5cGVzLnByb3RvIpsDChdDcHBTaGFwZUluZmVyZW5jZVJlc3VsdBIr", + "CgVzaGFwZRgBIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90bxJD", + "CgtoYW5kbGVfZGF0YRgEIAEoCzIuLnRlbnNvcmZsb3cuQ3BwU2hhcGVJbmZl", + "cmVuY2VSZXN1bHQuSGFuZGxlRGF0YRqTAQoSSGFuZGxlU2hhcGVBbmRUeXBl", + "EisKBXNoYXBlGAEgASgLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3Rv", + "EiMKBWR0eXBlGAIgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZRIlCgR0eXBl", + "GAQgASgLMhcudGVuc29yZmxvdy5GdWxsVHlwZURlZkoECAMQBBpsCgpIYW5k", + "bGVEYXRhEg4KBmlzX3NldBgBIAEoCBJOCg5zaGFwZV9hbmRfdHlwZRgCIAMo", + "CzI2LnRlbnNvcmZsb3cuQ3BwU2hhcGVJbmZlcmVuY2VSZXN1bHQuSGFuZGxl", + "U2hhcGVBbmRUeXBlSgQIAhADSgQIAxAEImUKHUNwcFNoYXBlSW5mZXJlbmNl", + "SW5wdXRzTmVlZGVkEhwKFGlucHV0X3RlbnNvcnNfbmVlZGVkGAEgAygFEiYK", + "HmlucHV0X3RlbnNvcnNfYXNfc2hhcGVzX25lZWRlZBgCIAMoBUJhWlxnaXRo", + "dWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL3B5", + "dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVyZW5jZV9nb19wcm90b/gB", + "AWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.FullTypeReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult), global::Tensorflow.CppShapeInferenceResult.Parser, new[]{ "Shape", "HandleData" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType), global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser, new[]{ "Shape", "Dtype", "Type" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleData), global::Tensorflow.CppShapeInferenceResult.Types.HandleData.Parser, new[]{ "IsSet", "ShapeAndType" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceInputsNeeded), global::Tensorflow.CppShapeInferenceInputsNeeded.Parser, new[]{ "InputTensorsNeeded", "InputTensorsAsShapesNeeded" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class CppShapeInferenceResult : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CppShapeInferenceResult()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceResult() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceResult(CppShapeInferenceResult other) : this() { + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceResult Clone() { + return new CppShapeInferenceResult(this); + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 1; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "handle_data" field. + public const int HandleDataFieldNumber = 4; + private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { + get { return handleData_; } + set { + handleData_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CppShapeInferenceResult); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CppShapeInferenceResult other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Shape, other.Shape)) return false; + if (!object.Equals(HandleData, other.HandleData)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (handleData_ != null) hash ^= HandleData.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (handleData_ != null) { + output.WriteRawTag(34); + output.WriteMessage(HandleData); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (handleData_ != null) { + output.WriteRawTag(34); + output.WriteMessage(HandleData); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (handleData_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(HandleData); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CppShapeInferenceResult other) { + if (other == null) { + return; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.handleData_ != null) { + if (handleData_ == null) { + HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); + } + HandleData.MergeFrom(other.HandleData); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 34: { + if (handleData_ == null) { + HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); + } + input.ReadMessage(HandleData); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 34: { + if (handleData_ == null) { + HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); + } + input.ReadMessage(HandleData); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the CppShapeInferenceResult message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class HandleShapeAndType : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HandleShapeAndType()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleShapeAndType() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleShapeAndType(HandleShapeAndType other) : this() { + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + dtype_ = other.dtype_; + type_ = other.type_ != null ? other.type_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleShapeAndType Clone() { + return new HandleShapeAndType(this); + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 1; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 2; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 4; + private global::Tensorflow.FullTypeDef type_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FullTypeDef Type { + get { return type_; } + set { + type_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HandleShapeAndType); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HandleShapeAndType other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Shape, other.Shape)) return false; + if (Dtype != other.Dtype) return false; + if (!object.Equals(Type, other.Type)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (type_ != null) hash ^= Type.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Dtype); + } + if (type_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Type); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Dtype); + } + if (type_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Type); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (type_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Type); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HandleShapeAndType other) { + if (other == null) { + return; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.type_ != null) { + if (type_ == null) { + Type = new global::Tensorflow.FullTypeDef(); + } + Type.MergeFrom(other.Type); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 16: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + if (type_ == null) { + Type = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(Type); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 16: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + if (type_ == null) { + Type = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(Type); + break; + } + } + } + } + #endif + + } + + public sealed partial class HandleData : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HandleData()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleData() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleData(HandleData other) : this() { + isSet_ = other.isSet_; + shapeAndType_ = other.shapeAndType_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HandleData Clone() { + return new HandleData(this); + } + + /// Field number for the "is_set" field. + public const int IsSetFieldNumber = 1; + private bool isSet_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsSet { + get { return isSet_; } + set { + isSet_ = value; + } + } + + /// Field number for the "shape_and_type" field. + public const int ShapeAndTypeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_shapeAndType_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser); + private readonly pbc::RepeatedField shapeAndType_ = new pbc::RepeatedField(); + /// + /// Only valid if <is_set>. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ShapeAndType { + get { return shapeAndType_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HandleData); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HandleData other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (IsSet != other.IsSet) return false; + if(!shapeAndType_.Equals(other.shapeAndType_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (IsSet != false) hash ^= IsSet.GetHashCode(); + hash ^= shapeAndType_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (IsSet != false) { + output.WriteRawTag(8); + output.WriteBool(IsSet); + } + shapeAndType_.WriteTo(output, _repeated_shapeAndType_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (IsSet != false) { + output.WriteRawTag(8); + output.WriteBool(IsSet); + } + shapeAndType_.WriteTo(ref output, _repeated_shapeAndType_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (IsSet != false) { + size += 1 + 1; + } + size += shapeAndType_.CalculateSize(_repeated_shapeAndType_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HandleData other) { + if (other == null) { + return; + } + if (other.IsSet != false) { + IsSet = other.IsSet; + } + shapeAndType_.Add(other.shapeAndType_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + IsSet = input.ReadBool(); + break; + } + case 18: { + shapeAndType_.AddEntriesFrom(input, _repeated_shapeAndType_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + IsSet = input.ReadBool(); + break; + } + case 18: { + shapeAndType_.AddEntriesFrom(ref input, _repeated_shapeAndType_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CppShapeInferenceInputsNeeded()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceInputsNeeded() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() { + inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone(); + inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CppShapeInferenceInputsNeeded Clone() { + return new CppShapeInferenceInputsNeeded(this); + } + + /// Field number for the "input_tensors_needed" field. + public const int InputTensorsNeededFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_inputTensorsNeeded_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField inputTensorsNeeded_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InputTensorsNeeded { + get { return inputTensorsNeeded_; } + } + + /// Field number for the "input_tensors_as_shapes_needed" field. + public const int InputTensorsAsShapesNeededFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_inputTensorsAsShapesNeeded_codec + = pb::FieldCodec.ForInt32(18); + private readonly pbc::RepeatedField inputTensorsAsShapesNeeded_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InputTensorsAsShapesNeeded { + get { return inputTensorsAsShapesNeeded_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CppShapeInferenceInputsNeeded); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CppShapeInferenceInputsNeeded other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!inputTensorsNeeded_.Equals(other.inputTensorsNeeded_)) return false; + if(!inputTensorsAsShapesNeeded_.Equals(other.inputTensorsAsShapesNeeded_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= inputTensorsNeeded_.GetHashCode(); + hash ^= inputTensorsAsShapesNeeded_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec); + inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + inputTensorsNeeded_.WriteTo(ref output, _repeated_inputTensorsNeeded_codec); + inputTensorsAsShapesNeeded_.WriteTo(ref output, _repeated_inputTensorsAsShapesNeeded_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec); + size += inputTensorsAsShapesNeeded_.CalculateSize(_repeated_inputTensorsAsShapesNeeded_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CppShapeInferenceInputsNeeded other) { + if (other == null) { + return; + } + inputTensorsNeeded_.Add(other.inputTensorsNeeded_); + inputTensorsAsShapesNeeded_.Add(other.inputTensorsAsShapesNeeded_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + inputTensorsNeeded_.AddEntriesFrom(input, _repeated_inputTensorsNeeded_codec); + break; + } + case 18: + case 16: { + inputTensorsAsShapesNeeded_.AddEntriesFrom(input, _repeated_inputTensorsAsShapesNeeded_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + inputTensorsNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsNeeded_codec); + break; + } + case 18: + case 16: { + inputTensorsAsShapesNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsAsShapesNeeded_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/DataService.cs b/src/TensorFlowNET.Core/Protobuf/DataService.cs new file mode 100644 index 000000000..ca59a471d --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/DataService.cs @@ -0,0 +1,1041 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/data_service.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Data { + + /// Holder for reflection information generated from tensorflow/core/protobuf/data_service.proto + public static partial class DataServiceReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/data_service.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static DataServiceReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cit0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvZGF0YV9zZXJ2aWNlLnByb3Rv", + "Eg90ZW5zb3JmbG93LmRhdGEitwEKEVByb2Nlc3NpbmdNb2RlRGVmEkoKD3No", + "YXJkaW5nX3BvbGljeRgBIAEoDjIxLnRlbnNvcmZsb3cuZGF0YS5Qcm9jZXNz", + "aW5nTW9kZURlZi5TaGFyZGluZ1BvbGljeSJWCg5TaGFyZGluZ1BvbGljeRIH", + "CgNPRkYQABILCgdEWU5BTUlDEAESCAoERklMRRACEggKBERBVEEQAxIQCgxG", + "SUxFX09SX0RBVEEQBBIICgRISU5UEAUi+wEKE0RhdGFTZXJ2aWNlTWV0YWRh", + "dGESFgoMZWxlbWVudF9zcGVjGAEgASgMSAASRQoLY29tcHJlc3Npb24YAiAB", + "KA4yMC50ZW5zb3JmbG93LmRhdGEuRGF0YVNlcnZpY2VNZXRhZGF0YS5Db21w", + "cmVzc2lvbhITCgtjYXJkaW5hbGl0eRgDIAEoAyJXCgtDb21wcmVzc2lvbhIb", + "ChdDT01QUkVTU0lPTl9VTlNQRUNJRklFRBAAEhMKD0NPTVBSRVNTSU9OX09G", + "RhABEhYKEkNPTVBSRVNTSU9OX1NOQVBQWRACQhcKFW9wdGlvbmFsX2VsZW1l", + "bnRfc3BlYyIuChhDcm9zc1RyYWluZXJDYWNoZU9wdGlvbnMSEgoKdHJhaW5l", + "cl9pZBgBIAEoCSJNChFEYXRhU2VydmljZUNvbmZpZxI4Cg9kZXBsb3ltZW50", + "X21vZGUYASABKA4yHy50ZW5zb3JmbG93LmRhdGEuRGVwbG95bWVudE1vZGUq", + "iAEKDkRlcGxveW1lbnRNb2RlEh8KG0RFUExPWU1FTlRfTU9ERV9VTlNQRUNJ", + "RklFRBAAEh0KGURFUExPWU1FTlRfTU9ERV9DT0xPQ0FURUQQARIaChZERVBM", + "T1lNRU5UX01PREVfUkVNT1RFEAISGgoWREVQTE9ZTUVOVF9NT0RFX0hZQlJJ", + "RBADQldaVWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNv", + "cmZsb3cvZ28vY29yZS9wcm90b2J1Zi9mb3JfY29yZV9wcm90b3NfZ29fcHJv", + "dG9iBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.Data.DeploymentMode), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.ProcessingModeDef), global::Tensorflow.Data.ProcessingModeDef.Parser, new[]{ "ShardingPolicy" }, null, new[]{ typeof(global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.DataServiceMetadata), global::Tensorflow.Data.DataServiceMetadata.Parser, new[]{ "ElementSpec", "Compression", "Cardinality" }, new[]{ "OptionalElementSpec" }, new[]{ typeof(global::Tensorflow.Data.DataServiceMetadata.Types.Compression) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.CrossTrainerCacheOptions), global::Tensorflow.Data.CrossTrainerCacheOptions.Parser, new[]{ "TrainerId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.DataServiceConfig), global::Tensorflow.Data.DataServiceConfig.Parser, new[]{ "DeploymentMode" }, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// tf.data service deployment mode. + /// + public enum DeploymentMode { + [pbr::OriginalName("DEPLOYMENT_MODE_UNSPECIFIED")] Unspecified = 0, + /// + /// tf.data service workers colocate with TF workers. + /// + [pbr::OriginalName("DEPLOYMENT_MODE_COLOCATED")] Colocated = 1, + /// + /// tf.data service workers run in dedicated tf.data hosts. + /// + [pbr::OriginalName("DEPLOYMENT_MODE_REMOTE")] Remote = 2, + /// + /// tf.data service workers run in colocated TF hosts and dedicated tf.data + /// hosts. + /// + [pbr::OriginalName("DEPLOYMENT_MODE_HYBRID")] Hybrid = 3, + } + + #endregion + + #region Messages + /// + /// Next tag: 2 + /// + public sealed partial class ProcessingModeDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ProcessingModeDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.DataServiceReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProcessingModeDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProcessingModeDef(ProcessingModeDef other) : this() { + shardingPolicy_ = other.shardingPolicy_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProcessingModeDef Clone() { + return new ProcessingModeDef(this); + } + + /// Field number for the "sharding_policy" field. + public const int ShardingPolicyFieldNumber = 1; + private global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy shardingPolicy_ = global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy ShardingPolicy { + get { return shardingPolicy_; } + set { + shardingPolicy_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ProcessingModeDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ProcessingModeDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ShardingPolicy != other.ShardingPolicy) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ShardingPolicy != global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off) hash ^= ShardingPolicy.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ShardingPolicy != global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off) { + output.WriteRawTag(8); + output.WriteEnum((int) ShardingPolicy); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ShardingPolicy != global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off) { + output.WriteRawTag(8); + output.WriteEnum((int) ShardingPolicy); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ShardingPolicy != global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ShardingPolicy); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ProcessingModeDef other) { + if (other == null) { + return; + } + if (other.ShardingPolicy != global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy.Off) { + ShardingPolicy = other.ShardingPolicy; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ShardingPolicy = (global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ShardingPolicy = (global::Tensorflow.Data.ProcessingModeDef.Types.ShardingPolicy) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the ProcessingModeDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Specifies how data is sharded among tf.data service workers. + /// + public enum ShardingPolicy { + /// + /// No sharding will be performed. Each worker produces the entire dataset + /// without any sharding. With this mode, the best practice is to shuffle the + /// dataset nondeterministically so that workers process the dataset in + /// different orders. + /// + [pbr::OriginalName("OFF")] Off = 0, + /// + /// The input dataset is dynamically split among workers at runtime. Each + /// worker gets the next split when it reads data from the dispatcher. There + /// is no fixed sharding with this mode. + /// + [pbr::OriginalName("DYNAMIC")] Dynamic = 1, + /// + /// The following are static sharding policies. The semantics are similar to + /// `tf.data.experimental.AutoShardPolicy`. These policies require: + /// * The tf.data service cluster has a fixed size, and you need to specify + /// the workers in DispatcherConfig. + /// * Each client only reads from the local tf.data service worker. + /// + /// Shards by input files (each worker will get a set of files to process). + /// When this option is selected, make sure that there is at least as many + /// files as workers. If there are fewer input files than workers, a runtime + /// error will be raised. + /// + [pbr::OriginalName("FILE")] File = 2, + /// + /// Shards by elements produced by the dataset. Each worker will process the + /// whole dataset and discard the portion that is not for itself. Note that + /// for this mode to correctly partitions the dataset elements, the dataset + /// needs to produce elements in a deterministic order. + /// + [pbr::OriginalName("DATA")] Data = 3, + /// + /// Attempts FILE-based sharding, falling back to DATA-based sharding on + /// failures. + /// + [pbr::OriginalName("FILE_OR_DATA")] FileOrData = 4, + /// + /// Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a + /// placeholder to replace with `shard(num_workers, worker_index)`. + /// + [pbr::OriginalName("HINT")] Hint = 5, + } + + } + #endregion + + } + + /// + /// Metadata related to tf.data service datasets. + /// Next tag: 4 + /// + public sealed partial class DataServiceMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DataServiceMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.DataServiceReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceMetadata(DataServiceMetadata other) : this() { + compression_ = other.compression_; + cardinality_ = other.cardinality_; + switch (other.OptionalElementSpecCase) { + case OptionalElementSpecOneofCase.ElementSpec: + ElementSpec = other.ElementSpec; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceMetadata Clone() { + return new DataServiceMetadata(this); + } + + /// Field number for the "element_spec" field. + public const int ElementSpecFieldNumber = 1; + /// + /// Serialized element spec. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString ElementSpec { + get { return optionalElementSpecCase_ == OptionalElementSpecOneofCase.ElementSpec ? (pb::ByteString) optionalElementSpec_ : pb::ByteString.Empty; } + set { + optionalElementSpec_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + optionalElementSpecCase_ = OptionalElementSpecOneofCase.ElementSpec; + } + } + + /// Field number for the "compression" field. + public const int CompressionFieldNumber = 2; + private global::Tensorflow.Data.DataServiceMetadata.Types.Compression compression_ = global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.Data.DataServiceMetadata.Types.Compression Compression { + get { return compression_; } + set { + compression_ = value; + } + } + + /// Field number for the "cardinality" field. + public const int CardinalityFieldNumber = 3; + private long cardinality_; + /// + /// Cardinality of the dataset. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Cardinality { + get { return cardinality_; } + set { + cardinality_ = value; + } + } + + private object optionalElementSpec_; + /// Enum of possible cases for the "optional_element_spec" oneof. + public enum OptionalElementSpecOneofCase { + None = 0, + ElementSpec = 1, + } + private OptionalElementSpecOneofCase optionalElementSpecCase_ = OptionalElementSpecOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalElementSpecOneofCase OptionalElementSpecCase { + get { return optionalElementSpecCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearOptionalElementSpec() { + optionalElementSpecCase_ = OptionalElementSpecOneofCase.None; + optionalElementSpec_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DataServiceMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DataServiceMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ElementSpec != other.ElementSpec) return false; + if (Compression != other.Compression) return false; + if (Cardinality != other.Cardinality) return false; + if (OptionalElementSpecCase != other.OptionalElementSpecCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (optionalElementSpecCase_ == OptionalElementSpecOneofCase.ElementSpec) hash ^= ElementSpec.GetHashCode(); + if (Compression != global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified) hash ^= Compression.GetHashCode(); + if (Cardinality != 0L) hash ^= Cardinality.GetHashCode(); + hash ^= (int) optionalElementSpecCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (optionalElementSpecCase_ == OptionalElementSpecOneofCase.ElementSpec) { + output.WriteRawTag(10); + output.WriteBytes(ElementSpec); + } + if (Compression != global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified) { + output.WriteRawTag(16); + output.WriteEnum((int) Compression); + } + if (Cardinality != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Cardinality); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (optionalElementSpecCase_ == OptionalElementSpecOneofCase.ElementSpec) { + output.WriteRawTag(10); + output.WriteBytes(ElementSpec); + } + if (Compression != global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified) { + output.WriteRawTag(16); + output.WriteEnum((int) Compression); + } + if (Cardinality != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Cardinality); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (optionalElementSpecCase_ == OptionalElementSpecOneofCase.ElementSpec) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(ElementSpec); + } + if (Compression != global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Compression); + } + if (Cardinality != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Cardinality); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DataServiceMetadata other) { + if (other == null) { + return; + } + if (other.Compression != global::Tensorflow.Data.DataServiceMetadata.Types.Compression.Unspecified) { + Compression = other.Compression; + } + if (other.Cardinality != 0L) { + Cardinality = other.Cardinality; + } + switch (other.OptionalElementSpecCase) { + case OptionalElementSpecOneofCase.ElementSpec: + ElementSpec = other.ElementSpec; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ElementSpec = input.ReadBytes(); + break; + } + case 16: { + Compression = (global::Tensorflow.Data.DataServiceMetadata.Types.Compression) input.ReadEnum(); + break; + } + case 24: { + Cardinality = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ElementSpec = input.ReadBytes(); + break; + } + case 16: { + Compression = (global::Tensorflow.Data.DataServiceMetadata.Types.Compression) input.ReadEnum(); + break; + } + case 24: { + Cardinality = input.ReadInt64(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the DataServiceMetadata message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Compression { + [pbr::OriginalName("COMPRESSION_UNSPECIFIED")] Unspecified = 0, + /// + /// No compression. + /// + [pbr::OriginalName("COMPRESSION_OFF")] Off = 1, + /// + /// Snappy compression as defined in tensorflow/core/platform/snappy.h. + /// + [pbr::OriginalName("COMPRESSION_SNAPPY")] Snappy = 2, + } + + } + #endregion + + } + + public sealed partial class CrossTrainerCacheOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CrossTrainerCacheOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.DataServiceReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossTrainerCacheOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossTrainerCacheOptions(CrossTrainerCacheOptions other) : this() { + trainerId_ = other.trainerId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossTrainerCacheOptions Clone() { + return new CrossTrainerCacheOptions(this); + } + + /// Field number for the "trainer_id" field. + public const int TrainerIdFieldNumber = 1; + private string trainerId_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TrainerId { + get { return trainerId_; } + set { + trainerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CrossTrainerCacheOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CrossTrainerCacheOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TrainerId != other.TrainerId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TrainerId.Length != 0) hash ^= TrainerId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TrainerId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TrainerId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TrainerId.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TrainerId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TrainerId.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TrainerId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CrossTrainerCacheOptions other) { + if (other == null) { + return; + } + if (other.TrainerId.Length != 0) { + TrainerId = other.TrainerId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + TrainerId = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + TrainerId = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Data service config available to the client through GetDataServiceConfig RPC. + /// Next tag: 2 + /// + public sealed partial class DataServiceConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DataServiceConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.DataServiceReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceConfig(DataServiceConfig other) : this() { + deploymentMode_ = other.deploymentMode_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DataServiceConfig Clone() { + return new DataServiceConfig(this); + } + + /// Field number for the "deployment_mode" field. + public const int DeploymentModeFieldNumber = 1; + private global::Tensorflow.Data.DeploymentMode deploymentMode_ = global::Tensorflow.Data.DeploymentMode.Unspecified; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.Data.DeploymentMode DeploymentMode { + get { return deploymentMode_; } + set { + deploymentMode_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DataServiceConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DataServiceConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DeploymentMode != other.DeploymentMode) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) hash ^= DeploymentMode.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + output.WriteRawTag(8); + output.WriteEnum((int) DeploymentMode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + output.WriteRawTag(8); + output.WriteEnum((int) DeploymentMode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DeploymentMode); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DataServiceConfig other) { + if (other == null) { + return; + } + if (other.DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + DeploymentMode = other.DeploymentMode; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + DeploymentMode = (global::Tensorflow.Data.DeploymentMode) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + DeploymentMode = (global::Tensorflow.Data.DeploymentMode) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Debug.cs b/src/TensorFlowNET.Core/Protobuf/Debug.cs new file mode 100644 index 000000000..85b3bc6cc --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Debug.cs @@ -0,0 +1,1211 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/debug.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/debug.proto + public static partial class DebugReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/debug.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static DebugReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiR0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvZGVidWcucHJvdG8SCnRlbnNv", + "cmZsb3cijgEKEERlYnVnVGVuc29yV2F0Y2gSEQoJbm9kZV9uYW1lGAEgASgJ", + "EhMKC291dHB1dF9zbG90GAIgASgFEhEKCWRlYnVnX29wcxgDIAMoCRISCgpk", + "ZWJ1Z191cmxzGAQgAygJEisKI3RvbGVyYXRlX2RlYnVnX29wX2NyZWF0aW9u", + "X2ZhaWx1cmVzGAUgASgIIoEBCgxEZWJ1Z09wdGlvbnMSPQoXZGVidWdfdGVu", + "c29yX3dhdGNoX29wdHMYBCADKAsyHC50ZW5zb3JmbG93LkRlYnVnVGVuc29y", + "V2F0Y2gSEwoLZ2xvYmFsX3N0ZXAYCiABKAMSHQoVcmVzZXRfZGlza19ieXRl", + "X3VzYWdlGAsgASgIImoKEkRlYnVnZ2VkU291cmNlRmlsZRIMCgRob3N0GAEg", + "ASgJEhEKCWZpbGVfcGF0aBgCIAEoCRIVCg1sYXN0X21vZGlmaWVkGAMgASgD", + "Eg0KBWJ5dGVzGAQgASgDEg0KBWxpbmVzGAUgAygJIksKE0RlYnVnZ2VkU291", + "cmNlRmlsZXMSNAoMc291cmNlX2ZpbGVzGAEgAygLMh4udGVuc29yZmxvdy5E", + "ZWJ1Z2dlZFNvdXJjZUZpbGVCgwEKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29y", + "a0ILRGVidWdQcm90b3NQAVpVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3JlX3By", + "b3Rvc19nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DebugTensorWatch), global::Tensorflow.DebugTensorWatch.Parser, new[]{ "NodeName", "OutputSlot", "DebugOps", "DebugUrls", "TolerateDebugOpCreationFailures" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DebugOptions), global::Tensorflow.DebugOptions.Parser, new[]{ "DebugTensorWatchOpts", "GlobalStep", "ResetDiskByteUsage" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DebuggedSourceFile), global::Tensorflow.DebuggedSourceFile.Parser, new[]{ "Host", "FilePath", "LastModified", "Bytes", "Lines" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DebuggedSourceFiles), global::Tensorflow.DebuggedSourceFiles.Parser, new[]{ "SourceFiles" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Option for watching a node in TensorFlow Debugger (tfdbg). + /// + public sealed partial class DebugTensorWatch : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DebugTensorWatch()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugTensorWatch() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugTensorWatch(DebugTensorWatch other) : this() { + nodeName_ = other.nodeName_; + outputSlot_ = other.outputSlot_; + debugOps_ = other.debugOps_.Clone(); + debugUrls_ = other.debugUrls_.Clone(); + tolerateDebugOpCreationFailures_ = other.tolerateDebugOpCreationFailures_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugTensorWatch Clone() { + return new DebugTensorWatch(this); + } + + /// Field number for the "node_name" field. + public const int NodeNameFieldNumber = 1; + private string nodeName_ = ""; + /// + /// Name of the node to watch. + /// Use "*" for wildcard. But note: currently, regex is not supported in + /// general. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string NodeName { + get { return nodeName_; } + set { + nodeName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "output_slot" field. + public const int OutputSlotFieldNumber = 2; + private int outputSlot_; + /// + /// Output slot to watch. + /// The semantics of output_slot == -1 is that all outputs of the node + /// will be watched (i.e., a wildcard). + /// Other negative values of output_slot are invalid and will lead to + /// errors currently. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int OutputSlot { + get { return outputSlot_; } + set { + outputSlot_ = value; + } + } + + /// Field number for the "debug_ops" field. + public const int DebugOpsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_debugOps_codec + = pb::FieldCodec.ForString(26); + private readonly pbc::RepeatedField debugOps_ = new pbc::RepeatedField(); + /// + /// Name(s) of the debugging op(s). + /// One or more than one probes on a tensor. + /// e.g., {"DebugIdentity", "DebugNanCount"} + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DebugOps { + get { return debugOps_; } + } + + /// Field number for the "debug_urls" field. + public const int DebugUrlsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_debugUrls_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField debugUrls_ = new pbc::RepeatedField(); + /// + /// URL(s) for debug targets(s). + /// + /// Supported URL formats are: + /// - file:///foo/tfdbg_dump: Writes out Event content to file + /// /foo/tfdbg_dump. Assumes all directories can be created if they don't + /// already exist. + /// - grpc://localhost:11011: Sends an RPC request to an EventListener + /// service running at localhost:11011 with the event. + /// - memcbk:///event_key: Routes tensors to clients using the + /// callback registered with the DebugCallbackRegistry for event_key. + /// + /// Each debug op listed in debug_ops will publish its output tensor (debug + /// signal) to all URLs in debug_urls. + /// + /// N.B. Session::Run() supports concurrent invocations of the same inputs + /// (feed keys), outputs and target nodes. If such concurrent invocations + /// are to be debugged, the callers of Session::Run() must use distinct + /// debug_urls to make sure that the streamed or dumped events do not overlap + /// among the invocations. + /// TODO(cais): More visible documentation of this in g3docs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DebugUrls { + get { return debugUrls_; } + } + + /// Field number for the "tolerate_debug_op_creation_failures" field. + public const int TolerateDebugOpCreationFailuresFieldNumber = 5; + private bool tolerateDebugOpCreationFailures_; + /// + /// Do not error out if debug op creation fails (e.g., due to dtype + /// incompatibility). Instead, just log the failure. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool TolerateDebugOpCreationFailures { + get { return tolerateDebugOpCreationFailures_; } + set { + tolerateDebugOpCreationFailures_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DebugTensorWatch); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DebugTensorWatch other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeName != other.NodeName) return false; + if (OutputSlot != other.OutputSlot) return false; + if(!debugOps_.Equals(other.debugOps_)) return false; + if(!debugUrls_.Equals(other.debugUrls_)) return false; + if (TolerateDebugOpCreationFailures != other.TolerateDebugOpCreationFailures) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); + if (OutputSlot != 0) hash ^= OutputSlot.GetHashCode(); + hash ^= debugOps_.GetHashCode(); + hash ^= debugUrls_.GetHashCode(); + if (TolerateDebugOpCreationFailures != false) hash ^= TolerateDebugOpCreationFailures.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(NodeName); + } + if (OutputSlot != 0) { + output.WriteRawTag(16); + output.WriteInt32(OutputSlot); + } + debugOps_.WriteTo(output, _repeated_debugOps_codec); + debugUrls_.WriteTo(output, _repeated_debugUrls_codec); + if (TolerateDebugOpCreationFailures != false) { + output.WriteRawTag(40); + output.WriteBool(TolerateDebugOpCreationFailures); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(NodeName); + } + if (OutputSlot != 0) { + output.WriteRawTag(16); + output.WriteInt32(OutputSlot); + } + debugOps_.WriteTo(ref output, _repeated_debugOps_codec); + debugUrls_.WriteTo(ref output, _repeated_debugUrls_codec); + if (TolerateDebugOpCreationFailures != false) { + output.WriteRawTag(40); + output.WriteBool(TolerateDebugOpCreationFailures); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(NodeName); + } + if (OutputSlot != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(OutputSlot); + } + size += debugOps_.CalculateSize(_repeated_debugOps_codec); + size += debugUrls_.CalculateSize(_repeated_debugUrls_codec); + if (TolerateDebugOpCreationFailures != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DebugTensorWatch other) { + if (other == null) { + return; + } + if (other.NodeName.Length != 0) { + NodeName = other.NodeName; + } + if (other.OutputSlot != 0) { + OutputSlot = other.OutputSlot; + } + debugOps_.Add(other.debugOps_); + debugUrls_.Add(other.debugUrls_); + if (other.TolerateDebugOpCreationFailures != false) { + TolerateDebugOpCreationFailures = other.TolerateDebugOpCreationFailures; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + NodeName = input.ReadString(); + break; + } + case 16: { + OutputSlot = input.ReadInt32(); + break; + } + case 26: { + debugOps_.AddEntriesFrom(input, _repeated_debugOps_codec); + break; + } + case 34: { + debugUrls_.AddEntriesFrom(input, _repeated_debugUrls_codec); + break; + } + case 40: { + TolerateDebugOpCreationFailures = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + NodeName = input.ReadString(); + break; + } + case 16: { + OutputSlot = input.ReadInt32(); + break; + } + case 26: { + debugOps_.AddEntriesFrom(ref input, _repeated_debugOps_codec); + break; + } + case 34: { + debugUrls_.AddEntriesFrom(ref input, _repeated_debugUrls_codec); + break; + } + case 40: { + TolerateDebugOpCreationFailures = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + /// + /// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). + /// + public sealed partial class DebugOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DebugOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions(DebugOptions other) : this() { + debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone(); + globalStep_ = other.globalStep_; + resetDiskByteUsage_ = other.resetDiskByteUsage_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions Clone() { + return new DebugOptions(this); + } + + /// Field number for the "debug_tensor_watch_opts" field. + public const int DebugTensorWatchOptsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_debugTensorWatchOpts_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.DebugTensorWatch.Parser); + private readonly pbc::RepeatedField debugTensorWatchOpts_ = new pbc::RepeatedField(); + /// + /// Debugging options + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DebugTensorWatchOpts { + get { return debugTensorWatchOpts_; } + } + + /// Field number for the "global_step" field. + public const int GlobalStepFieldNumber = 10; + private long globalStep_; + /// + /// Caller-specified global step count. + /// Note that this is distinct from the session run count and the executor + /// step count. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long GlobalStep { + get { return globalStep_; } + set { + globalStep_ = value; + } + } + + /// Field number for the "reset_disk_byte_usage" field. + public const int ResetDiskByteUsageFieldNumber = 11; + private bool resetDiskByteUsage_; + /// + /// Whether the total disk usage of tfdbg is to be reset to zero + /// in this Session.run call. This is used by wrappers and hooks + /// such as the local CLI ones to indicate that the dumped tensors + /// are cleaned up from the disk after each Session.run. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ResetDiskByteUsage { + get { return resetDiskByteUsage_; } + set { + resetDiskByteUsage_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DebugOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DebugOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!debugTensorWatchOpts_.Equals(other.debugTensorWatchOpts_)) return false; + if (GlobalStep != other.GlobalStep) return false; + if (ResetDiskByteUsage != other.ResetDiskByteUsage) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= debugTensorWatchOpts_.GetHashCode(); + if (GlobalStep != 0L) hash ^= GlobalStep.GetHashCode(); + if (ResetDiskByteUsage != false) hash ^= ResetDiskByteUsage.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec); + if (GlobalStep != 0L) { + output.WriteRawTag(80); + output.WriteInt64(GlobalStep); + } + if (ResetDiskByteUsage != false) { + output.WriteRawTag(88); + output.WriteBool(ResetDiskByteUsage); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + debugTensorWatchOpts_.WriteTo(ref output, _repeated_debugTensorWatchOpts_codec); + if (GlobalStep != 0L) { + output.WriteRawTag(80); + output.WriteInt64(GlobalStep); + } + if (ResetDiskByteUsage != false) { + output.WriteRawTag(88); + output.WriteBool(ResetDiskByteUsage); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec); + if (GlobalStep != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(GlobalStep); + } + if (ResetDiskByteUsage != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DebugOptions other) { + if (other == null) { + return; + } + debugTensorWatchOpts_.Add(other.debugTensorWatchOpts_); + if (other.GlobalStep != 0L) { + GlobalStep = other.GlobalStep; + } + if (other.ResetDiskByteUsage != false) { + ResetDiskByteUsage = other.ResetDiskByteUsage; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 34: { + debugTensorWatchOpts_.AddEntriesFrom(input, _repeated_debugTensorWatchOpts_codec); + break; + } + case 80: { + GlobalStep = input.ReadInt64(); + break; + } + case 88: { + ResetDiskByteUsage = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 34: { + debugTensorWatchOpts_.AddEntriesFrom(ref input, _repeated_debugTensorWatchOpts_codec); + break; + } + case 80: { + GlobalStep = input.ReadInt64(); + break; + } + case 88: { + ResetDiskByteUsage = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + public sealed partial class DebuggedSourceFile : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DebuggedSourceFile()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFile() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFile(DebuggedSourceFile other) : this() { + host_ = other.host_; + filePath_ = other.filePath_; + lastModified_ = other.lastModified_; + bytes_ = other.bytes_; + lines_ = other.lines_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFile Clone() { + return new DebuggedSourceFile(this); + } + + /// Field number for the "host" field. + public const int HostFieldNumber = 1; + private string host_ = ""; + /// + /// The host name on which a source code file is located. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Host { + get { return host_; } + set { + host_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "file_path" field. + public const int FilePathFieldNumber = 2; + private string filePath_ = ""; + /// + /// Path to the source code file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FilePath { + get { return filePath_; } + set { + filePath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "last_modified" field. + public const int LastModifiedFieldNumber = 3; + private long lastModified_; + /// + /// The timestamp at which the source code file is last modified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long LastModified { + get { return lastModified_; } + set { + lastModified_ = value; + } + } + + /// Field number for the "bytes" field. + public const int BytesFieldNumber = 4; + private long bytes_; + /// + /// Byte size of the file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Bytes { + get { return bytes_; } + set { + bytes_ = value; + } + } + + /// Field number for the "lines" field. + public const int LinesFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_lines_codec + = pb::FieldCodec.ForString(42); + private readonly pbc::RepeatedField lines_ = new pbc::RepeatedField(); + /// + /// Line-by-line content of the source code file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Lines { + get { return lines_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DebuggedSourceFile); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DebuggedSourceFile other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Host != other.Host) return false; + if (FilePath != other.FilePath) return false; + if (LastModified != other.LastModified) return false; + if (Bytes != other.Bytes) return false; + if(!lines_.Equals(other.lines_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Host.Length != 0) hash ^= Host.GetHashCode(); + if (FilePath.Length != 0) hash ^= FilePath.GetHashCode(); + if (LastModified != 0L) hash ^= LastModified.GetHashCode(); + if (Bytes != 0L) hash ^= Bytes.GetHashCode(); + hash ^= lines_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Host.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Host); + } + if (FilePath.Length != 0) { + output.WriteRawTag(18); + output.WriteString(FilePath); + } + if (LastModified != 0L) { + output.WriteRawTag(24); + output.WriteInt64(LastModified); + } + if (Bytes != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Bytes); + } + lines_.WriteTo(output, _repeated_lines_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Host.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Host); + } + if (FilePath.Length != 0) { + output.WriteRawTag(18); + output.WriteString(FilePath); + } + if (LastModified != 0L) { + output.WriteRawTag(24); + output.WriteInt64(LastModified); + } + if (Bytes != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Bytes); + } + lines_.WriteTo(ref output, _repeated_lines_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Host.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Host); + } + if (FilePath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FilePath); + } + if (LastModified != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(LastModified); + } + if (Bytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Bytes); + } + size += lines_.CalculateSize(_repeated_lines_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DebuggedSourceFile other) { + if (other == null) { + return; + } + if (other.Host.Length != 0) { + Host = other.Host; + } + if (other.FilePath.Length != 0) { + FilePath = other.FilePath; + } + if (other.LastModified != 0L) { + LastModified = other.LastModified; + } + if (other.Bytes != 0L) { + Bytes = other.Bytes; + } + lines_.Add(other.lines_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Host = input.ReadString(); + break; + } + case 18: { + FilePath = input.ReadString(); + break; + } + case 24: { + LastModified = input.ReadInt64(); + break; + } + case 32: { + Bytes = input.ReadInt64(); + break; + } + case 42: { + lines_.AddEntriesFrom(input, _repeated_lines_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Host = input.ReadString(); + break; + } + case 18: { + FilePath = input.ReadString(); + break; + } + case 24: { + LastModified = input.ReadInt64(); + break; + } + case 32: { + Bytes = input.ReadInt64(); + break; + } + case 42: { + lines_.AddEntriesFrom(ref input, _repeated_lines_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class DebuggedSourceFiles : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DebuggedSourceFiles()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFiles() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFiles(DebuggedSourceFiles other) : this() { + sourceFiles_ = other.sourceFiles_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebuggedSourceFiles Clone() { + return new DebuggedSourceFiles(this); + } + + /// Field number for the "source_files" field. + public const int SourceFilesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_sourceFiles_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.DebuggedSourceFile.Parser); + private readonly pbc::RepeatedField sourceFiles_ = new pbc::RepeatedField(); + /// + /// A collection of source code files. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SourceFiles { + get { return sourceFiles_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DebuggedSourceFiles); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DebuggedSourceFiles other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!sourceFiles_.Equals(other.sourceFiles_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= sourceFiles_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + sourceFiles_.WriteTo(ref output, _repeated_sourceFiles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DebuggedSourceFiles other) { + if (other == null) { + return; + } + sourceFiles_.Add(other.sourceFiles_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + sourceFiles_.AddEntriesFrom(input, _repeated_sourceFiles_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + sourceFiles_.AddEntriesFrom(ref input, _repeated_sourceFiles_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/DeviceAttributes.cs b/src/TensorFlowNET.Core/Protobuf/DeviceAttributes.cs new file mode 100644 index 000000000..81d17e932 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/DeviceAttributes.cs @@ -0,0 +1,1227 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/device_attributes.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/device_attributes.proto + public static partial class DeviceAttributesReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/device_attributes.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static DeviceAttributesReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjF0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVz", + "LnByb3RvEgp0ZW5zb3JmbG93IkUKEEludGVyY29ubmVjdExpbmsSEQoJZGV2", + "aWNlX2lkGAEgASgFEgwKBHR5cGUYAiABKAkSEAoIc3RyZW5ndGgYAyABKAUi", + "OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl", + "cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo", + "BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm", + "bG93LkxvY2FsTGlua3MiwwEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", + "IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB", + "KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs", + "aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k", + "ZXNjGAcgASgJEhUKDXhsYV9nbG9iYWxfaWQYCCABKANCkQEKGG9yZy50ZW5z", + "b3JmbG93LmZyYW1ld29ya0IWRGV2aWNlQXR0cmlidXRlc1Byb3Rvc1ABWlhn", + "aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", + "L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVzX2dvX3Byb3Rv+AEB", + "YgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc", "XlaGlobalId" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class InterconnectLink : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InterconnectLink()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InterconnectLink() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InterconnectLink(InterconnectLink other) : this() { + deviceId_ = other.deviceId_; + type_ = other.type_; + strength_ = other.strength_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InterconnectLink Clone() { + return new InterconnectLink(this); + } + + /// Field number for the "device_id" field. + public const int DeviceIdFieldNumber = 1; + private int deviceId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int DeviceId { + get { return deviceId_; } + set { + deviceId_ = value; + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 2; + private string type_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Type { + get { return type_; } + set { + type_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "strength" field. + public const int StrengthFieldNumber = 3; + private int strength_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Strength { + get { return strength_; } + set { + strength_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as InterconnectLink); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(InterconnectLink other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DeviceId != other.DeviceId) return false; + if (Type != other.Type) return false; + if (Strength != other.Strength) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DeviceId != 0) hash ^= DeviceId.GetHashCode(); + if (Type.Length != 0) hash ^= Type.GetHashCode(); + if (Strength != 0) hash ^= Strength.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DeviceId != 0) { + output.WriteRawTag(8); + output.WriteInt32(DeviceId); + } + if (Type.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Type); + } + if (Strength != 0) { + output.WriteRawTag(24); + output.WriteInt32(Strength); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DeviceId != 0) { + output.WriteRawTag(8); + output.WriteInt32(DeviceId); + } + if (Type.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Type); + } + if (Strength != 0) { + output.WriteRawTag(24); + output.WriteInt32(Strength); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DeviceId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(DeviceId); + } + if (Type.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Type); + } + if (Strength != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Strength); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(InterconnectLink other) { + if (other == null) { + return; + } + if (other.DeviceId != 0) { + DeviceId = other.DeviceId; + } + if (other.Type.Length != 0) { + Type = other.Type; + } + if (other.Strength != 0) { + Strength = other.Strength; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + DeviceId = input.ReadInt32(); + break; + } + case 18: { + Type = input.ReadString(); + break; + } + case 24: { + Strength = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + DeviceId = input.ReadInt32(); + break; + } + case 18: { + Type = input.ReadString(); + break; + } + case 24: { + Strength = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class LocalLinks : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LocalLinks()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalLinks() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalLinks(LocalLinks other) : this() { + link_ = other.link_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalLinks Clone() { + return new LocalLinks(this); + } + + /// Field number for the "link" field. + public const int LinkFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_link_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser); + private readonly pbc::RepeatedField link_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Link { + get { return link_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LocalLinks); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LocalLinks other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!link_.Equals(other.link_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= link_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + link_.WriteTo(output, _repeated_link_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + link_.WriteTo(ref output, _repeated_link_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += link_.CalculateSize(_repeated_link_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LocalLinks other) { + if (other == null) { + return; + } + link_.Add(other.link_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + link_.AddEntriesFrom(input, _repeated_link_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + link_.AddEntriesFrom(ref input, _repeated_link_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeviceLocality : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceLocality()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceLocality() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceLocality(DeviceLocality other) : this() { + busId_ = other.busId_; + numaNode_ = other.numaNode_; + links_ = other.links_ != null ? other.links_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceLocality Clone() { + return new DeviceLocality(this); + } + + /// Field number for the "bus_id" field. + public const int BusIdFieldNumber = 1; + private int busId_; + /// + /// Optional bus locality of device. Default value of 0 means + /// no specific locality. Specific localities are indexed from 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int BusId { + get { return busId_; } + set { + busId_ = value; + } + } + + /// Field number for the "numa_node" field. + public const int NumaNodeFieldNumber = 2; + private int numaNode_; + /// + /// Optional NUMA locality of device. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumaNode { + get { return numaNode_; } + set { + numaNode_ = value; + } + } + + /// Field number for the "links" field. + public const int LinksFieldNumber = 3; + private global::Tensorflow.LocalLinks links_; + /// + /// Optional local interconnect links to other devices. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.LocalLinks Links { + get { return links_; } + set { + links_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceLocality); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceLocality other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BusId != other.BusId) return false; + if (NumaNode != other.NumaNode) return false; + if (!object.Equals(Links, other.Links)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (BusId != 0) hash ^= BusId.GetHashCode(); + if (NumaNode != 0) hash ^= NumaNode.GetHashCode(); + if (links_ != null) hash ^= Links.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (BusId != 0) { + output.WriteRawTag(8); + output.WriteInt32(BusId); + } + if (NumaNode != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumaNode); + } + if (links_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Links); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (BusId != 0) { + output.WriteRawTag(8); + output.WriteInt32(BusId); + } + if (NumaNode != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumaNode); + } + if (links_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Links); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (BusId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BusId); + } + if (NumaNode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumaNode); + } + if (links_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Links); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceLocality other) { + if (other == null) { + return; + } + if (other.BusId != 0) { + BusId = other.BusId; + } + if (other.NumaNode != 0) { + NumaNode = other.NumaNode; + } + if (other.links_ != null) { + if (links_ == null) { + Links = new global::Tensorflow.LocalLinks(); + } + Links.MergeFrom(other.Links); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + BusId = input.ReadInt32(); + break; + } + case 16: { + NumaNode = input.ReadInt32(); + break; + } + case 26: { + if (links_ == null) { + Links = new global::Tensorflow.LocalLinks(); + } + input.ReadMessage(Links); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + BusId = input.ReadInt32(); + break; + } + case 16: { + NumaNode = input.ReadInt32(); + break; + } + case 26: { + if (links_ == null) { + Links = new global::Tensorflow.LocalLinks(); + } + input.ReadMessage(Links); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeviceAttributes : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceAttributes()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAttributes() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAttributes(DeviceAttributes other) : this() { + name_ = other.name_; + deviceType_ = other.deviceType_; + memoryLimit_ = other.memoryLimit_; + locality_ = other.locality_ != null ? other.locality_.Clone() : null; + incarnation_ = other.incarnation_; + physicalDeviceDesc_ = other.physicalDeviceDesc_; + xlaGlobalId_ = other.xlaGlobalId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAttributes Clone() { + return new DeviceAttributes(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Fully specified name of the device within a cluster. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "device_type" field. + public const int DeviceTypeFieldNumber = 2; + private string deviceType_ = ""; + /// + /// String representation of device_type. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DeviceType { + get { return deviceType_; } + set { + deviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "memory_limit" field. + public const int MemoryLimitFieldNumber = 4; + private long memoryLimit_; + /// + /// Memory capacity of device in bytes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long MemoryLimit { + get { return memoryLimit_; } + set { + memoryLimit_ = value; + } + } + + /// Field number for the "locality" field. + public const int LocalityFieldNumber = 5; + private global::Tensorflow.DeviceLocality locality_; + /// + /// Platform-specific data about device that may be useful + /// for supporting efficient data transfers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DeviceLocality Locality { + get { return locality_; } + set { + locality_ = value; + } + } + + /// Field number for the "incarnation" field. + public const int IncarnationFieldNumber = 6; + private ulong incarnation_; + /// + /// A device is assigned a global unique number each time it is + /// initialized. "incarnation" should never be 0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Incarnation { + get { return incarnation_; } + set { + incarnation_ = value; + } + } + + /// Field number for the "physical_device_desc" field. + public const int PhysicalDeviceDescFieldNumber = 7; + private string physicalDeviceDesc_ = ""; + /// + /// String representation of the physical device that this device maps to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PhysicalDeviceDesc { + get { return physicalDeviceDesc_; } + set { + physicalDeviceDesc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_global_id" field. + public const int XlaGlobalIdFieldNumber = 8; + private long xlaGlobalId_; + /// + /// A physical device ID for use in XLA DeviceAssignments, unique across + /// clients in a multi-client setup. Set to -1 if unavailable, non-negative + /// otherwise. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long XlaGlobalId { + get { return xlaGlobalId_; } + set { + xlaGlobalId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceAttributes); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceAttributes other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (DeviceType != other.DeviceType) return false; + if (MemoryLimit != other.MemoryLimit) return false; + if (!object.Equals(Locality, other.Locality)) return false; + if (Incarnation != other.Incarnation) return false; + if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false; + if (XlaGlobalId != other.XlaGlobalId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (DeviceType.Length != 0) hash ^= DeviceType.GetHashCode(); + if (MemoryLimit != 0L) hash ^= MemoryLimit.GetHashCode(); + if (locality_ != null) hash ^= Locality.GetHashCode(); + if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); + if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode(); + if (XlaGlobalId != 0L) hash ^= XlaGlobalId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (DeviceType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DeviceType); + } + if (MemoryLimit != 0L) { + output.WriteRawTag(32); + output.WriteInt64(MemoryLimit); + } + if (locality_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Locality); + } + if (Incarnation != 0UL) { + output.WriteRawTag(49); + output.WriteFixed64(Incarnation); + } + if (PhysicalDeviceDesc.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PhysicalDeviceDesc); + } + if (XlaGlobalId != 0L) { + output.WriteRawTag(64); + output.WriteInt64(XlaGlobalId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (DeviceType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DeviceType); + } + if (MemoryLimit != 0L) { + output.WriteRawTag(32); + output.WriteInt64(MemoryLimit); + } + if (locality_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Locality); + } + if (Incarnation != 0UL) { + output.WriteRawTag(49); + output.WriteFixed64(Incarnation); + } + if (PhysicalDeviceDesc.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PhysicalDeviceDesc); + } + if (XlaGlobalId != 0L) { + output.WriteRawTag(64); + output.WriteInt64(XlaGlobalId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (DeviceType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DeviceType); + } + if (MemoryLimit != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(MemoryLimit); + } + if (locality_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Locality); + } + if (Incarnation != 0UL) { + size += 1 + 8; + } + if (PhysicalDeviceDesc.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc); + } + if (XlaGlobalId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaGlobalId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceAttributes other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.DeviceType.Length != 0) { + DeviceType = other.DeviceType; + } + if (other.MemoryLimit != 0L) { + MemoryLimit = other.MemoryLimit; + } + if (other.locality_ != null) { + if (locality_ == null) { + Locality = new global::Tensorflow.DeviceLocality(); + } + Locality.MergeFrom(other.Locality); + } + if (other.Incarnation != 0UL) { + Incarnation = other.Incarnation; + } + if (other.PhysicalDeviceDesc.Length != 0) { + PhysicalDeviceDesc = other.PhysicalDeviceDesc; + } + if (other.XlaGlobalId != 0L) { + XlaGlobalId = other.XlaGlobalId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + DeviceType = input.ReadString(); + break; + } + case 32: { + MemoryLimit = input.ReadInt64(); + break; + } + case 42: { + if (locality_ == null) { + Locality = new global::Tensorflow.DeviceLocality(); + } + input.ReadMessage(Locality); + break; + } + case 49: { + Incarnation = input.ReadFixed64(); + break; + } + case 58: { + PhysicalDeviceDesc = input.ReadString(); + break; + } + case 64: { + XlaGlobalId = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + DeviceType = input.ReadString(); + break; + } + case 32: { + MemoryLimit = input.ReadInt64(); + break; + } + case 42: { + if (locality_ == null) { + Locality = new global::Tensorflow.DeviceLocality(); + } + input.ReadMessage(Locality); + break; + } + case 49: { + Incarnation = input.ReadFixed64(); + break; + } + case 58: { + PhysicalDeviceDesc = input.ReadString(); + break; + } + case 64: { + XlaGlobalId = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Event.cs b/src/TensorFlowNET.Core/Protobuf/Event.cs new file mode 100644 index 000000000..cd80bf37d --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Event.cs @@ -0,0 +1,2422 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/util/event.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/util/event.proto + public static partial class EventReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/util/event.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static EventReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiB0ZW5zb3JmbG93L2NvcmUvdXRpbC9ldmVudC5wcm90bxIKdGVuc29yZmxv", + "dxondGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay9zdW1tYXJ5LnByb3RvIr8C", + "CgVFdmVudBIRCgl3YWxsX3RpbWUYASABKAESDAoEc3RlcBgCIAEoAxIWCgxm", + "aWxlX3ZlcnNpb24YAyABKAlIABITCglncmFwaF9kZWYYBCABKAxIABImCgdz", + "dW1tYXJ5GAUgASgLMhMudGVuc29yZmxvdy5TdW1tYXJ5SAASMQoLbG9nX21l", + "c3NhZ2UYBiABKAsyFi50ZW5zb3JmbG93LkxvZ01lc3NhZ2VCAhgBSAASLQoL", + "c2Vzc2lvbl9sb2cYByABKAsyFi50ZW5zb3JmbG93LlNlc3Npb25Mb2dIABI8", + "ChN0YWdnZWRfcnVuX21ldGFkYXRhGAggASgLMh0udGVuc29yZmxvdy5UYWdn", + "ZWRSdW5NZXRhZGF0YUgAEhgKDm1ldGFfZ3JhcGhfZGVmGAkgASgMSABCBgoE", + "d2hhdCKhAQoKTG9nTWVzc2FnZRIrCgVsZXZlbBgBIAEoDjIcLnRlbnNvcmZs", + "b3cuTG9nTWVzc2FnZS5MZXZlbBIPCgdtZXNzYWdlGAIgASgJIlEKBUxldmVs", + "EgsKB1VOS05PV04QABINCglERUJVR0dJTkcQChIICgRJTkZPEBQSCAoEV0FS", + "ThAeEgkKBUVSUk9SECgSCQoFRkFUQUwQMhoCGAE6AhgBIrYBCgpTZXNzaW9u", + "TG9nEjQKBnN0YXR1cxgBIAEoDjIkLnRlbnNvcmZsb3cuU2Vzc2lvbkxvZy5T", + "ZXNzaW9uU3RhdHVzEhcKD2NoZWNrcG9pbnRfcGF0aBgCIAEoCRILCgNtc2cY", + "AyABKAkiTAoNU2Vzc2lvblN0YXR1cxIWChJTVEFUVVNfVU5TUEVDSUZJRUQQ", + "ABIJCgVTVEFSVBABEggKBFNUT1AQAhIOCgpDSEVDS1BPSU5UEAMiNgoRVGFn", + "Z2VkUnVuTWV0YWRhdGESCwoDdGFnGAEgASgJEhQKDHJ1bl9tZXRhZGF0YRgC", + "IAEoDCIkCg5XYXRjaGRvZ0NvbmZpZxISCgp0aW1lb3V0X21zGAEgASgDIiYK", + "EVJlcXVlc3RlZEV4aXRDb2RlEhEKCWV4aXRfY29kZRgBIAEoBSK2AQoWV29y", + "a2VySGVhcnRiZWF0UmVxdWVzdBI1Cg1zaHV0ZG93bl9tb2RlGAEgASgOMh4u", + "dGVuc29yZmxvdy5Xb3JrZXJTaHV0ZG93bk1vZGUSMwoPd2F0Y2hkb2dfY29u", + "ZmlnGAIgASgLMhoudGVuc29yZmxvdy5XYXRjaGRvZ0NvbmZpZxIwCglleGl0", + "X2NvZGUYAyABKAsyHS50ZW5zb3JmbG93LlJlcXVlc3RlZEV4aXRDb2RlIoMB", + "ChdXb3JrZXJIZWFydGJlYXRSZXNwb25zZRIvCg1oZWFsdGhfc3RhdHVzGAEg", + "ASgOMhgudGVuc29yZmxvdy5Xb3JrZXJIZWFsdGgSJQoKd29ya2VyX2xvZxgC", + "IAMoCzIRLnRlbnNvcmZsb3cuRXZlbnQSEAoIaG9zdG5hbWUYAyABKAkqWwoM", + "V29ya2VySGVhbHRoEgYKAk9LEAASHAoYUkVDRUlWRURfU0hVVERPV05fU0lH", + "TkFMEAESEgoOSU5URVJOQUxfRVJST1IQAhIRCg1TSFVUVElOR19ET1dOEAMq", + "awoSV29ya2VyU2h1dGRvd25Nb2RlEgsKB0RFRkFVTFQQABISCg5OT1RfQ09O", + "RklHVVJFRBABEhgKFFdBSVRfRk9SX0NPT1JESU5BVE9SEAISGgoWU0hVVERP", + "V05fQUZURVJfVElNRU9VVBADQnAKE29yZy50ZW5zb3JmbG93LnV0aWxCC0V2", + "ZW50UHJvdG9zUAFaR2dpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93", + "L3RlbnNvcmZsb3cvZ28vY29yZS91dGlsL2V2ZW50X2dvX3Byb3Rv+AEBYgZw", + "cm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.SummaryReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.WorkerHealth), typeof(global::Tensorflow.WorkerShutdownMode), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Event), global::Tensorflow.Event.Parser, new[]{ "WallTime", "Step", "FileVersion", "GraphDef", "Summary", "LogMessage", "SessionLog", "TaggedRunMetadata", "MetaGraphDef" }, new[]{ "What" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LogMessage), global::Tensorflow.LogMessage.Parser, new[]{ "Level", "Message" }, null, new[]{ typeof(global::Tensorflow.LogMessage.Types.Level) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SessionLog), global::Tensorflow.SessionLog.Parser, new[]{ "Status", "CheckpointPath", "Msg" }, null, new[]{ typeof(global::Tensorflow.SessionLog.Types.SessionStatus) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TaggedRunMetadata), global::Tensorflow.TaggedRunMetadata.Parser, new[]{ "Tag", "RunMetadata" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WatchdogConfig), global::Tensorflow.WatchdogConfig.Parser, new[]{ "TimeoutMs" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RequestedExitCode), global::Tensorflow.RequestedExitCode.Parser, new[]{ "ExitCode" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WorkerHeartbeatRequest), global::Tensorflow.WorkerHeartbeatRequest.Parser, new[]{ "ShutdownMode", "WatchdogConfig", "ExitCode" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WorkerHeartbeatResponse), global::Tensorflow.WorkerHeartbeatResponse.Parser, new[]{ "HealthStatus", "WorkerLog", "Hostname" }, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Current health status of a worker. + /// + public enum WorkerHealth { + /// + /// By default a worker is healthy. + /// + [pbr::OriginalName("OK")] Ok = 0, + [pbr::OriginalName("RECEIVED_SHUTDOWN_SIGNAL")] ReceivedShutdownSignal = 1, + [pbr::OriginalName("INTERNAL_ERROR")] InternalError = 2, + /// + /// Worker has been instructed to shutdown after a timeout. + /// + [pbr::OriginalName("SHUTTING_DOWN")] ShuttingDown = 3, + } + + /// + /// Indicates the behavior of the worker when an internal error or shutdown + /// signal is received. + /// + public enum WorkerShutdownMode { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("NOT_CONFIGURED")] NotConfigured = 1, + [pbr::OriginalName("WAIT_FOR_COORDINATOR")] WaitForCoordinator = 2, + [pbr::OriginalName("SHUTDOWN_AFTER_TIMEOUT")] ShutdownAfterTimeout = 3, + } + + #endregion + + #region Messages + /// + /// Protocol buffer representing an event that happened during + /// the execution of a Brain model. + /// + public sealed partial class Event : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Event()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event(Event other) : this() { + wallTime_ = other.wallTime_; + step_ = other.step_; + switch (other.WhatCase) { + case WhatOneofCase.FileVersion: + FileVersion = other.FileVersion; + break; + case WhatOneofCase.GraphDef: + GraphDef = other.GraphDef; + break; + case WhatOneofCase.Summary: + Summary = other.Summary.Clone(); + break; + case WhatOneofCase.LogMessage: + LogMessage = other.LogMessage.Clone(); + break; + case WhatOneofCase.SessionLog: + SessionLog = other.SessionLog.Clone(); + break; + case WhatOneofCase.TaggedRunMetadata: + TaggedRunMetadata = other.TaggedRunMetadata.Clone(); + break; + case WhatOneofCase.MetaGraphDef: + MetaGraphDef = other.MetaGraphDef; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event Clone() { + return new Event(this); + } + + /// Field number for the "wall_time" field. + public const int WallTimeFieldNumber = 1; + private double wallTime_; + /// + /// Timestamp of the event. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double WallTime { + get { return wallTime_; } + set { + wallTime_ = value; + } + } + + /// Field number for the "step" field. + public const int StepFieldNumber = 2; + private long step_; + /// + /// Global step of the event. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Step { + get { return step_; } + set { + step_ = value; + } + } + + /// Field number for the "file_version" field. + public const int FileVersionFieldNumber = 3; + /// + /// An event file was started, with the specified version. + /// This is use to identify the contents of the record IO files + /// easily. Current version is "brain.Event:2". All versions + /// start with "brain.Event:". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FileVersion { + get { return whatCase_ == WhatOneofCase.FileVersion ? (string) what_ : ""; } + set { + what_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + whatCase_ = WhatOneofCase.FileVersion; + } + } + + /// Field number for the "graph_def" field. + public const int GraphDefFieldNumber = 4; + /// + /// An encoded version of a GraphDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString GraphDef { + get { return whatCase_ == WhatOneofCase.GraphDef ? (pb::ByteString) what_ : pb::ByteString.Empty; } + set { + what_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + whatCase_ = WhatOneofCase.GraphDef; + } + } + + /// Field number for the "summary" field. + public const int SummaryFieldNumber = 5; + /// + /// A summary was generated. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.Summary Summary { + get { return whatCase_ == WhatOneofCase.Summary ? (global::Tensorflow.Summary) what_ : null; } + set { + what_ = value; + whatCase_ = value == null ? WhatOneofCase.None : WhatOneofCase.Summary; + } + } + + /// Field number for the "log_message" field. + public const int LogMessageFieldNumber = 6; + /// + /// The user output a log message. This was theoretically used by the defunct + /// tensorboard_logging module, which has since been removed; this field is + /// now deprecated and should not be used. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.LogMessage LogMessage { + get { return whatCase_ == WhatOneofCase.LogMessage ? (global::Tensorflow.LogMessage) what_ : null; } + set { + what_ = value; + whatCase_ = value == null ? WhatOneofCase.None : WhatOneofCase.LogMessage; + } + } + + /// Field number for the "session_log" field. + public const int SessionLogFieldNumber = 7; + /// + /// The state of the session which can be used for restarting after crashes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SessionLog SessionLog { + get { return whatCase_ == WhatOneofCase.SessionLog ? (global::Tensorflow.SessionLog) what_ : null; } + set { + what_ = value; + whatCase_ = value == null ? WhatOneofCase.None : WhatOneofCase.SessionLog; + } + } + + /// Field number for the "tagged_run_metadata" field. + public const int TaggedRunMetadataFieldNumber = 8; + /// + /// The metadata returned by running a session.run() call. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TaggedRunMetadata TaggedRunMetadata { + get { return whatCase_ == WhatOneofCase.TaggedRunMetadata ? (global::Tensorflow.TaggedRunMetadata) what_ : null; } + set { + what_ = value; + whatCase_ = value == null ? WhatOneofCase.None : WhatOneofCase.TaggedRunMetadata; + } + } + + /// Field number for the "meta_graph_def" field. + public const int MetaGraphDefFieldNumber = 9; + /// + /// An encoded version of a MetaGraphDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString MetaGraphDef { + get { return whatCase_ == WhatOneofCase.MetaGraphDef ? (pb::ByteString) what_ : pb::ByteString.Empty; } + set { + what_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + whatCase_ = WhatOneofCase.MetaGraphDef; + } + } + + private object what_; + /// Enum of possible cases for the "what" oneof. + public enum WhatOneofCase { + None = 0, + FileVersion = 3, + GraphDef = 4, + Summary = 5, + LogMessage = 6, + SessionLog = 7, + TaggedRunMetadata = 8, + MetaGraphDef = 9, + } + private WhatOneofCase whatCase_ = WhatOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhatOneofCase WhatCase { + get { return whatCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearWhat() { + whatCase_ = WhatOneofCase.None; + what_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Event); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Event other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(WallTime, other.WallTime)) return false; + if (Step != other.Step) return false; + if (FileVersion != other.FileVersion) return false; + if (GraphDef != other.GraphDef) return false; + if (!object.Equals(Summary, other.Summary)) return false; + if (!object.Equals(LogMessage, other.LogMessage)) return false; + if (!object.Equals(SessionLog, other.SessionLog)) return false; + if (!object.Equals(TaggedRunMetadata, other.TaggedRunMetadata)) return false; + if (MetaGraphDef != other.MetaGraphDef) return false; + if (WhatCase != other.WhatCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (WallTime != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(WallTime); + if (Step != 0L) hash ^= Step.GetHashCode(); + if (whatCase_ == WhatOneofCase.FileVersion) hash ^= FileVersion.GetHashCode(); + if (whatCase_ == WhatOneofCase.GraphDef) hash ^= GraphDef.GetHashCode(); + if (whatCase_ == WhatOneofCase.Summary) hash ^= Summary.GetHashCode(); + if (whatCase_ == WhatOneofCase.LogMessage) hash ^= LogMessage.GetHashCode(); + if (whatCase_ == WhatOneofCase.SessionLog) hash ^= SessionLog.GetHashCode(); + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) hash ^= TaggedRunMetadata.GetHashCode(); + if (whatCase_ == WhatOneofCase.MetaGraphDef) hash ^= MetaGraphDef.GetHashCode(); + hash ^= (int) whatCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (WallTime != 0D) { + output.WriteRawTag(9); + output.WriteDouble(WallTime); + } + if (Step != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Step); + } + if (whatCase_ == WhatOneofCase.FileVersion) { + output.WriteRawTag(26); + output.WriteString(FileVersion); + } + if (whatCase_ == WhatOneofCase.GraphDef) { + output.WriteRawTag(34); + output.WriteBytes(GraphDef); + } + if (whatCase_ == WhatOneofCase.Summary) { + output.WriteRawTag(42); + output.WriteMessage(Summary); + } + if (whatCase_ == WhatOneofCase.LogMessage) { + output.WriteRawTag(50); + output.WriteMessage(LogMessage); + } + if (whatCase_ == WhatOneofCase.SessionLog) { + output.WriteRawTag(58); + output.WriteMessage(SessionLog); + } + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) { + output.WriteRawTag(66); + output.WriteMessage(TaggedRunMetadata); + } + if (whatCase_ == WhatOneofCase.MetaGraphDef) { + output.WriteRawTag(74); + output.WriteBytes(MetaGraphDef); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (WallTime != 0D) { + output.WriteRawTag(9); + output.WriteDouble(WallTime); + } + if (Step != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Step); + } + if (whatCase_ == WhatOneofCase.FileVersion) { + output.WriteRawTag(26); + output.WriteString(FileVersion); + } + if (whatCase_ == WhatOneofCase.GraphDef) { + output.WriteRawTag(34); + output.WriteBytes(GraphDef); + } + if (whatCase_ == WhatOneofCase.Summary) { + output.WriteRawTag(42); + output.WriteMessage(Summary); + } + if (whatCase_ == WhatOneofCase.LogMessage) { + output.WriteRawTag(50); + output.WriteMessage(LogMessage); + } + if (whatCase_ == WhatOneofCase.SessionLog) { + output.WriteRawTag(58); + output.WriteMessage(SessionLog); + } + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) { + output.WriteRawTag(66); + output.WriteMessage(TaggedRunMetadata); + } + if (whatCase_ == WhatOneofCase.MetaGraphDef) { + output.WriteRawTag(74); + output.WriteBytes(MetaGraphDef); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (WallTime != 0D) { + size += 1 + 8; + } + if (Step != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Step); + } + if (whatCase_ == WhatOneofCase.FileVersion) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FileVersion); + } + if (whatCase_ == WhatOneofCase.GraphDef) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(GraphDef); + } + if (whatCase_ == WhatOneofCase.Summary) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Summary); + } + if (whatCase_ == WhatOneofCase.LogMessage) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LogMessage); + } + if (whatCase_ == WhatOneofCase.SessionLog) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SessionLog); + } + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TaggedRunMetadata); + } + if (whatCase_ == WhatOneofCase.MetaGraphDef) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(MetaGraphDef); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Event other) { + if (other == null) { + return; + } + if (other.WallTime != 0D) { + WallTime = other.WallTime; + } + if (other.Step != 0L) { + Step = other.Step; + } + switch (other.WhatCase) { + case WhatOneofCase.FileVersion: + FileVersion = other.FileVersion; + break; + case WhatOneofCase.GraphDef: + GraphDef = other.GraphDef; + break; + case WhatOneofCase.Summary: + if (Summary == null) { + Summary = new global::Tensorflow.Summary(); + } + Summary.MergeFrom(other.Summary); + break; + case WhatOneofCase.LogMessage: + if (LogMessage == null) { + LogMessage = new global::Tensorflow.LogMessage(); + } + LogMessage.MergeFrom(other.LogMessage); + break; + case WhatOneofCase.SessionLog: + if (SessionLog == null) { + SessionLog = new global::Tensorflow.SessionLog(); + } + SessionLog.MergeFrom(other.SessionLog); + break; + case WhatOneofCase.TaggedRunMetadata: + if (TaggedRunMetadata == null) { + TaggedRunMetadata = new global::Tensorflow.TaggedRunMetadata(); + } + TaggedRunMetadata.MergeFrom(other.TaggedRunMetadata); + break; + case WhatOneofCase.MetaGraphDef: + MetaGraphDef = other.MetaGraphDef; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + WallTime = input.ReadDouble(); + break; + } + case 16: { + Step = input.ReadInt64(); + break; + } + case 26: { + FileVersion = input.ReadString(); + break; + } + case 34: { + GraphDef = input.ReadBytes(); + break; + } + case 42: { + global::Tensorflow.Summary subBuilder = new global::Tensorflow.Summary(); + if (whatCase_ == WhatOneofCase.Summary) { + subBuilder.MergeFrom(Summary); + } + input.ReadMessage(subBuilder); + Summary = subBuilder; + break; + } + case 50: { + global::Tensorflow.LogMessage subBuilder = new global::Tensorflow.LogMessage(); + if (whatCase_ == WhatOneofCase.LogMessage) { + subBuilder.MergeFrom(LogMessage); + } + input.ReadMessage(subBuilder); + LogMessage = subBuilder; + break; + } + case 58: { + global::Tensorflow.SessionLog subBuilder = new global::Tensorflow.SessionLog(); + if (whatCase_ == WhatOneofCase.SessionLog) { + subBuilder.MergeFrom(SessionLog); + } + input.ReadMessage(subBuilder); + SessionLog = subBuilder; + break; + } + case 66: { + global::Tensorflow.TaggedRunMetadata subBuilder = new global::Tensorflow.TaggedRunMetadata(); + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) { + subBuilder.MergeFrom(TaggedRunMetadata); + } + input.ReadMessage(subBuilder); + TaggedRunMetadata = subBuilder; + break; + } + case 74: { + MetaGraphDef = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + WallTime = input.ReadDouble(); + break; + } + case 16: { + Step = input.ReadInt64(); + break; + } + case 26: { + FileVersion = input.ReadString(); + break; + } + case 34: { + GraphDef = input.ReadBytes(); + break; + } + case 42: { + global::Tensorflow.Summary subBuilder = new global::Tensorflow.Summary(); + if (whatCase_ == WhatOneofCase.Summary) { + subBuilder.MergeFrom(Summary); + } + input.ReadMessage(subBuilder); + Summary = subBuilder; + break; + } + case 50: { + global::Tensorflow.LogMessage subBuilder = new global::Tensorflow.LogMessage(); + if (whatCase_ == WhatOneofCase.LogMessage) { + subBuilder.MergeFrom(LogMessage); + } + input.ReadMessage(subBuilder); + LogMessage = subBuilder; + break; + } + case 58: { + global::Tensorflow.SessionLog subBuilder = new global::Tensorflow.SessionLog(); + if (whatCase_ == WhatOneofCase.SessionLog) { + subBuilder.MergeFrom(SessionLog); + } + input.ReadMessage(subBuilder); + SessionLog = subBuilder; + break; + } + case 66: { + global::Tensorflow.TaggedRunMetadata subBuilder = new global::Tensorflow.TaggedRunMetadata(); + if (whatCase_ == WhatOneofCase.TaggedRunMetadata) { + subBuilder.MergeFrom(TaggedRunMetadata); + } + input.ReadMessage(subBuilder); + TaggedRunMetadata = subBuilder; + break; + } + case 74: { + MetaGraphDef = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + /// + /// Protocol buffer used for logging messages to the events file. + /// + /// This was theoretically used by the defunct tensorboard_logging module, which + /// has been removed; this message is now deprecated and should not be used. + /// + [global::System.ObsoleteAttribute] + public sealed partial class LogMessage : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LogMessage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogMessage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogMessage(LogMessage other) : this() { + level_ = other.level_; + message_ = other.message_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogMessage Clone() { + return new LogMessage(this); + } + + /// Field number for the "level" field. + public const int LevelFieldNumber = 1; + private global::Tensorflow.LogMessage.Types.Level level_ = global::Tensorflow.LogMessage.Types.Level.Unknown; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.LogMessage.Types.Level Level { + get { return level_; } + set { + level_ = value; + } + } + + /// Field number for the "message" field. + public const int MessageFieldNumber = 2; + private string message_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Message { + get { return message_; } + set { + message_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LogMessage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LogMessage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Level != other.Level) return false; + if (Message != other.Message) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Level != global::Tensorflow.LogMessage.Types.Level.Unknown) hash ^= Level.GetHashCode(); + if (Message.Length != 0) hash ^= Message.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Level != global::Tensorflow.LogMessage.Types.Level.Unknown) { + output.WriteRawTag(8); + output.WriteEnum((int) Level); + } + if (Message.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Message); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Level != global::Tensorflow.LogMessage.Types.Level.Unknown) { + output.WriteRawTag(8); + output.WriteEnum((int) Level); + } + if (Message.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Message); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Level != global::Tensorflow.LogMessage.Types.Level.Unknown) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Level); + } + if (Message.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Message); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LogMessage other) { + if (other == null) { + return; + } + if (other.Level != global::Tensorflow.LogMessage.Types.Level.Unknown) { + Level = other.Level; + } + if (other.Message.Length != 0) { + Message = other.Message; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Level = (global::Tensorflow.LogMessage.Types.Level) input.ReadEnum(); + break; + } + case 18: { + Message = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Level = (global::Tensorflow.LogMessage.Types.Level) input.ReadEnum(); + break; + } + case 18: { + Message = input.ReadString(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the LogMessage message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Level { + [pbr::OriginalName("UNKNOWN")] Unknown = 0, + /// + /// Note: The logging level 10 cannot be named DEBUG. Some software + /// projects compile their C/C++ code with -DDEBUG in debug builds. So the + /// C++ code generated from this file should not have an identifier named + /// DEBUG. + /// + [pbr::OriginalName("DEBUGGING")] Debugging = 10, + [pbr::OriginalName("INFO")] Info = 20, + [pbr::OriginalName("WARN")] Warn = 30, + [pbr::OriginalName("ERROR")] Error = 40, + [pbr::OriginalName("FATAL")] Fatal = 50, + } + + } + #endregion + + } + + /// + /// Protocol buffer used for logging session state. + /// + public sealed partial class SessionLog : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SessionLog()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionLog() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionLog(SessionLog other) : this() { + status_ = other.status_; + checkpointPath_ = other.checkpointPath_; + msg_ = other.msg_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SessionLog Clone() { + return new SessionLog(this); + } + + /// Field number for the "status" field. + public const int StatusFieldNumber = 1; + private global::Tensorflow.SessionLog.Types.SessionStatus status_ = global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SessionLog.Types.SessionStatus Status { + get { return status_; } + set { + status_ = value; + } + } + + /// Field number for the "checkpoint_path" field. + public const int CheckpointPathFieldNumber = 2; + private string checkpointPath_ = ""; + /// + /// This checkpoint_path contains both the path and filename. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CheckpointPath { + get { return checkpointPath_; } + set { + checkpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "msg" field. + public const int MsgFieldNumber = 3; + private string msg_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Msg { + get { return msg_; } + set { + msg_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SessionLog); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SessionLog other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Status != other.Status) return false; + if (CheckpointPath != other.CheckpointPath) return false; + if (Msg != other.Msg) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Status != global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified) hash ^= Status.GetHashCode(); + if (CheckpointPath.Length != 0) hash ^= CheckpointPath.GetHashCode(); + if (Msg.Length != 0) hash ^= Msg.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Status != global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified) { + output.WriteRawTag(8); + output.WriteEnum((int) Status); + } + if (CheckpointPath.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CheckpointPath); + } + if (Msg.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Msg); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Status != global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified) { + output.WriteRawTag(8); + output.WriteEnum((int) Status); + } + if (CheckpointPath.Length != 0) { + output.WriteRawTag(18); + output.WriteString(CheckpointPath); + } + if (Msg.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Msg); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Status != global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Status); + } + if (CheckpointPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CheckpointPath); + } + if (Msg.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Msg); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SessionLog other) { + if (other == null) { + return; + } + if (other.Status != global::Tensorflow.SessionLog.Types.SessionStatus.StatusUnspecified) { + Status = other.Status; + } + if (other.CheckpointPath.Length != 0) { + CheckpointPath = other.CheckpointPath; + } + if (other.Msg.Length != 0) { + Msg = other.Msg; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Status = (global::Tensorflow.SessionLog.Types.SessionStatus) input.ReadEnum(); + break; + } + case 18: { + CheckpointPath = input.ReadString(); + break; + } + case 26: { + Msg = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Status = (global::Tensorflow.SessionLog.Types.SessionStatus) input.ReadEnum(); + break; + } + case 18: { + CheckpointPath = input.ReadString(); + break; + } + case 26: { + Msg = input.ReadString(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the SessionLog message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum SessionStatus { + [pbr::OriginalName("STATUS_UNSPECIFIED")] StatusUnspecified = 0, + [pbr::OriginalName("START")] Start = 1, + [pbr::OriginalName("STOP")] Stop = 2, + [pbr::OriginalName("CHECKPOINT")] Checkpoint = 3, + } + + } + #endregion + + } + + /// + /// For logging the metadata output for a single session.run() call. + /// + public sealed partial class TaggedRunMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TaggedRunMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TaggedRunMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TaggedRunMetadata(TaggedRunMetadata other) : this() { + tag_ = other.tag_; + runMetadata_ = other.runMetadata_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TaggedRunMetadata Clone() { + return new TaggedRunMetadata(this); + } + + /// Field number for the "tag" field. + public const int TagFieldNumber = 1; + private string tag_ = ""; + /// + /// Tag name associated with this metadata. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Tag { + get { return tag_; } + set { + tag_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "run_metadata" field. + public const int RunMetadataFieldNumber = 2; + private pb::ByteString runMetadata_ = pb::ByteString.Empty; + /// + /// Byte-encoded version of the `RunMetadata` proto in order to allow lazy + /// deserialization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString RunMetadata { + get { return runMetadata_; } + set { + runMetadata_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TaggedRunMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TaggedRunMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Tag != other.Tag) return false; + if (RunMetadata != other.RunMetadata) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Tag.Length != 0) hash ^= Tag.GetHashCode(); + if (RunMetadata.Length != 0) hash ^= RunMetadata.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Tag.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Tag); + } + if (RunMetadata.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(RunMetadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Tag.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Tag); + } + if (RunMetadata.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(RunMetadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Tag.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Tag); + } + if (RunMetadata.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(RunMetadata); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TaggedRunMetadata other) { + if (other == null) { + return; + } + if (other.Tag.Length != 0) { + Tag = other.Tag; + } + if (other.RunMetadata.Length != 0) { + RunMetadata = other.RunMetadata; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Tag = input.ReadString(); + break; + } + case 18: { + RunMetadata = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Tag = input.ReadString(); + break; + } + case 18: { + RunMetadata = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + public sealed partial class WatchdogConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WatchdogConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WatchdogConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WatchdogConfig(WatchdogConfig other) : this() { + timeoutMs_ = other.timeoutMs_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WatchdogConfig Clone() { + return new WatchdogConfig(this); + } + + /// Field number for the "timeout_ms" field. + public const int TimeoutMsFieldNumber = 1; + private long timeoutMs_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TimeoutMs { + get { return timeoutMs_; } + set { + timeoutMs_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WatchdogConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WatchdogConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TimeoutMs != other.TimeoutMs) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TimeoutMs != 0L) hash ^= TimeoutMs.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TimeoutMs != 0L) { + output.WriteRawTag(8); + output.WriteInt64(TimeoutMs); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TimeoutMs != 0L) { + output.WriteRawTag(8); + output.WriteInt64(TimeoutMs); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TimeoutMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TimeoutMs); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WatchdogConfig other) { + if (other == null) { + return; + } + if (other.TimeoutMs != 0L) { + TimeoutMs = other.TimeoutMs; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TimeoutMs = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + TimeoutMs = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class RequestedExitCode : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RequestedExitCode()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RequestedExitCode() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RequestedExitCode(RequestedExitCode other) : this() { + exitCode_ = other.exitCode_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RequestedExitCode Clone() { + return new RequestedExitCode(this); + } + + /// Field number for the "exit_code" field. + public const int ExitCodeFieldNumber = 1; + private int exitCode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ExitCode { + get { return exitCode_; } + set { + exitCode_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RequestedExitCode); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RequestedExitCode other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ExitCode != other.ExitCode) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ExitCode != 0) hash ^= ExitCode.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ExitCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ExitCode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ExitCode != 0) { + output.WriteRawTag(8); + output.WriteInt32(ExitCode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ExitCode != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ExitCode); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RequestedExitCode other) { + if (other == null) { + return; + } + if (other.ExitCode != 0) { + ExitCode = other.ExitCode; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ExitCode = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ExitCode = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class WorkerHeartbeatRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WorkerHeartbeatRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatRequest(WorkerHeartbeatRequest other) : this() { + shutdownMode_ = other.shutdownMode_; + watchdogConfig_ = other.watchdogConfig_ != null ? other.watchdogConfig_.Clone() : null; + exitCode_ = other.exitCode_ != null ? other.exitCode_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatRequest Clone() { + return new WorkerHeartbeatRequest(this); + } + + /// Field number for the "shutdown_mode" field. + public const int ShutdownModeFieldNumber = 1; + private global::Tensorflow.WorkerShutdownMode shutdownMode_ = global::Tensorflow.WorkerShutdownMode.Default; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.WorkerShutdownMode ShutdownMode { + get { return shutdownMode_; } + set { + shutdownMode_ = value; + } + } + + /// Field number for the "watchdog_config" field. + public const int WatchdogConfigFieldNumber = 2; + private global::Tensorflow.WatchdogConfig watchdogConfig_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.WatchdogConfig WatchdogConfig { + get { return watchdogConfig_; } + set { + watchdogConfig_ = value; + } + } + + /// Field number for the "exit_code" field. + public const int ExitCodeFieldNumber = 3; + private global::Tensorflow.RequestedExitCode exitCode_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RequestedExitCode ExitCode { + get { return exitCode_; } + set { + exitCode_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WorkerHeartbeatRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WorkerHeartbeatRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ShutdownMode != other.ShutdownMode) return false; + if (!object.Equals(WatchdogConfig, other.WatchdogConfig)) return false; + if (!object.Equals(ExitCode, other.ExitCode)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ShutdownMode != global::Tensorflow.WorkerShutdownMode.Default) hash ^= ShutdownMode.GetHashCode(); + if (watchdogConfig_ != null) hash ^= WatchdogConfig.GetHashCode(); + if (exitCode_ != null) hash ^= ExitCode.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ShutdownMode != global::Tensorflow.WorkerShutdownMode.Default) { + output.WriteRawTag(8); + output.WriteEnum((int) ShutdownMode); + } + if (watchdogConfig_ != null) { + output.WriteRawTag(18); + output.WriteMessage(WatchdogConfig); + } + if (exitCode_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ExitCode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ShutdownMode != global::Tensorflow.WorkerShutdownMode.Default) { + output.WriteRawTag(8); + output.WriteEnum((int) ShutdownMode); + } + if (watchdogConfig_ != null) { + output.WriteRawTag(18); + output.WriteMessage(WatchdogConfig); + } + if (exitCode_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ExitCode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ShutdownMode != global::Tensorflow.WorkerShutdownMode.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ShutdownMode); + } + if (watchdogConfig_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WatchdogConfig); + } + if (exitCode_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExitCode); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WorkerHeartbeatRequest other) { + if (other == null) { + return; + } + if (other.ShutdownMode != global::Tensorflow.WorkerShutdownMode.Default) { + ShutdownMode = other.ShutdownMode; + } + if (other.watchdogConfig_ != null) { + if (watchdogConfig_ == null) { + WatchdogConfig = new global::Tensorflow.WatchdogConfig(); + } + WatchdogConfig.MergeFrom(other.WatchdogConfig); + } + if (other.exitCode_ != null) { + if (exitCode_ == null) { + ExitCode = new global::Tensorflow.RequestedExitCode(); + } + ExitCode.MergeFrom(other.ExitCode); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ShutdownMode = (global::Tensorflow.WorkerShutdownMode) input.ReadEnum(); + break; + } + case 18: { + if (watchdogConfig_ == null) { + WatchdogConfig = new global::Tensorflow.WatchdogConfig(); + } + input.ReadMessage(WatchdogConfig); + break; + } + case 26: { + if (exitCode_ == null) { + ExitCode = new global::Tensorflow.RequestedExitCode(); + } + input.ReadMessage(ExitCode); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ShutdownMode = (global::Tensorflow.WorkerShutdownMode) input.ReadEnum(); + break; + } + case 18: { + if (watchdogConfig_ == null) { + WatchdogConfig = new global::Tensorflow.WatchdogConfig(); + } + input.ReadMessage(WatchdogConfig); + break; + } + case 26: { + if (exitCode_ == null) { + ExitCode = new global::Tensorflow.RequestedExitCode(); + } + input.ReadMessage(ExitCode); + break; + } + } + } + } + #endif + + } + + public sealed partial class WorkerHeartbeatResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WorkerHeartbeatResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.EventReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatResponse(WorkerHeartbeatResponse other) : this() { + healthStatus_ = other.healthStatus_; + workerLog_ = other.workerLog_.Clone(); + hostname_ = other.hostname_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerHeartbeatResponse Clone() { + return new WorkerHeartbeatResponse(this); + } + + /// Field number for the "health_status" field. + public const int HealthStatusFieldNumber = 1; + private global::Tensorflow.WorkerHealth healthStatus_ = global::Tensorflow.WorkerHealth.Ok; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.WorkerHealth HealthStatus { + get { return healthStatus_; } + set { + healthStatus_ = value; + } + } + + /// Field number for the "worker_log" field. + public const int WorkerLogFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_workerLog_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.Event.Parser); + private readonly pbc::RepeatedField workerLog_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField WorkerLog { + get { return workerLog_; } + } + + /// Field number for the "hostname" field. + public const int HostnameFieldNumber = 3; + private string hostname_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Hostname { + get { return hostname_; } + set { + hostname_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WorkerHeartbeatResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WorkerHeartbeatResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (HealthStatus != other.HealthStatus) return false; + if(!workerLog_.Equals(other.workerLog_)) return false; + if (Hostname != other.Hostname) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (HealthStatus != global::Tensorflow.WorkerHealth.Ok) hash ^= HealthStatus.GetHashCode(); + hash ^= workerLog_.GetHashCode(); + if (Hostname.Length != 0) hash ^= Hostname.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (HealthStatus != global::Tensorflow.WorkerHealth.Ok) { + output.WriteRawTag(8); + output.WriteEnum((int) HealthStatus); + } + workerLog_.WriteTo(output, _repeated_workerLog_codec); + if (Hostname.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Hostname); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (HealthStatus != global::Tensorflow.WorkerHealth.Ok) { + output.WriteRawTag(8); + output.WriteEnum((int) HealthStatus); + } + workerLog_.WriteTo(ref output, _repeated_workerLog_codec); + if (Hostname.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Hostname); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (HealthStatus != global::Tensorflow.WorkerHealth.Ok) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) HealthStatus); + } + size += workerLog_.CalculateSize(_repeated_workerLog_codec); + if (Hostname.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Hostname); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WorkerHeartbeatResponse other) { + if (other == null) { + return; + } + if (other.HealthStatus != global::Tensorflow.WorkerHealth.Ok) { + HealthStatus = other.HealthStatus; + } + workerLog_.Add(other.workerLog_); + if (other.Hostname.Length != 0) { + Hostname = other.Hostname; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + HealthStatus = (global::Tensorflow.WorkerHealth) input.ReadEnum(); + break; + } + case 18: { + workerLog_.AddEntriesFrom(input, _repeated_workerLog_codec); + break; + } + case 26: { + Hostname = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + HealthStatus = (global::Tensorflow.WorkerHealth) input.ReadEnum(); + break; + } + case 18: { + workerLog_.AddEntriesFrom(ref input, _repeated_workerLog_codec); + break; + } + case 26: { + Hostname = input.ReadString(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Executable.cs b/src/TensorFlowNET.Core/Protobuf/Executable.cs new file mode 100644 index 000000000..245c87ffb --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Executable.cs @@ -0,0 +1,340 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/service/cpu/executable.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla.Cpu { + + /// Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/executable.proto + public static partial class ExecutableReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/service/cpu/executable.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ExecutableReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjR0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS9leGVjdXRh", + "YmxlLnByb3RvEgd4bGEuY3B1Gjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9z", + "ZXJ2aWNlL2NwdS94bGFfZnJhbWV3b3JrLnByb3RvGil0ZW5zb3JmbG93L2Nv", + "bXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90byLXAQocWGxhUnVudGltZUNw", + "dUV4ZWN1dGFibGVQcm90bxI+ChZ4bGFfcnVudGltZV9leGVjdXRhYmxlGAEg", + "ASgLMh4ueGxhLlhsYVJ1bnRpbWVFeGVjdXRhYmxlUHJvdG8SQAoVeGxhX2Zy", + "YW1ld29ya19tYXBwaW5nGAIgASgLMiEueGxhLmNwdS5YbGFGcmFtZXdvcmtN", + "YXBwaW5nUHJvdG8SNQoRYnVmZmVyX2Fzc2lnbm1lbnQYAyABKAsyGi54bGEu", + "QnVmZmVyQXNzaWdubWVudFByb3Rv")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Xla.Cpu.XlaFrameworkReflection.Descriptor, global::Xla.HloReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaRuntimeCpuExecutableProto), global::Xla.Cpu.XlaRuntimeCpuExecutableProto.Parser, new[]{ "XlaRuntimeExecutable", "XlaFrameworkMapping", "BufferAssignment" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class XlaRuntimeCpuExecutableProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XlaRuntimeCpuExecutableProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.Cpu.ExecutableReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeCpuExecutableProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeCpuExecutableProto(XlaRuntimeCpuExecutableProto other) : this() { + xlaRuntimeExecutable_ = other.xlaRuntimeExecutable_ != null ? other.xlaRuntimeExecutable_.Clone() : null; + xlaFrameworkMapping_ = other.xlaFrameworkMapping_ != null ? other.xlaFrameworkMapping_.Clone() : null; + bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeCpuExecutableProto Clone() { + return new XlaRuntimeCpuExecutableProto(this); + } + + /// Field number for the "xla_runtime_executable" field. + public const int XlaRuntimeExecutableFieldNumber = 1; + private global::Xla.XlaRuntimeExecutableProto xlaRuntimeExecutable_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.XlaRuntimeExecutableProto XlaRuntimeExecutable { + get { return xlaRuntimeExecutable_; } + set { + xlaRuntimeExecutable_ = value; + } + } + + /// Field number for the "xla_framework_mapping" field. + public const int XlaFrameworkMappingFieldNumber = 2; + private global::Xla.Cpu.XlaFrameworkMappingProto xlaFrameworkMapping_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.Cpu.XlaFrameworkMappingProto XlaFrameworkMapping { + get { return xlaFrameworkMapping_; } + set { + xlaFrameworkMapping_ = value; + } + } + + /// Field number for the "buffer_assignment" field. + public const int BufferAssignmentFieldNumber = 3; + private global::Xla.BufferAssignmentProto bufferAssignment_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.BufferAssignmentProto BufferAssignment { + get { return bufferAssignment_; } + set { + bufferAssignment_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as XlaRuntimeCpuExecutableProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(XlaRuntimeCpuExecutableProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(XlaRuntimeExecutable, other.XlaRuntimeExecutable)) return false; + if (!object.Equals(XlaFrameworkMapping, other.XlaFrameworkMapping)) return false; + if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (xlaRuntimeExecutable_ != null) hash ^= XlaRuntimeExecutable.GetHashCode(); + if (xlaFrameworkMapping_ != null) hash ^= XlaFrameworkMapping.GetHashCode(); + if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (xlaRuntimeExecutable_ != null) { + output.WriteRawTag(10); + output.WriteMessage(XlaRuntimeExecutable); + } + if (xlaFrameworkMapping_ != null) { + output.WriteRawTag(18); + output.WriteMessage(XlaFrameworkMapping); + } + if (bufferAssignment_ != null) { + output.WriteRawTag(26); + output.WriteMessage(BufferAssignment); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (xlaRuntimeExecutable_ != null) { + output.WriteRawTag(10); + output.WriteMessage(XlaRuntimeExecutable); + } + if (xlaFrameworkMapping_ != null) { + output.WriteRawTag(18); + output.WriteMessage(XlaFrameworkMapping); + } + if (bufferAssignment_ != null) { + output.WriteRawTag(26); + output.WriteMessage(BufferAssignment); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (xlaRuntimeExecutable_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaRuntimeExecutable); + } + if (xlaFrameworkMapping_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaFrameworkMapping); + } + if (bufferAssignment_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(XlaRuntimeCpuExecutableProto other) { + if (other == null) { + return; + } + if (other.xlaRuntimeExecutable_ != null) { + if (xlaRuntimeExecutable_ == null) { + XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); + } + XlaRuntimeExecutable.MergeFrom(other.XlaRuntimeExecutable); + } + if (other.xlaFrameworkMapping_ != null) { + if (xlaFrameworkMapping_ == null) { + XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); + } + XlaFrameworkMapping.MergeFrom(other.XlaFrameworkMapping); + } + if (other.bufferAssignment_ != null) { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + BufferAssignment.MergeFrom(other.BufferAssignment); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (xlaRuntimeExecutable_ == null) { + XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); + } + input.ReadMessage(XlaRuntimeExecutable); + break; + } + case 18: { + if (xlaFrameworkMapping_ == null) { + XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); + } + input.ReadMessage(XlaFrameworkMapping); + break; + } + case 26: { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + input.ReadMessage(BufferAssignment); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (xlaRuntimeExecutable_ == null) { + XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); + } + input.ReadMessage(XlaRuntimeExecutable); + break; + } + case 18: { + if (xlaFrameworkMapping_ == null) { + XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); + } + input.ReadMessage(XlaFrameworkMapping); + break; + } + case 26: { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + input.ReadMessage(BufferAssignment); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/FullType.cs b/src/TensorFlowNET.Core/Protobuf/FullType.cs new file mode 100644 index 000000000..dee5571e8 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/FullType.cs @@ -0,0 +1,675 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/full_type.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/full_type.proto + public static partial class FullTypeReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/full_type.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FullTypeReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cil0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZS5wcm90bxIK", + "dGVuc29yZmxvdyJ/CgtGdWxsVHlwZURlZhInCgd0eXBlX2lkGAEgASgOMhYu", + "dGVuc29yZmxvdy5GdWxsVHlwZUlkEiUKBGFyZ3MYAiADKAsyFy50ZW5zb3Jm", + "bG93LkZ1bGxUeXBlRGVmEgsKAXMYAyABKAlIABILCgFpGAQgASgDSABCBgoE", + "YXR0cirDBAoKRnVsbFR5cGVJZBINCglURlRfVU5TRVQQABILCgdURlRfVkFS", + "EAESCwoHVEZUX0FOWRACEg8KC1RGVF9QUk9EVUNUEAMSDQoJVEZUX05BTUVE", + "EAQSEAoMVEZUX0ZPUl9FQUNIEBQSEAoMVEZUX0NBTExBQkxFEGQSDwoKVEZU", + "X1RFTlNPUhDoBxIOCglURlRfQVJSQVkQ6QcSEQoMVEZUX09QVElPTkFMEOoH", + "EhAKC1RGVF9MSVRFUkFMEOsHEhAKC1RGVF9FTkNPREVEEOwHEg0KCFRGVF9C", + "T09MEMgBEg4KCVRGVF9VSU5UOBDJARIPCgpURlRfVUlOVDE2EMoBEg8KClRG", + "VF9VSU5UMzIQywESDwoKVEZUX1VJTlQ2NBDMARINCghURlRfSU5UOBDNARIO", + "CglURlRfSU5UMTYQzgESDgoJVEZUX0lOVDMyEM8BEg4KCVRGVF9JTlQ2NBDQ", + "ARINCghURlRfSEFMRhDRARIOCglURlRfRkxPQVQQ0gESDwoKVEZUX0RPVUJM", + "RRDTARIRCgxURlRfQkZMT0FUMTYQ1wESEgoNVEZUX0NPTVBMRVg2NBDUARIT", + "Cg5URlRfQ09NUExFWDEyOBDVARIPCgpURlRfU1RSSU5HENYBEhAKC1RGVF9E", + "QVRBU0VUEPZOEg8KClRGVF9SQUdHRUQQ904SEQoMVEZUX0lURVJBVE9SEPhO", + "EhMKDlRGVF9NVVRFWF9MT0NLENpPEhcKElRGVF9MRUdBQ1lfVkFSSUFOVBDb", + "T0KBAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQg5GdWxsVHlwZVByb3Rv", + "c1ABWlBnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3Jm", + "bG93L2dvL2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZV9nb19wcm90b/gBAWIG", + "cHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.FullTypeId), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FullTypeDef), global::Tensorflow.FullTypeDef.Parser, new[]{ "TypeId", "Args", "S", "I" }, new[]{ "Attr" }, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// LINT.IfChange + /// Experimental. Represents the complete type information of a TensorFlow value. + /// + public enum FullTypeId { + /// + /// The default represents an uninitialized values. + /// + [pbr::OriginalName("TFT_UNSET")] TftUnset = 0, + /// + /// Type variables may serve as placeholder for any other type ID in type + /// templates. + /// + /// Examples: + /// TFT_DATASET[TFT_VAR["T"]] is a Dataset returning a type indicated by "T". + /// TFT_TENSOR[TFT_VAR["T"]] is a Tensor of n element type indicated by "T". + /// TFT_TENSOR[TFT_VAR["T"]], TFT_TENSOR[TFT_VAR["T"]] are two tensors of + /// identical element types. + /// TFT_TENSOR[TFT_VAR["P"]], TFT_TENSOR[TFT_VAR["Q"]] are two tensors of + /// independent element types. + /// + [pbr::OriginalName("TFT_VAR")] TftVar = 1, + /// + /// Wildcard type. Describes a parameter of unknown type. In TensorFlow, that + /// can mean either a "Top" type (accepts any type), or a dynamically typed + /// object whose type is unknown in context. + /// Important: "unknown" does not necessarily mean undeterminable! + /// + [pbr::OriginalName("TFT_ANY")] TftAny = 2, + /// + /// The algebraic product type. This is an algebraic type that may be used just + /// for logical grouping. Not to confused with TFT_TUPLE which describes a + /// concrete object of several elements. + /// + /// Example: + /// TFT_DATASET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]]] + /// is a Dataset producing two tensors, an integer one and a float one. + /// + [pbr::OriginalName("TFT_PRODUCT")] TftProduct = 3, + /// + /// Represents a named field, with the name stored in the attribute. + /// + /// Parametrization: + /// TFT_NAMED[<type>]{<name>} + /// * <type> is the type of the field + /// * <name> is the field name, as string (thpugh can theoretically be an int + /// as well) + /// + /// Example: + /// TFT_RECORD[ + /// TFT_NAMED[TFT_TENSOR[TFT_INT32]]{'foo'}, + /// TFT_NAMED[TFT_TENSOR[TFT_FLOAT32]]{'bar'}, + /// ] + /// is a structure with two fields, an int tensor "foo" and a float tensor + /// "bar". + /// + [pbr::OriginalName("TFT_NAMED")] TftNamed = 4, + /// + /// Template definition. Expands the variables by repeating a template as + /// arguments of container. + /// + /// Parametrization: + /// TFT_FOR_EACH[<container_type>, <template>, <expansions>] + /// * <container_type> is the type of the container that the template will be + /// expanded into + /// * <template> is any type definition that potentially contains type + /// variables + /// * <expansions> is a TFT_VAR and may include more types in the future + /// + /// Example: + /// TFT_FOR_EACH[ + /// TFT_PRODUCT, + /// TFT_TENSOR[TFT_VAR["t"]], + /// TFT_VAR["t"] + /// ] + /// will substitute a T = TFT_INT32 to TFT_PRODUCT[TFT_TENSOR[TFT_INT32]] + /// and a T = (TFT_INT32, TFT_INT64) to + /// TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_INT64]]. + /// + [pbr::OriginalName("TFT_FOR_EACH")] TftForEach = 20, + /// + /// Callable types describe functions and ops. + /// + /// Parametrization: + /// TFT_CALLABLE[<arg type>, <return type>] + /// * <arg type> is the type of the arguments; TFT_PRODUCT represents + /// multiple + /// arguments. + /// * <return type> is the return type; TFT_PRODUCT represents multiple + /// return values (that means that callables returning multiple things + /// don't necessarily return a single tuple). + /// + /// Example: + /// TFT_CALLABLE[ + /// TFT_ANY, + /// TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]], + /// ] + /// is a callable with unspecified (for now) input arguments, and + /// two return values of type tensor. + /// + [pbr::OriginalName("TFT_CALLABLE")] TftCallable = 100, + /// + /// The usual Tensor. This is a parametric type. + /// + /// Parametrization: + /// TFT_TENSOR[<element type>, <shape type>] + /// * <element type> is currently limited to one of the element types + /// defined below. + /// * <shape type> is not yet defined, and may only be TFT_UNKNOWN for now. + /// + /// A TFT_SHAPE type will be defined in the future. + /// + /// Example: + /// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN] + /// is a Tensor of int32 element type and unknown shape. + /// + /// TODO(mdan): Define TFT_SHAPE and add more examples. + /// + [pbr::OriginalName("TFT_TENSOR")] TftTensor = 1000, + /// + /// Array (or tensorflow::TensorList in the variant type registry). + /// Note: this is not to be confused with the deprecated `TensorArray*` ops + /// which are not supported by FullType. + /// This type represents a random-access list whose elements can be + /// described by a single type. Although immutable, Array is expected to + /// support efficient mutation semantics (i.e. element update) in the + /// user-facing API. + /// The element type may be generic or even TFT_ANY for a heterogenous list. + /// + /// Parametrization: + /// TFT_ARRAY[<element type>] + /// * <element type> may be any concrete type. + /// + /// Examples: + /// TFT_ARRAY[TFT_TENSOR[TFT_INT32]] is a TensorArray holding int32 Tensors + /// of any shape. + /// TFT_ARRAY[TFT_TENSOR[TFT_UNKNOWN]] is a TensorArray holding Tensors of + /// mixed element types. + /// TFT_ARRAY[TFT_UNKNOWN] is a TensorArray holding any element type. + /// TFT_ARRAY[] is equivalent to TFT_ARRAY[TFT_UNKNOWN]. + /// TFT_ARRAY[TFT_ARRAY[]] is an array or arrays (of unknown types). + /// + [pbr::OriginalName("TFT_ARRAY")] TftArray = 1001, + /// + /// Optional (or tensorflow::OptionalVariant in the variant type registry). + /// This type represents a value that may either hold an element of a single + /// specified type, or nothing at all. + /// + /// Parametrization: + /// TFT_OPTIONAL[<element type>] + /// * <element type> may be any concrete type. + /// + /// Examples: + /// TFT_OPTIONAL[TFT_TENSOR[TFT_INT32]] is an Optional holding an int32 + /// Tensor of any shape. + /// + [pbr::OriginalName("TFT_OPTIONAL")] TftOptional = 1002, + /// + /// Literal types describe compile-time constant values. + /// Literal types may also participate in dependent types. + /// + /// Parametrization: + /// TFT_LITERAL[<value type>]{<value>} + /// * <value type> may be any concrete type compatible that can hold <value> + /// * <value> is the type's attribute, and holds the actual literal value + /// + /// Examples: + /// TFT_LITERAL[TFT_INT32]{1} is the compile-time constant 1. + /// + [pbr::OriginalName("TFT_LITERAL")] TftLiteral = 1003, + /// + /// Encoding types describe a value of a certain type, encoded as a different + /// type. + /// + /// Parametrization: + /// TFT_ENCODED[<encoded type>, <encoding type>] + /// * <encoded type> may be any type + /// * <encoding type> may be any type + /// + /// Examples: + /// TFT_ENCODING[TFT_INT32, TFT_STRING] is an integer encoded as string. + /// + [pbr::OriginalName("TFT_ENCODED")] TftEncoded = 1004, + /// + /// The bool element type. + /// TODO(mdan): Quantized types, legacy representations (e.g. ref) + /// + [pbr::OriginalName("TFT_BOOL")] TftBool = 200, + /// + /// Integer element types. + /// + [pbr::OriginalName("TFT_UINT8")] TftUint8 = 201, + [pbr::OriginalName("TFT_UINT16")] TftUint16 = 202, + [pbr::OriginalName("TFT_UINT32")] TftUint32 = 203, + [pbr::OriginalName("TFT_UINT64")] TftUint64 = 204, + [pbr::OriginalName("TFT_INT8")] TftInt8 = 205, + [pbr::OriginalName("TFT_INT16")] TftInt16 = 206, + [pbr::OriginalName("TFT_INT32")] TftInt32 = 207, + [pbr::OriginalName("TFT_INT64")] TftInt64 = 208, + /// + /// Floating-point element types. + /// + [pbr::OriginalName("TFT_HALF")] TftHalf = 209, + [pbr::OriginalName("TFT_FLOAT")] TftFloat = 210, + [pbr::OriginalName("TFT_DOUBLE")] TftDouble = 211, + [pbr::OriginalName("TFT_BFLOAT16")] TftBfloat16 = 215, + /// + /// Complex element types. + /// TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead? + /// + [pbr::OriginalName("TFT_COMPLEX64")] TftComplex64 = 212, + [pbr::OriginalName("TFT_COMPLEX128")] TftComplex128 = 213, + /// + /// The string element type. + /// + [pbr::OriginalName("TFT_STRING")] TftString = 214, + /// + /// Datasets created by tf.data ops and APIs. Datasets have generator/iterable + /// semantics, that is, one can construct an iterator from them. Like + /// Array, they are considered to return elements that can be described + /// by a single type. Unlike Array, they do not support random access or + /// mutation, and can potentially produce an infinite number of elements. + /// A datasets can produce logical structures (e.g. multiple elements). This + /// is expressed using TFT_PRODUCT. + /// + /// Parametrization: TFT_DATASET[<element type>]. + /// * <element type> may be a concrete type or a type symbol. It represents + /// the data type of the elements produced by the dataset. + /// + /// Examples: + /// TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32 + /// Tensors of unknown shape. + /// TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is + /// a Dataset producing pairs of Tensors, one integer and one float. + /// Note: The high ID number is to prepare for the eventuality that Datasets + /// will be supported by user types in the future. + /// + [pbr::OriginalName("TFT_DATASET")] TftDataset = 10102, + /// + /// A ragged tensor created by tf.ragged ops and APIs. + /// + /// Parametrization: TFT_RAGGED[<element_type>]. + /// + [pbr::OriginalName("TFT_RAGGED")] TftRagged = 10103, + /// + /// Iterators created by tf.data ops and APIs. Very similar to Datasets, except + /// they are mutable. + /// + /// Parametrization: TFT_ITERATOR[<element type>]. + /// * <element type> may be a concrete type or a type symbol. It represents + /// the data type of the elements produced by the dataset. + /// + [pbr::OriginalName("TFT_ITERATOR")] TftIterator = 10104, + /// + /// A mutex lock tensor, produced by tf.raw_ops.MutexLock. + /// Unlike strict execution models, where ownership of a lock is denoted by + /// "running after the lock has been acquired", in non-strict mode, lock + /// ownership is in the true sense: "the op argument representing the lock is + /// available". + /// Mutex locks are the dynamic counterpart of control dependencies. + /// TODO(mdan): Properly document this thing. + /// + /// Parametrization: TFT_MUTEX_LOCK[]. + /// + [pbr::OriginalName("TFT_MUTEX_LOCK")] TftMutexLock = 10202, + /// + /// The equivalent of a Tensor with DT_VARIANT dtype, kept here to simplify + /// translation. This type should not normally appear after type inference. + /// Note that LEGACY_VARIANT != ANY: TENSOR[INT32] is a subtype of ANY, but is + /// not a subtype of LEGACY_VARIANT. + /// + [pbr::OriginalName("TFT_LEGACY_VARIANT")] TftLegacyVariant = 10203, + } + + #endregion + + #region Messages + /// + /// Highly experimental and very likely to change. + /// This encoding uses tags instead of dedicated messages for regularity. In + /// particular the encoding imposes no restrictions on what the parameters of any + /// type should be, which in particular needs to be true for type symbols. + /// + public sealed partial class FullTypeDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FullTypeDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FullTypeReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FullTypeDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FullTypeDef(FullTypeDef other) : this() { + typeId_ = other.typeId_; + args_ = other.args_.Clone(); + switch (other.AttrCase) { + case AttrOneofCase.S: + S = other.S; + break; + case AttrOneofCase.I: + I = other.I; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FullTypeDef Clone() { + return new FullTypeDef(this); + } + + /// Field number for the "type_id" field. + public const int TypeIdFieldNumber = 1; + private global::Tensorflow.FullTypeId typeId_ = global::Tensorflow.FullTypeId.TftUnset; + /// + /// The principal type represented by this object. This may be a concrete type + /// (Tensor, Dataset) a type variable (used for dependent types) a type + /// symbol (Any, Union). See FullTypeId for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FullTypeId TypeId { + get { return typeId_; } + set { + typeId_ = value; + } + } + + /// Field number for the "args" field. + public const int ArgsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_args_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.FullTypeDef.Parser); + private readonly pbc::RepeatedField args_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Args { + get { return args_; } + } + + /// Field number for the "s" field. + public const int SFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string S { + get { return attrCase_ == AttrOneofCase.S ? (string) attr_ : ""; } + set { + attr_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + attrCase_ = AttrOneofCase.S; + } + } + + /// Field number for the "i" field. + public const int IFieldNumber = 4; + /// + /// TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long I { + get { return attrCase_ == AttrOneofCase.I ? (long) attr_ : 0L; } + set { + attr_ = value; + attrCase_ = AttrOneofCase.I; + } + } + + private object attr_; + /// Enum of possible cases for the "attr" oneof. + public enum AttrOneofCase { + None = 0, + S = 3, + I = 4, + } + private AttrOneofCase attrCase_ = AttrOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AttrOneofCase AttrCase { + get { return attrCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearAttr() { + attrCase_ = AttrOneofCase.None; + attr_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FullTypeDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FullTypeDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TypeId != other.TypeId) return false; + if(!args_.Equals(other.args_)) return false; + if (S != other.S) return false; + if (I != other.I) return false; + if (AttrCase != other.AttrCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) hash ^= TypeId.GetHashCode(); + hash ^= args_.GetHashCode(); + if (attrCase_ == AttrOneofCase.S) hash ^= S.GetHashCode(); + if (attrCase_ == AttrOneofCase.I) hash ^= I.GetHashCode(); + hash ^= (int) attrCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { + output.WriteRawTag(8); + output.WriteEnum((int) TypeId); + } + args_.WriteTo(output, _repeated_args_codec); + if (attrCase_ == AttrOneofCase.S) { + output.WriteRawTag(26); + output.WriteString(S); + } + if (attrCase_ == AttrOneofCase.I) { + output.WriteRawTag(32); + output.WriteInt64(I); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { + output.WriteRawTag(8); + output.WriteEnum((int) TypeId); + } + args_.WriteTo(ref output, _repeated_args_codec); + if (attrCase_ == AttrOneofCase.S) { + output.WriteRawTag(26); + output.WriteString(S); + } + if (attrCase_ == AttrOneofCase.I) { + output.WriteRawTag(32); + output.WriteInt64(I); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TypeId); + } + size += args_.CalculateSize(_repeated_args_codec); + if (attrCase_ == AttrOneofCase.S) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(S); + } + if (attrCase_ == AttrOneofCase.I) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(I); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FullTypeDef other) { + if (other == null) { + return; + } + if (other.TypeId != global::Tensorflow.FullTypeId.TftUnset) { + TypeId = other.TypeId; + } + args_.Add(other.args_); + switch (other.AttrCase) { + case AttrOneofCase.S: + S = other.S; + break; + case AttrOneofCase.I: + I = other.I; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TypeId = (global::Tensorflow.FullTypeId) input.ReadEnum(); + break; + } + case 18: { + args_.AddEntriesFrom(input, _repeated_args_codec); + break; + } + case 26: { + S = input.ReadString(); + break; + } + case 32: { + I = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + TypeId = (global::Tensorflow.FullTypeId) input.ReadEnum(); + break; + } + case 18: { + args_.AddEntriesFrom(ref input, _repeated_args_codec); + break; + } + case 26: { + S = input.ReadString(); + break; + } + case 32: { + I = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Function.cs b/src/TensorFlowNET.Core/Protobuf/Function.cs new file mode 100644 index 000000000..800e64442 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Function.cs @@ -0,0 +1,1384 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/function.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/function.proto + public static partial class FunctionReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/function.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FunctionReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bmN0aW9uLnByb3RvEgp0", + "ZW5zb3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2F0dHJfdmFs", + "dWUucHJvdG8aKHRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvbm9kZV9kZWYu", + "cHJvdG8aJnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvb3BfZGVmLnByb3Rv", + "IqgBChJGdW5jdGlvbkRlZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50", + "ZW5zb3JmbG93LkZ1bmN0aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVu", + "c29yZmxvdy5HcmFkaWVudERlZhI8ChRyZWdpc3RlcmVkX2dyYWRpZW50cxgD", + "IAMoCzIeLnRlbnNvcmZsb3cuUmVnaXN0ZXJlZEdyYWRpZW50IsQGCgtGdW5j", + "dGlvbkRlZhIkCglzaWduYXR1cmUYASABKAsyES50ZW5zb3JmbG93Lk9wRGVm", + "Ei8KBGF0dHIYBSADKAsyIS50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLkF0dHJF", + "bnRyeRI2CghhcmdfYXR0chgHIAMoCzIkLnRlbnNvcmZsb3cuRnVuY3Rpb25E", + "ZWYuQXJnQXR0ckVudHJ5ElAKFnJlc291cmNlX2FyZ191bmlxdWVfaWQYCCAD", + "KAsyMC50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLlJlc291cmNlQXJnVW5pcXVl", + "SWRFbnRyeRIlCghub2RlX2RlZhgDIAMoCzITLnRlbnNvcmZsb3cuTm9kZURl", + "ZhItCgNyZXQYBCADKAsyIC50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLlJldEVu", + "dHJ5EjwKC2NvbnRyb2xfcmV0GAYgAygLMicudGVuc29yZmxvdy5GdW5jdGlv", + "bkRlZi5Db250cm9sUmV0RW50cnkaQgoJQXR0ckVudHJ5EgsKA2tleRgBIAEo", + "CRIkCgV2YWx1ZRgCIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlOgI4ARqI", + "AQoIQXJnQXR0cnMSOAoEYXR0chgBIAMoCzIqLnRlbnNvcmZsb3cuRnVuY3Rp", + "b25EZWYuQXJnQXR0cnMuQXR0ckVudHJ5GkIKCUF0dHJFbnRyeRILCgNrZXkY", + "ASABKAkSJAoFdmFsdWUYAiABKAsyFS50ZW5zb3JmbG93LkF0dHJWYWx1ZToC", + "OAEaUAoMQXJnQXR0ckVudHJ5EgsKA2tleRgBIAEoDRIvCgV2YWx1ZRgCIAEo", + "CzIgLnRlbnNvcmZsb3cuRnVuY3Rpb25EZWYuQXJnQXR0cnM6AjgBGjoKGFJl", + "c291cmNlQXJnVW5pcXVlSWRFbnRyeRILCgNrZXkYASABKA0SDQoFdmFsdWUY", + "AiABKA06AjgBGioKCFJldEVudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgC", + "IAEoCToCOAEaMQoPQ29udHJvbFJldEVudHJ5EgsKA2tleRgBIAEoCRINCgV2", + "YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoLR3JhZGllbnREZWYSFQoNZnVuY3Rp", + "b25fbmFtZRgBIAEoCRIVCg1ncmFkaWVudF9mdW5jGAIgASgJIkcKElJlZ2lz", + "dGVyZWRHcmFkaWVudBIVCg1ncmFkaWVudF9mdW5jGAEgASgJEhoKEnJlZ2lz", + "dGVyZWRfb3BfdHlwZRgCIAEoCUKAAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3", + "b3JrQg5GdW5jdGlvblByb3Rvc1ABWk9naXRodWIuY29tL3RlbnNvcmZsb3cv", + "dGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL2Z1bmN0", + "aW9uX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient", "RegisteredGradients" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "ArgAttr", "ResourceArgUniqueId", "NodeDef", "Ret", "ControlRet" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef.Types.ArgAttrs), global::Tensorflow.FunctionDef.Types.ArgAttrs.Parser, new[]{ "Attr" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + null, null, null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RegisteredGradient), global::Tensorflow.RegisteredGradient.Parser, new[]{ "GradientFunc", "RegisteredOpType" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// A library is a set of named functions. + /// + public sealed partial class FunctionDefLibrary : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDefLibrary()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDefLibrary() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDefLibrary(FunctionDefLibrary other) : this() { + function_ = other.function_.Clone(); + gradient_ = other.gradient_.Clone(); + registeredGradients_ = other.registeredGradients_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDefLibrary Clone() { + return new FunctionDefLibrary(this); + } + + /// Field number for the "function" field. + public const int FunctionFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_function_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser); + private readonly pbc::RepeatedField function_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Function { + get { return function_; } + } + + /// Field number for the "gradient" field. + public const int GradientFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_gradient_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser); + private readonly pbc::RepeatedField gradient_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Gradient { + get { return gradient_; } + } + + /// Field number for the "registered_gradients" field. + public const int RegisteredGradientsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_registeredGradients_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.RegisteredGradient.Parser); + private readonly pbc::RepeatedField registeredGradients_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField RegisteredGradients { + get { return registeredGradients_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FunctionDefLibrary); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FunctionDefLibrary other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!function_.Equals(other.function_)) return false; + if(!gradient_.Equals(other.gradient_)) return false; + if(!registeredGradients_.Equals(other.registeredGradients_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= function_.GetHashCode(); + hash ^= gradient_.GetHashCode(); + hash ^= registeredGradients_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + function_.WriteTo(output, _repeated_function_codec); + gradient_.WriteTo(output, _repeated_gradient_codec); + registeredGradients_.WriteTo(output, _repeated_registeredGradients_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + function_.WriteTo(ref output, _repeated_function_codec); + gradient_.WriteTo(ref output, _repeated_gradient_codec); + registeredGradients_.WriteTo(ref output, _repeated_registeredGradients_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += function_.CalculateSize(_repeated_function_codec); + size += gradient_.CalculateSize(_repeated_gradient_codec); + size += registeredGradients_.CalculateSize(_repeated_registeredGradients_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FunctionDefLibrary other) { + if (other == null) { + return; + } + function_.Add(other.function_); + gradient_.Add(other.gradient_); + registeredGradients_.Add(other.registeredGradients_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + function_.AddEntriesFrom(input, _repeated_function_codec); + break; + } + case 18: { + gradient_.AddEntriesFrom(input, _repeated_gradient_codec); + break; + } + case 26: { + registeredGradients_.AddEntriesFrom(input, _repeated_registeredGradients_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + function_.AddEntriesFrom(ref input, _repeated_function_codec); + break; + } + case 18: { + gradient_.AddEntriesFrom(ref input, _repeated_gradient_codec); + break; + } + case 26: { + registeredGradients_.AddEntriesFrom(ref input, _repeated_registeredGradients_codec); + break; + } + } + } + } + #endif + + } + + /// + /// A function can be instantiated when the runtime can bind every attr + /// with a value. When a GraphDef has a call to a function, it must + /// have binding for every attr defined in the signature. + /// + /// TODO(zhifengc): + /// * device spec, etc. + /// + public sealed partial class FunctionDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDef(FunctionDef other) : this() { + signature_ = other.signature_ != null ? other.signature_.Clone() : null; + attr_ = other.attr_.Clone(); + argAttr_ = other.argAttr_.Clone(); + resourceArgUniqueId_ = other.resourceArgUniqueId_.Clone(); + nodeDef_ = other.nodeDef_.Clone(); + ret_ = other.ret_.Clone(); + controlRet_ = other.controlRet_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionDef Clone() { + return new FunctionDef(this); + } + + /// Field number for the "signature" field. + public const int SignatureFieldNumber = 1; + private global::Tensorflow.OpDef signature_; + /// + /// The definition of the function's name, arguments, return values, + /// attrs etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.OpDef Signature { + get { return signature_; } + set { + signature_ = value; + } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 5; + private static readonly pbc::MapField.Codec _map_attr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); + private readonly pbc::MapField attr_ = new pbc::MapField(); + /// + /// Attributes specific to this function definition. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Attr { + get { return attr_; } + } + + /// Field number for the "arg_attr" field. + public const int ArgAttrFieldNumber = 7; + private static readonly pbc::MapField.Codec _map_argAttr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForUInt32(8, 0), pb::FieldCodec.ForMessage(18, global::Tensorflow.FunctionDef.Types.ArgAttrs.Parser), 58); + private readonly pbc::MapField argAttr_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ArgAttr { + get { return argAttr_; } + } + + /// Field number for the "resource_arg_unique_id" field. + public const int ResourceArgUniqueIdFieldNumber = 8; + private static readonly pbc::MapField.Codec _map_resourceArgUniqueId_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForUInt32(8, 0), pb::FieldCodec.ForUInt32(16, 0), 66); + private readonly pbc::MapField resourceArgUniqueId_ = new pbc::MapField(); + /// + /// Unique IDs for each resource argument, used to track aliasing resources. If + /// Argument A and Argument B alias each other, then + /// resource_arg_unique_ids[A.index] == resource_arg_unique_ids[B.index]. + /// + /// If this field is empty, none of the arguments could alias; otherwise, every + /// resource argument should have an entry in this field. + /// + /// When instantiated, the unique IDs will be attached to the _Arg nodes' + /// "_resource_arg_unique_id" attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ResourceArgUniqueId { + get { return resourceArgUniqueId_; } + } + + /// Field number for the "node_def" field. + public const int NodeDefFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_nodeDef_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser); + private readonly pbc::RepeatedField nodeDef_ = new pbc::RepeatedField(); + /// + /// By convention, "op" in node_def is resolved by consulting with a + /// user-defined library first. If not resolved, "func" is assumed to + /// be a builtin op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeDef { + get { return nodeDef_; } + } + + /// Field number for the "ret" field. + public const int RetFieldNumber = 4; + private static readonly pbc::MapField.Codec _map_ret_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 34); + private readonly pbc::MapField ret_ = new pbc::MapField(); + /// + /// A mapping from the output arg names from `signature` to the + /// outputs from `node_def` that should be returned by the function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Ret { + get { return ret_; } + } + + /// Field number for the "control_ret" field. + public const int ControlRetFieldNumber = 6; + private static readonly pbc::MapField.Codec _map_controlRet_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 50); + private readonly pbc::MapField controlRet_ = new pbc::MapField(); + /// + /// A mapping from control output names from `signature` to node names in + /// `node_def` which should be control outputs of this function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ControlRet { + get { return controlRet_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FunctionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FunctionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Signature, other.Signature)) return false; + if (!Attr.Equals(other.Attr)) return false; + if (!ArgAttr.Equals(other.ArgAttr)) return false; + if (!ResourceArgUniqueId.Equals(other.ResourceArgUniqueId)) return false; + if(!nodeDef_.Equals(other.nodeDef_)) return false; + if (!Ret.Equals(other.Ret)) return false; + if (!ControlRet.Equals(other.ControlRet)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (signature_ != null) hash ^= Signature.GetHashCode(); + hash ^= Attr.GetHashCode(); + hash ^= ArgAttr.GetHashCode(); + hash ^= ResourceArgUniqueId.GetHashCode(); + hash ^= nodeDef_.GetHashCode(); + hash ^= Ret.GetHashCode(); + hash ^= ControlRet.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (signature_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Signature); + } + nodeDef_.WriteTo(output, _repeated_nodeDef_codec); + ret_.WriteTo(output, _map_ret_codec); + attr_.WriteTo(output, _map_attr_codec); + controlRet_.WriteTo(output, _map_controlRet_codec); + argAttr_.WriteTo(output, _map_argAttr_codec); + resourceArgUniqueId_.WriteTo(output, _map_resourceArgUniqueId_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (signature_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Signature); + } + nodeDef_.WriteTo(ref output, _repeated_nodeDef_codec); + ret_.WriteTo(ref output, _map_ret_codec); + attr_.WriteTo(ref output, _map_attr_codec); + controlRet_.WriteTo(ref output, _map_controlRet_codec); + argAttr_.WriteTo(ref output, _map_argAttr_codec); + resourceArgUniqueId_.WriteTo(ref output, _map_resourceArgUniqueId_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (signature_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature); + } + size += attr_.CalculateSize(_map_attr_codec); + size += argAttr_.CalculateSize(_map_argAttr_codec); + size += resourceArgUniqueId_.CalculateSize(_map_resourceArgUniqueId_codec); + size += nodeDef_.CalculateSize(_repeated_nodeDef_codec); + size += ret_.CalculateSize(_map_ret_codec); + size += controlRet_.CalculateSize(_map_controlRet_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FunctionDef other) { + if (other == null) { + return; + } + if (other.signature_ != null) { + if (signature_ == null) { + Signature = new global::Tensorflow.OpDef(); + } + Signature.MergeFrom(other.Signature); + } + attr_.Add(other.attr_); + argAttr_.Add(other.argAttr_); + resourceArgUniqueId_.Add(other.resourceArgUniqueId_); + nodeDef_.Add(other.nodeDef_); + ret_.Add(other.ret_); + controlRet_.Add(other.controlRet_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (signature_ == null) { + Signature = new global::Tensorflow.OpDef(); + } + input.ReadMessage(Signature); + break; + } + case 26: { + nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec); + break; + } + case 34: { + ret_.AddEntriesFrom(input, _map_ret_codec); + break; + } + case 42: { + attr_.AddEntriesFrom(input, _map_attr_codec); + break; + } + case 50: { + controlRet_.AddEntriesFrom(input, _map_controlRet_codec); + break; + } + case 58: { + argAttr_.AddEntriesFrom(input, _map_argAttr_codec); + break; + } + case 66: { + resourceArgUniqueId_.AddEntriesFrom(input, _map_resourceArgUniqueId_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (signature_ == null) { + Signature = new global::Tensorflow.OpDef(); + } + input.ReadMessage(Signature); + break; + } + case 26: { + nodeDef_.AddEntriesFrom(ref input, _repeated_nodeDef_codec); + break; + } + case 34: { + ret_.AddEntriesFrom(ref input, _map_ret_codec); + break; + } + case 42: { + attr_.AddEntriesFrom(ref input, _map_attr_codec); + break; + } + case 50: { + controlRet_.AddEntriesFrom(ref input, _map_controlRet_codec); + break; + } + case 58: { + argAttr_.AddEntriesFrom(ref input, _map_argAttr_codec); + break; + } + case 66: { + resourceArgUniqueId_.AddEntriesFrom(ref input, _map_resourceArgUniqueId_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the FunctionDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Attributes for function arguments. These attributes are the same set of + /// valid attributes as to _Arg nodes. + /// + public sealed partial class ArgAttrs : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ArgAttrs()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ArgAttrs() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ArgAttrs(ArgAttrs other) : this() { + attr_ = other.attr_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ArgAttrs Clone() { + return new ArgAttrs(this); + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_attr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 10); + private readonly pbc::MapField attr_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Attr { + get { return attr_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ArgAttrs); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ArgAttrs other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Attr.Equals(other.Attr)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= Attr.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + attr_.WriteTo(output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + attr_.WriteTo(ref output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += attr_.CalculateSize(_map_attr_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ArgAttrs other) { + if (other == null) { + return; + } + attr_.Add(other.attr_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + attr_.AddEntriesFrom(input, _map_attr_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + attr_.AddEntriesFrom(ref input, _map_attr_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// GradientDef defines the gradient function of a function defined in + /// a function library. + /// + /// A gradient function g (specified by gradient_func) for a function f + /// (specified by function_name) must follow the following: + /// + /// The function 'f' must be a numerical function which takes N inputs + /// and produces M outputs. Its gradient function 'g', which is a + /// function taking N + M inputs and produces N outputs. + /// + /// I.e. if we have + /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + /// then, g is + /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + /// dL/dy1, dL/dy2, ..., dL/dy_M), + /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + /// loss function). dL/dx_i is the partial derivative of L with respect + /// to x_i. + /// + public sealed partial class GradientDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GradientDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GradientDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GradientDef(GradientDef other) : this() { + functionName_ = other.functionName_; + gradientFunc_ = other.gradientFunc_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GradientDef Clone() { + return new GradientDef(this); + } + + /// Field number for the "function_name" field. + public const int FunctionNameFieldNumber = 1; + private string functionName_ = ""; + /// + /// The function name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FunctionName { + get { return functionName_; } + set { + functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "gradient_func" field. + public const int GradientFuncFieldNumber = 2; + private string gradientFunc_ = ""; + /// + /// The gradient function's name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string GradientFunc { + get { return gradientFunc_; } + set { + gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GradientDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GradientDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FunctionName != other.FunctionName) return false; + if (GradientFunc != other.GradientFunc) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode(); + if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (FunctionName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FunctionName); + } + if (GradientFunc.Length != 0) { + output.WriteRawTag(18); + output.WriteString(GradientFunc); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (FunctionName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FunctionName); + } + if (GradientFunc.Length != 0) { + output.WriteRawTag(18); + output.WriteString(GradientFunc); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (FunctionName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName); + } + if (GradientFunc.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GradientDef other) { + if (other == null) { + return; + } + if (other.FunctionName.Length != 0) { + FunctionName = other.FunctionName; + } + if (other.GradientFunc.Length != 0) { + GradientFunc = other.GradientFunc; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FunctionName = input.ReadString(); + break; + } + case 18: { + GradientFunc = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + FunctionName = input.ReadString(); + break; + } + case 18: { + GradientFunc = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// RegisteredGradient stores a gradient function that is registered in the + /// gradients library and used in the ops of a function in the function library. + /// Unlike GradientDef, these gradients are identified by op type, and not + /// directly linked to any function. + /// + public sealed partial class RegisteredGradient : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegisteredGradient()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredGradient() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredGradient(RegisteredGradient other) : this() { + gradientFunc_ = other.gradientFunc_; + registeredOpType_ = other.registeredOpType_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredGradient Clone() { + return new RegisteredGradient(this); + } + + /// Field number for the "gradient_func" field. + public const int GradientFuncFieldNumber = 1; + private string gradientFunc_ = ""; + /// + /// The gradient function's name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string GradientFunc { + get { return gradientFunc_; } + set { + gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "registered_op_type" field. + public const int RegisteredOpTypeFieldNumber = 2; + private string registeredOpType_ = ""; + /// + /// The gradient function's registered op type. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RegisteredOpType { + get { return registeredOpType_; } + set { + registeredOpType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RegisteredGradient); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RegisteredGradient other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (GradientFunc != other.GradientFunc) return false; + if (RegisteredOpType != other.RegisteredOpType) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); + if (RegisteredOpType.Length != 0) hash ^= RegisteredOpType.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (GradientFunc.Length != 0) { + output.WriteRawTag(10); + output.WriteString(GradientFunc); + } + if (RegisteredOpType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RegisteredOpType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (GradientFunc.Length != 0) { + output.WriteRawTag(10); + output.WriteString(GradientFunc); + } + if (RegisteredOpType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RegisteredOpType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (GradientFunc.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); + } + if (RegisteredOpType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RegisteredOpType); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RegisteredGradient other) { + if (other == null) { + return; + } + if (other.GradientFunc.Length != 0) { + GradientFunc = other.GradientFunc; + } + if (other.RegisteredOpType.Length != 0) { + RegisteredOpType = other.RegisteredOpType; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + GradientFunc = input.ReadString(); + break; + } + case 18: { + RegisteredOpType = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + GradientFunc = input.ReadString(); + break; + } + case 18: { + RegisteredOpType = input.ReadString(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Gen.bat b/src/TensorFlowNET.Core/Protobuf/Gen.bat new file mode 100644 index 000000000..6b898bcb8 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Gen.bat @@ -0,0 +1,63 @@ +@ECHO OFF + +set SRC_DIR=D:/development/tf.net/tensorflow-2.11.0 +set DST_DIR=D:/development/tf.net/gen_proto + +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/attr_value.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/node_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/versions.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/function.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/variable.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/cost_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/step_stats.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/allocation_description.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_description.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/api_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/device_attributes.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/graph_transfer_info.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/kernel_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/log_memory.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_slice.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/full_type.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto +ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/coordination_service.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/coordination_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/service_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/data_service.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/trackable_object_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/struct.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/verifier_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/util/event.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/util/memmapped_file_system.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/tsl/protobuf/histogram.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/xla.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/xla_data.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/service/hlo.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/pjrt/distributed/protocol.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/service/gpu/executable.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/service/cpu/executable.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/compiler/xla/service/cpu/xla_framework.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/framework/cpp_shape_inference.proto + +ECHO protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/keras/protobuf/projector_config.proto +ECHO protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/keras/protobuf/versions.proto +ECHO protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/keras/protobuf/saved_metadata.proto + +PAUSE \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Protobuf/Graph.cs b/src/TensorFlowNET.Core/Protobuf/Graph.cs new file mode 100644 index 000000000..0b7644eba --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Graph.cs @@ -0,0 +1,400 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/graph.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/graph.proto + public static partial class GraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static GraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2dyYXBoLnByb3RvEgp0ZW5z", + "b3JmbG93Gih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bmN0aW9uLnBy", + "b3RvGih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL25vZGVfZGVmLnByb3Rv", + "Gih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZlcnNpb25zLnByb3RvIp0B", + "CghHcmFwaERlZhIhCgRub2RlGAEgAygLMhMudGVuc29yZmxvdy5Ob2RlRGVm", + "EigKCHZlcnNpb25zGAQgASgLMhYudGVuc29yZmxvdy5WZXJzaW9uRGVmEhMK", + "B3ZlcnNpb24YAyABKAVCAhgBEi8KB2xpYnJhcnkYAiABKAsyHi50ZW5zb3Jm", + "bG93LkZ1bmN0aW9uRGVmTGlicmFyeUJ6ChhvcmcudGVuc29yZmxvdy5mcmFt", + "ZXdvcmtCC0dyYXBoUHJvdG9zUAFaTGdpdGh1Yi5jb20vdGVuc29yZmxvdy90", + "ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmsvZ3JhcGhf", + "Z29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Represents the graph of operations + /// + public sealed partial class GraphDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphDef(GraphDef other) : this() { + node_ = other.node_.Clone(); + versions_ = other.versions_ != null ? other.versions_.Clone() : null; + version_ = other.version_; + library_ = other.library_ != null ? other.library_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphDef Clone() { + return new GraphDef(this); + } + + /// Field number for the "node" field. + public const int NodeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_node_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser); + private readonly pbc::RepeatedField node_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Node { + get { return node_; } + } + + /// Field number for the "versions" field. + public const int VersionsFieldNumber = 4; + private global::Tensorflow.VersionDef versions_; + /// + /// Compatibility versions of the graph. See core/public/version.h for version + /// history. The GraphDef version is distinct from the TensorFlow version, and + /// each release of TensorFlow will support a range of GraphDef versions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VersionDef Versions { + get { return versions_; } + set { + versions_ = value; + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 3; + private int version_; + /// + /// Deprecated single version field; use versions above instead. Since all + /// GraphDef changes before "versions" was introduced were forward + /// compatible, this field is entirely ignored. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Version { + get { return version_; } + set { + version_ = value; + } + } + + /// Field number for the "library" field. + public const int LibraryFieldNumber = 2; + private global::Tensorflow.FunctionDefLibrary library_; + /// + /// "library" provides user-defined functions. + /// + /// Naming: + /// * library.function.name are in a flat namespace. + /// NOTE: We may need to change it to be hierarchical to support + /// different orgs. E.g., + /// { "/google/nn", { ... }}, + /// { "/google/vision", { ... }} + /// { "/org_foo/module_bar", { ... }} + /// map<string, FunctionDefLib> named_lib; + /// * If node[i].op is the name of one function in "library", + /// node[i] is deemed as a function call. Otherwise, node[i].op + /// must be a primitive operation supported by the runtime. + /// + /// Function call semantics: + /// + /// * The callee may start execution as soon as some of its inputs + /// are ready. The caller may want to use Tuple() mechanism to + /// ensure all inputs are ready in the same time. + /// + /// * The consumer of return values may start executing as soon as + /// the return values the consumer depends on are ready. The + /// consumer may want to use Tuple() mechanism to ensure the + /// consumer does not start until all return values of the callee + /// function are ready. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FunctionDefLibrary Library { + get { return library_; } + set { + library_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!node_.Equals(other.node_)) return false; + if (!object.Equals(Versions, other.Versions)) return false; + if (Version != other.Version) return false; + if (!object.Equals(Library, other.Library)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= node_.GetHashCode(); + if (versions_ != null) hash ^= Versions.GetHashCode(); + if (Version != 0) hash ^= Version.GetHashCode(); + if (library_ != null) hash ^= Library.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + node_.WriteTo(output, _repeated_node_codec); + if (library_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Library); + } + if (Version != 0) { + output.WriteRawTag(24); + output.WriteInt32(Version); + } + if (versions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Versions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + node_.WriteTo(ref output, _repeated_node_codec); + if (library_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Library); + } + if (Version != 0) { + output.WriteRawTag(24); + output.WriteInt32(Version); + } + if (versions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Versions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += node_.CalculateSize(_repeated_node_codec); + if (versions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions); + } + if (Version != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); + } + if (library_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphDef other) { + if (other == null) { + return; + } + node_.Add(other.node_); + if (other.versions_ != null) { + if (versions_ == null) { + Versions = new global::Tensorflow.VersionDef(); + } + Versions.MergeFrom(other.Versions); + } + if (other.Version != 0) { + Version = other.Version; + } + if (other.library_ != null) { + if (library_ == null) { + Library = new global::Tensorflow.FunctionDefLibrary(); + } + Library.MergeFrom(other.Library); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + node_.AddEntriesFrom(input, _repeated_node_codec); + break; + } + case 18: { + if (library_ == null) { + Library = new global::Tensorflow.FunctionDefLibrary(); + } + input.ReadMessage(Library); + break; + } + case 24: { + Version = input.ReadInt32(); + break; + } + case 34: { + if (versions_ == null) { + Versions = new global::Tensorflow.VersionDef(); + } + input.ReadMessage(Versions); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + node_.AddEntriesFrom(ref input, _repeated_node_codec); + break; + } + case 18: { + if (library_ == null) { + Library = new global::Tensorflow.FunctionDefLibrary(); + } + input.ReadMessage(Library); + break; + } + case 24: { + Version = input.ReadInt32(); + break; + } + case 34: { + if (versions_ == null) { + Versions = new global::Tensorflow.VersionDef(); + } + input.ReadMessage(Versions); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/GraphTransferInfo.cs b/src/TensorFlowNET.Core/Protobuf/GraphTransferInfo.cs new file mode 100644 index 000000000..0292e8170 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/GraphTransferInfo.cs @@ -0,0 +1,2356 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/graph_transfer_info.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/graph_transfer_info.proto + public static partial class GraphTransferInfoReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/graph_transfer_info.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static GraphTransferInfoReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjN0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2dyYXBoX3RyYW5zZmVyX2lu", + "Zm8ucHJvdG8SCnRlbnNvcmZsb3caJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdv", + "cmsvdHlwZXMucHJvdG8iPgoWR3JhcGhUcmFuc2Zlck5vZGVJbnB1dBIPCgdu", + "b2RlX2lkGAEgASgFEhMKC291dHB1dF9wb3J0GAIgASgFIpsBChVHcmFwaFRy", + "YW5zZmVyTm9kZUluZm8SDAoEbmFtZRgBIAEoCRIPCgdub2RlX2lkGAIgASgF", + "EhEKCXR5cGVfbmFtZRgDIAEoCRIRCglzb2Nfb3BfaWQYBCABKAUSEgoKcGFk", + "ZGluZ19pZBgFIAEoBRITCgtpbnB1dF9jb3VudBgGIAEoBRIUCgxvdXRwdXRf", + "Y291bnQYByABKAUifQoaR3JhcGhUcmFuc2ZlckNvbnN0Tm9kZUluZm8SDAoE", + "bmFtZRgBIAEoCRIPCgdub2RlX2lkGAIgASgFEg0KBXNoYXBlGAMgAygDEgwK", + "BGRhdGEYBCABKAwSIwoFZHR5cGUYBSABKA4yFC50ZW5zb3JmbG93LkRhdGFU", + "eXBlImUKGkdyYXBoVHJhbnNmZXJOb2RlSW5wdXRJbmZvEg8KB25vZGVfaWQY", + "ASABKAUSNgoKbm9kZV9pbnB1dBgCIAMoCzIiLnRlbnNvcmZsb3cuR3JhcGhU", + "cmFuc2Zlck5vZGVJbnB1dCJFChtHcmFwaFRyYW5zZmVyTm9kZU91dHB1dElu", + "Zm8SDwoHbm9kZV9pZBgBIAEoBRIVCg1tYXhfYnl0ZV9zaXplGAIgAygFImMK", + "H0dyYXBoVHJhbnNmZXJHcmFwaElucHV0Tm9kZUluZm8SDAoEbmFtZRgBIAEo", + "CRINCgVzaGFwZRgCIAMoAxIjCgVkdHlwZRgDIAEoDjIULnRlbnNvcmZsb3cu", + "RGF0YVR5cGUiZAogR3JhcGhUcmFuc2ZlckdyYXBoT3V0cHV0Tm9kZUluZm8S", + "DAoEbmFtZRgBIAEoCRINCgVzaGFwZRgCIAMoAxIjCgVkdHlwZRgDIAEoDjIU", + "LnRlbnNvcmZsb3cuRGF0YVR5cGUijQQKEUdyYXBoVHJhbnNmZXJJbmZvEjQK", + "CW5vZGVfaW5mbxgBIAMoCzIhLnRlbnNvcmZsb3cuR3JhcGhUcmFuc2Zlck5v", + "ZGVJbmZvEj8KD2NvbnN0X25vZGVfaW5mbxgCIAMoCzImLnRlbnNvcmZsb3cu", + "R3JhcGhUcmFuc2ZlckNvbnN0Tm9kZUluZm8SPwoPbm9kZV9pbnB1dF9pbmZv", + "GAMgAygLMiYudGVuc29yZmxvdy5HcmFwaFRyYW5zZmVyTm9kZUlucHV0SW5m", + "bxJBChBub2RlX291dHB1dF9pbmZvGAQgAygLMicudGVuc29yZmxvdy5HcmFw", + "aFRyYW5zZmVyTm9kZU91dHB1dEluZm8SSgoVZ3JhcGhfaW5wdXRfbm9kZV9p", + "bmZvGAUgAygLMisudGVuc29yZmxvdy5HcmFwaFRyYW5zZmVyR3JhcGhJbnB1", + "dE5vZGVJbmZvEkwKFmdyYXBoX291dHB1dF9ub2RlX2luZm8YBiADKAsyLC50", + "ZW5zb3JmbG93LkdyYXBoVHJhbnNmZXJHcmFwaE91dHB1dE5vZGVJbmZvEj4K", + "C2Rlc3RpbmF0aW9uGAcgASgOMikudGVuc29yZmxvdy5HcmFwaFRyYW5zZmVy", + "SW5mby5EZXN0aW5hdGlvbiIjCgtEZXN0aW5hdGlvbhIHCgNOT1AQABILCgdI", + "RVhBR09OEAFCkwEKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IWR3JhcGhU", + "cmFuc2ZlckluZm9Qcm90b1ABWlpnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVu", + "c29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL2dyYXBoX3Ry", + "YW5zZmVyX2luZm9fZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferNodeInput), global::Tensorflow.GraphTransferNodeInput.Parser, new[]{ "NodeId", "OutputPort" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferNodeInfo), global::Tensorflow.GraphTransferNodeInfo.Parser, new[]{ "Name", "NodeId", "TypeName", "SocOpId", "PaddingId", "InputCount", "OutputCount" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferConstNodeInfo), global::Tensorflow.GraphTransferConstNodeInfo.Parser, new[]{ "Name", "NodeId", "Shape", "Data", "Dtype" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferNodeInputInfo), global::Tensorflow.GraphTransferNodeInputInfo.Parser, new[]{ "NodeId", "NodeInput" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferNodeOutputInfo), global::Tensorflow.GraphTransferNodeOutputInfo.Parser, new[]{ "NodeId", "MaxByteSize" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferGraphInputNodeInfo), global::Tensorflow.GraphTransferGraphInputNodeInfo.Parser, new[]{ "Name", "Shape", "Dtype" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferGraphOutputNodeInfo), global::Tensorflow.GraphTransferGraphOutputNodeInfo.Parser, new[]{ "Name", "Shape", "Dtype" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphTransferInfo), global::Tensorflow.GraphTransferInfo.Parser, new[]{ "NodeInfo", "ConstNodeInfo", "NodeInputInfo", "NodeOutputInfo", "GraphInputNodeInfo", "GraphOutputNodeInfo", "Destination" }, null, new[]{ typeof(global::Tensorflow.GraphTransferInfo.Types.Destination) }, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class GraphTransferNodeInput : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferNodeInput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInput(GraphTransferNodeInput other) : this() { + nodeId_ = other.nodeId_; + outputPort_ = other.outputPort_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInput Clone() { + return new GraphTransferNodeInput(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 1; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "output_port" field. + public const int OutputPortFieldNumber = 2; + private int outputPort_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int OutputPort { + get { return outputPort_; } + set { + outputPort_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferNodeInput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferNodeInput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if (OutputPort != other.OutputPort) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (OutputPort != 0) hash ^= OutputPort.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + if (OutputPort != 0) { + output.WriteRawTag(16); + output.WriteInt32(OutputPort); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + if (OutputPort != 0) { + output.WriteRawTag(16); + output.WriteInt32(OutputPort); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (OutputPort != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(OutputPort); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferNodeInput other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.OutputPort != 0) { + OutputPort = other.OutputPort; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 16: { + OutputPort = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 16: { + OutputPort = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferNodeInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferNodeInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInfo(GraphTransferNodeInfo other) : this() { + name_ = other.name_; + nodeId_ = other.nodeId_; + typeName_ = other.typeName_; + socOpId_ = other.socOpId_; + paddingId_ = other.paddingId_; + inputCount_ = other.inputCount_; + outputCount_ = other.outputCount_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInfo Clone() { + return new GraphTransferNodeInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 2; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "type_name" field. + public const int TypeNameFieldNumber = 3; + private string typeName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TypeName { + get { return typeName_; } + set { + typeName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "soc_op_id" field. + public const int SocOpIdFieldNumber = 4; + private int socOpId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int SocOpId { + get { return socOpId_; } + set { + socOpId_ = value; + } + } + + /// Field number for the "padding_id" field. + public const int PaddingIdFieldNumber = 5; + private int paddingId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int PaddingId { + get { return paddingId_; } + set { + paddingId_ = value; + } + } + + /// Field number for the "input_count" field. + public const int InputCountFieldNumber = 6; + private int inputCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int InputCount { + get { return inputCount_; } + set { + inputCount_ = value; + } + } + + /// Field number for the "output_count" field. + public const int OutputCountFieldNumber = 7; + private int outputCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int OutputCount { + get { return outputCount_; } + set { + outputCount_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferNodeInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferNodeInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (NodeId != other.NodeId) return false; + if (TypeName != other.TypeName) return false; + if (SocOpId != other.SocOpId) return false; + if (PaddingId != other.PaddingId) return false; + if (InputCount != other.InputCount) return false; + if (OutputCount != other.OutputCount) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (TypeName.Length != 0) hash ^= TypeName.GetHashCode(); + if (SocOpId != 0) hash ^= SocOpId.GetHashCode(); + if (PaddingId != 0) hash ^= PaddingId.GetHashCode(); + if (InputCount != 0) hash ^= InputCount.GetHashCode(); + if (OutputCount != 0) hash ^= OutputCount.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (TypeName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(TypeName); + } + if (SocOpId != 0) { + output.WriteRawTag(32); + output.WriteInt32(SocOpId); + } + if (PaddingId != 0) { + output.WriteRawTag(40); + output.WriteInt32(PaddingId); + } + if (InputCount != 0) { + output.WriteRawTag(48); + output.WriteInt32(InputCount); + } + if (OutputCount != 0) { + output.WriteRawTag(56); + output.WriteInt32(OutputCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (TypeName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(TypeName); + } + if (SocOpId != 0) { + output.WriteRawTag(32); + output.WriteInt32(SocOpId); + } + if (PaddingId != 0) { + output.WriteRawTag(40); + output.WriteInt32(PaddingId); + } + if (InputCount != 0) { + output.WriteRawTag(48); + output.WriteInt32(InputCount); + } + if (OutputCount != 0) { + output.WriteRawTag(56); + output.WriteInt32(OutputCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (TypeName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeName); + } + if (SocOpId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(SocOpId); + } + if (PaddingId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PaddingId); + } + if (InputCount != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(InputCount); + } + if (OutputCount != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(OutputCount); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferNodeInfo other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.TypeName.Length != 0) { + TypeName = other.TypeName; + } + if (other.SocOpId != 0) { + SocOpId = other.SocOpId; + } + if (other.PaddingId != 0) { + PaddingId = other.PaddingId; + } + if (other.InputCount != 0) { + InputCount = other.InputCount; + } + if (other.OutputCount != 0) { + OutputCount = other.OutputCount; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + case 26: { + TypeName = input.ReadString(); + break; + } + case 32: { + SocOpId = input.ReadInt32(); + break; + } + case 40: { + PaddingId = input.ReadInt32(); + break; + } + case 48: { + InputCount = input.ReadInt32(); + break; + } + case 56: { + OutputCount = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + case 26: { + TypeName = input.ReadString(); + break; + } + case 32: { + SocOpId = input.ReadInt32(); + break; + } + case 40: { + PaddingId = input.ReadInt32(); + break; + } + case 48: { + InputCount = input.ReadInt32(); + break; + } + case 56: { + OutputCount = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferConstNodeInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferConstNodeInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferConstNodeInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferConstNodeInfo(GraphTransferConstNodeInfo other) : this() { + name_ = other.name_; + nodeId_ = other.nodeId_; + shape_ = other.shape_.Clone(); + data_ = other.data_; + dtype_ = other.dtype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferConstNodeInfo Clone() { + return new GraphTransferConstNodeInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 2; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_shape_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Shape { + get { return shape_; } + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 4; + private pb::ByteString data_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Data { + get { return data_; } + set { + data_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 5; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferConstNodeInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferConstNodeInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (NodeId != other.NodeId) return false; + if(!shape_.Equals(other.shape_)) return false; + if (Data != other.Data) return false; + if (Dtype != other.Dtype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + hash ^= shape_.GetHashCode(); + if (Data.Length != 0) hash ^= Data.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + shape_.WriteTo(output, _repeated_shape_codec); + if (Data.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(Data); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(40); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + shape_.WriteTo(ref output, _repeated_shape_codec); + if (Data.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(Data); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(40); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + size += shape_.CalculateSize(_repeated_shape_codec); + if (Data.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Data); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferConstNodeInfo other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + shape_.Add(other.shape_); + if (other.Data.Length != 0) { + Data = other.Data; + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + case 26: + case 24: { + shape_.AddEntriesFrom(input, _repeated_shape_codec); + break; + } + case 34: { + Data = input.ReadBytes(); + break; + } + case 40: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + case 26: + case 24: { + shape_.AddEntriesFrom(ref input, _repeated_shape_codec); + break; + } + case 34: { + Data = input.ReadBytes(); + break; + } + case 40: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferNodeInputInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferNodeInputInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInputInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInputInfo(GraphTransferNodeInputInfo other) : this() { + nodeId_ = other.nodeId_; + nodeInput_ = other.nodeInput_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeInputInfo Clone() { + return new GraphTransferNodeInputInfo(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 1; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "node_input" field. + public const int NodeInputFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_nodeInput_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.GraphTransferNodeInput.Parser); + private readonly pbc::RepeatedField nodeInput_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeInput { + get { return nodeInput_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferNodeInputInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferNodeInputInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if(!nodeInput_.Equals(other.nodeInput_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + hash ^= nodeInput_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + nodeInput_.WriteTo(output, _repeated_nodeInput_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + nodeInput_.WriteTo(ref output, _repeated_nodeInput_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + size += nodeInput_.CalculateSize(_repeated_nodeInput_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferNodeInputInfo other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + nodeInput_.Add(other.nodeInput_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + nodeInput_.AddEntriesFrom(input, _repeated_nodeInput_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + nodeInput_.AddEntriesFrom(ref input, _repeated_nodeInput_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferNodeOutputInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferNodeOutputInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeOutputInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeOutputInfo(GraphTransferNodeOutputInfo other) : this() { + nodeId_ = other.nodeId_; + maxByteSize_ = other.maxByteSize_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferNodeOutputInfo Clone() { + return new GraphTransferNodeOutputInfo(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 1; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "max_byte_size" field. + public const int MaxByteSizeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_maxByteSize_codec + = pb::FieldCodec.ForInt32(18); + private readonly pbc::RepeatedField maxByteSize_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MaxByteSize { + get { return maxByteSize_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferNodeOutputInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferNodeOutputInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if(!maxByteSize_.Equals(other.maxByteSize_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + hash ^= maxByteSize_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + maxByteSize_.WriteTo(output, _repeated_maxByteSize_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + maxByteSize_.WriteTo(ref output, _repeated_maxByteSize_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + size += maxByteSize_.CalculateSize(_repeated_maxByteSize_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferNodeOutputInfo other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + maxByteSize_.Add(other.maxByteSize_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: + case 16: { + maxByteSize_.AddEntriesFrom(input, _repeated_maxByteSize_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: + case 16: { + maxByteSize_.AddEntriesFrom(ref input, _repeated_maxByteSize_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferGraphInputNodeInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferGraphInputNodeInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphInputNodeInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphInputNodeInfo(GraphTransferGraphInputNodeInfo other) : this() { + name_ = other.name_; + shape_ = other.shape_.Clone(); + dtype_ = other.dtype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphInputNodeInfo Clone() { + return new GraphTransferGraphInputNodeInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_shape_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Shape { + get { return shape_; } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 3; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferGraphInputNodeInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferGraphInputNodeInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!shape_.Equals(other.shape_)) return false; + if (Dtype != other.Dtype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= shape_.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + shape_.WriteTo(output, _repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + shape_.WriteTo(ref output, _repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += shape_.CalculateSize(_repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferGraphInputNodeInfo other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + shape_.Add(other.shape_); + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: + case 16: { + shape_.AddEntriesFrom(input, _repeated_shape_codec); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: + case 16: { + shape_.AddEntriesFrom(ref input, _repeated_shape_codec); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GraphTransferGraphOutputNodeInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferGraphOutputNodeInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphOutputNodeInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphOutputNodeInfo(GraphTransferGraphOutputNodeInfo other) : this() { + name_ = other.name_; + shape_ = other.shape_.Clone(); + dtype_ = other.dtype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferGraphOutputNodeInfo Clone() { + return new GraphTransferGraphOutputNodeInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_shape_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Shape { + get { return shape_; } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 3; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferGraphOutputNodeInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferGraphOutputNodeInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!shape_.Equals(other.shape_)) return false; + if (Dtype != other.Dtype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= shape_.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + shape_.WriteTo(output, _repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + shape_.WriteTo(ref output, _repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += shape_.CalculateSize(_repeated_shape_codec); + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferGraphOutputNodeInfo other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + shape_.Add(other.shape_); + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: + case 16: { + shape_.AddEntriesFrom(input, _repeated_shape_codec); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: + case 16: { + shape_.AddEntriesFrom(ref input, _repeated_shape_codec); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + /// + /// Protocol buffer representing a handle to a tensorflow resource. Handles are + /// not valid across executions, but can be serialized back and forth from within + /// a single run. + /// + public sealed partial class GraphTransferInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphTransferInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphTransferInfoReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferInfo(GraphTransferInfo other) : this() { + nodeInfo_ = other.nodeInfo_.Clone(); + constNodeInfo_ = other.constNodeInfo_.Clone(); + nodeInputInfo_ = other.nodeInputInfo_.Clone(); + nodeOutputInfo_ = other.nodeOutputInfo_.Clone(); + graphInputNodeInfo_ = other.graphInputNodeInfo_.Clone(); + graphOutputNodeInfo_ = other.graphOutputNodeInfo_.Clone(); + destination_ = other.destination_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GraphTransferInfo Clone() { + return new GraphTransferInfo(this); + } + + /// Field number for the "node_info" field. + public const int NodeInfoFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_nodeInfo_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.GraphTransferNodeInfo.Parser); + private readonly pbc::RepeatedField nodeInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeInfo { + get { return nodeInfo_; } + } + + /// Field number for the "const_node_info" field. + public const int ConstNodeInfoFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_constNodeInfo_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.GraphTransferConstNodeInfo.Parser); + private readonly pbc::RepeatedField constNodeInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ConstNodeInfo { + get { return constNodeInfo_; } + } + + /// Field number for the "node_input_info" field. + public const int NodeInputInfoFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_nodeInputInfo_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.GraphTransferNodeInputInfo.Parser); + private readonly pbc::RepeatedField nodeInputInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeInputInfo { + get { return nodeInputInfo_; } + } + + /// Field number for the "node_output_info" field. + public const int NodeOutputInfoFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_nodeOutputInfo_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.GraphTransferNodeOutputInfo.Parser); + private readonly pbc::RepeatedField nodeOutputInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeOutputInfo { + get { return nodeOutputInfo_; } + } + + /// Field number for the "graph_input_node_info" field. + public const int GraphInputNodeInfoFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_graphInputNodeInfo_codec + = pb::FieldCodec.ForMessage(42, global::Tensorflow.GraphTransferGraphInputNodeInfo.Parser); + private readonly pbc::RepeatedField graphInputNodeInfo_ = new pbc::RepeatedField(); + /// + /// Input Node parameters of transferred graph + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField GraphInputNodeInfo { + get { return graphInputNodeInfo_; } + } + + /// Field number for the "graph_output_node_info" field. + public const int GraphOutputNodeInfoFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_graphOutputNodeInfo_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.GraphTransferGraphOutputNodeInfo.Parser); + private readonly pbc::RepeatedField graphOutputNodeInfo_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField GraphOutputNodeInfo { + get { return graphOutputNodeInfo_; } + } + + /// Field number for the "destination" field. + public const int DestinationFieldNumber = 7; + private global::Tensorflow.GraphTransferInfo.Types.Destination destination_ = global::Tensorflow.GraphTransferInfo.Types.Destination.Nop; + /// + /// Destination of graph transfer + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GraphTransferInfo.Types.Destination Destination { + get { return destination_; } + set { + destination_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GraphTransferInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GraphTransferInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!nodeInfo_.Equals(other.nodeInfo_)) return false; + if(!constNodeInfo_.Equals(other.constNodeInfo_)) return false; + if(!nodeInputInfo_.Equals(other.nodeInputInfo_)) return false; + if(!nodeOutputInfo_.Equals(other.nodeOutputInfo_)) return false; + if(!graphInputNodeInfo_.Equals(other.graphInputNodeInfo_)) return false; + if(!graphOutputNodeInfo_.Equals(other.graphOutputNodeInfo_)) return false; + if (Destination != other.Destination) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= nodeInfo_.GetHashCode(); + hash ^= constNodeInfo_.GetHashCode(); + hash ^= nodeInputInfo_.GetHashCode(); + hash ^= nodeOutputInfo_.GetHashCode(); + hash ^= graphInputNodeInfo_.GetHashCode(); + hash ^= graphOutputNodeInfo_.GetHashCode(); + if (Destination != global::Tensorflow.GraphTransferInfo.Types.Destination.Nop) hash ^= Destination.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + nodeInfo_.WriteTo(output, _repeated_nodeInfo_codec); + constNodeInfo_.WriteTo(output, _repeated_constNodeInfo_codec); + nodeInputInfo_.WriteTo(output, _repeated_nodeInputInfo_codec); + nodeOutputInfo_.WriteTo(output, _repeated_nodeOutputInfo_codec); + graphInputNodeInfo_.WriteTo(output, _repeated_graphInputNodeInfo_codec); + graphOutputNodeInfo_.WriteTo(output, _repeated_graphOutputNodeInfo_codec); + if (Destination != global::Tensorflow.GraphTransferInfo.Types.Destination.Nop) { + output.WriteRawTag(56); + output.WriteEnum((int) Destination); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + nodeInfo_.WriteTo(ref output, _repeated_nodeInfo_codec); + constNodeInfo_.WriteTo(ref output, _repeated_constNodeInfo_codec); + nodeInputInfo_.WriteTo(ref output, _repeated_nodeInputInfo_codec); + nodeOutputInfo_.WriteTo(ref output, _repeated_nodeOutputInfo_codec); + graphInputNodeInfo_.WriteTo(ref output, _repeated_graphInputNodeInfo_codec); + graphOutputNodeInfo_.WriteTo(ref output, _repeated_graphOutputNodeInfo_codec); + if (Destination != global::Tensorflow.GraphTransferInfo.Types.Destination.Nop) { + output.WriteRawTag(56); + output.WriteEnum((int) Destination); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += nodeInfo_.CalculateSize(_repeated_nodeInfo_codec); + size += constNodeInfo_.CalculateSize(_repeated_constNodeInfo_codec); + size += nodeInputInfo_.CalculateSize(_repeated_nodeInputInfo_codec); + size += nodeOutputInfo_.CalculateSize(_repeated_nodeOutputInfo_codec); + size += graphInputNodeInfo_.CalculateSize(_repeated_graphInputNodeInfo_codec); + size += graphOutputNodeInfo_.CalculateSize(_repeated_graphOutputNodeInfo_codec); + if (Destination != global::Tensorflow.GraphTransferInfo.Types.Destination.Nop) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Destination); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GraphTransferInfo other) { + if (other == null) { + return; + } + nodeInfo_.Add(other.nodeInfo_); + constNodeInfo_.Add(other.constNodeInfo_); + nodeInputInfo_.Add(other.nodeInputInfo_); + nodeOutputInfo_.Add(other.nodeOutputInfo_); + graphInputNodeInfo_.Add(other.graphInputNodeInfo_); + graphOutputNodeInfo_.Add(other.graphOutputNodeInfo_); + if (other.Destination != global::Tensorflow.GraphTransferInfo.Types.Destination.Nop) { + Destination = other.Destination; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + nodeInfo_.AddEntriesFrom(input, _repeated_nodeInfo_codec); + break; + } + case 18: { + constNodeInfo_.AddEntriesFrom(input, _repeated_constNodeInfo_codec); + break; + } + case 26: { + nodeInputInfo_.AddEntriesFrom(input, _repeated_nodeInputInfo_codec); + break; + } + case 34: { + nodeOutputInfo_.AddEntriesFrom(input, _repeated_nodeOutputInfo_codec); + break; + } + case 42: { + graphInputNodeInfo_.AddEntriesFrom(input, _repeated_graphInputNodeInfo_codec); + break; + } + case 50: { + graphOutputNodeInfo_.AddEntriesFrom(input, _repeated_graphOutputNodeInfo_codec); + break; + } + case 56: { + Destination = (global::Tensorflow.GraphTransferInfo.Types.Destination) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + nodeInfo_.AddEntriesFrom(ref input, _repeated_nodeInfo_codec); + break; + } + case 18: { + constNodeInfo_.AddEntriesFrom(ref input, _repeated_constNodeInfo_codec); + break; + } + case 26: { + nodeInputInfo_.AddEntriesFrom(ref input, _repeated_nodeInputInfo_codec); + break; + } + case 34: { + nodeOutputInfo_.AddEntriesFrom(ref input, _repeated_nodeOutputInfo_codec); + break; + } + case 42: { + graphInputNodeInfo_.AddEntriesFrom(ref input, _repeated_graphInputNodeInfo_codec); + break; + } + case 50: { + graphOutputNodeInfo_.AddEntriesFrom(ref input, _repeated_graphOutputNodeInfo_codec); + break; + } + case 56: { + Destination = (global::Tensorflow.GraphTransferInfo.Types.Destination) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the GraphTransferInfo message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Destination { + [pbr::OriginalName("NOP")] Nop = 0, + [pbr::OriginalName("HEXAGON")] Hexagon = 1, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Histogram.cs b/src/TensorFlowNET.Core/Protobuf/Histogram.cs new file mode 100644 index 000000000..7414d1e50 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Histogram.cs @@ -0,0 +1,452 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/tsl/protobuf/histogram.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/tsl/protobuf/histogram.proto + public static partial class HistogramReflection { + + #region Descriptor + /// File descriptor for tensorflow/tsl/protobuf/histogram.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static HistogramReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cid0ZW5zb3JmbG93L3RzbC9wcm90b2J1Zi9oaXN0b2dyYW0ucHJvdG8SCnRl", + "bnNvcmZsb3cihwEKDkhpc3RvZ3JhbVByb3RvEgsKA21pbhgBIAEoARILCgNt", + "YXgYAiABKAESCwoDbnVtGAMgASgBEgsKA3N1bRgEIAEoARITCgtzdW1fc3F1", + "YXJlcxgFIAEoARIYCgxidWNrZXRfbGltaXQYBiADKAFCAhABEhIKBmJ1Y2tl", + "dBgHIAMoAUICEAFCXAoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrUAFaO2dp", + "dGh1Yi5jb20vZ29vZ2xlL3RzbC90c2wvZ28vY29yZS9wcm90b2J1Zi9zdW1t", + "YXJ5X2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.HistogramProto), global::Tensorflow.HistogramProto.Parser, new[]{ "Min", "Max", "Num", "Sum", "SumSquares", "BucketLimit", "Bucket" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Serialization format for histogram module in + /// tsl/lib/histogram/histogram.h + /// + public sealed partial class HistogramProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HistogramProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.HistogramReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HistogramProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HistogramProto(HistogramProto other) : this() { + min_ = other.min_; + max_ = other.max_; + num_ = other.num_; + sum_ = other.sum_; + sumSquares_ = other.sumSquares_; + bucketLimit_ = other.bucketLimit_.Clone(); + bucket_ = other.bucket_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HistogramProto Clone() { + return new HistogramProto(this); + } + + /// Field number for the "min" field. + public const int MinFieldNumber = 1; + private double min_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double Min { + get { return min_; } + set { + min_ = value; + } + } + + /// Field number for the "max" field. + public const int MaxFieldNumber = 2; + private double max_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double Max { + get { return max_; } + set { + max_ = value; + } + } + + /// Field number for the "num" field. + public const int NumFieldNumber = 3; + private double num_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double Num { + get { return num_; } + set { + num_ = value; + } + } + + /// Field number for the "sum" field. + public const int SumFieldNumber = 4; + private double sum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double Sum { + get { return sum_; } + set { + sum_ = value; + } + } + + /// Field number for the "sum_squares" field. + public const int SumSquaresFieldNumber = 5; + private double sumSquares_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double SumSquares { + get { return sumSquares_; } + set { + sumSquares_ = value; + } + } + + /// Field number for the "bucket_limit" field. + public const int BucketLimitFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_bucketLimit_codec + = pb::FieldCodec.ForDouble(50); + private readonly pbc::RepeatedField bucketLimit_ = new pbc::RepeatedField(); + /// + /// Parallel arrays encoding the bucket boundaries and the bucket values. + /// bucket(i) is the count for the bucket i. The range for + /// a bucket is: + /// i == 0: -DBL_MAX .. bucket_limit(0) + /// i != 0: bucket_limit(i-1) .. bucket_limit(i) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField BucketLimit { + get { return bucketLimit_; } + } + + /// Field number for the "bucket" field. + public const int BucketFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_bucket_codec + = pb::FieldCodec.ForDouble(58); + private readonly pbc::RepeatedField bucket_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Bucket { + get { return bucket_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HistogramProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HistogramProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(Min, other.Min)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(Max, other.Max)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(Num, other.Num)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(Sum, other.Sum)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(SumSquares, other.SumSquares)) return false; + if(!bucketLimit_.Equals(other.bucketLimit_)) return false; + if(!bucket_.Equals(other.bucket_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Min != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(Min); + if (Max != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(Max); + if (Num != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(Num); + if (Sum != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(Sum); + if (SumSquares != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(SumSquares); + hash ^= bucketLimit_.GetHashCode(); + hash ^= bucket_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Min != 0D) { + output.WriteRawTag(9); + output.WriteDouble(Min); + } + if (Max != 0D) { + output.WriteRawTag(17); + output.WriteDouble(Max); + } + if (Num != 0D) { + output.WriteRawTag(25); + output.WriteDouble(Num); + } + if (Sum != 0D) { + output.WriteRawTag(33); + output.WriteDouble(Sum); + } + if (SumSquares != 0D) { + output.WriteRawTag(41); + output.WriteDouble(SumSquares); + } + bucketLimit_.WriteTo(output, _repeated_bucketLimit_codec); + bucket_.WriteTo(output, _repeated_bucket_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Min != 0D) { + output.WriteRawTag(9); + output.WriteDouble(Min); + } + if (Max != 0D) { + output.WriteRawTag(17); + output.WriteDouble(Max); + } + if (Num != 0D) { + output.WriteRawTag(25); + output.WriteDouble(Num); + } + if (Sum != 0D) { + output.WriteRawTag(33); + output.WriteDouble(Sum); + } + if (SumSquares != 0D) { + output.WriteRawTag(41); + output.WriteDouble(SumSquares); + } + bucketLimit_.WriteTo(ref output, _repeated_bucketLimit_codec); + bucket_.WriteTo(ref output, _repeated_bucket_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Min != 0D) { + size += 1 + 8; + } + if (Max != 0D) { + size += 1 + 8; + } + if (Num != 0D) { + size += 1 + 8; + } + if (Sum != 0D) { + size += 1 + 8; + } + if (SumSquares != 0D) { + size += 1 + 8; + } + size += bucketLimit_.CalculateSize(_repeated_bucketLimit_codec); + size += bucket_.CalculateSize(_repeated_bucket_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HistogramProto other) { + if (other == null) { + return; + } + if (other.Min != 0D) { + Min = other.Min; + } + if (other.Max != 0D) { + Max = other.Max; + } + if (other.Num != 0D) { + Num = other.Num; + } + if (other.Sum != 0D) { + Sum = other.Sum; + } + if (other.SumSquares != 0D) { + SumSquares = other.SumSquares; + } + bucketLimit_.Add(other.bucketLimit_); + bucket_.Add(other.bucket_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + Min = input.ReadDouble(); + break; + } + case 17: { + Max = input.ReadDouble(); + break; + } + case 25: { + Num = input.ReadDouble(); + break; + } + case 33: { + Sum = input.ReadDouble(); + break; + } + case 41: { + SumSquares = input.ReadDouble(); + break; + } + case 50: + case 49: { + bucketLimit_.AddEntriesFrom(input, _repeated_bucketLimit_codec); + break; + } + case 58: + case 57: { + bucket_.AddEntriesFrom(input, _repeated_bucket_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + Min = input.ReadDouble(); + break; + } + case 17: { + Max = input.ReadDouble(); + break; + } + case 25: { + Num = input.ReadDouble(); + break; + } + case 33: { + Sum = input.ReadDouble(); + break; + } + case 41: { + SumSquares = input.ReadDouble(); + break; + } + case 50: + case 49: { + bucketLimit_.AddEntriesFrom(ref input, _repeated_bucketLimit_codec); + break; + } + case 58: + case 57: { + bucket_.AddEntriesFrom(ref input, _repeated_bucket_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Hlo.cs b/src/TensorFlowNET.Core/Protobuf/Hlo.cs new file mode 100644 index 000000000..27aa3faa3 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Hlo.cs @@ -0,0 +1,11996 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/service/hlo.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla { + + /// Holder for reflection information generated from tensorflow/compiler/xla/service/hlo.proto + public static partial class HloReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/service/hlo.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static HloReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cil0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90bxID", + "eGxhGiZ0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS94bGFfZGF0YS5wcm90byKV", + "FQoTSGxvSW5zdHJ1Y3Rpb25Qcm90bxIMCgRuYW1lGAEgASgJEg4KBm9wY29k", + "ZRgCIAEoCRIeCgVzaGFwZRgDIAEoCzIPLnhsYS5TaGFwZVByb3RvEiEKCG1l", + "dGFkYXRhGAcgASgLMg8ueGxhLk9wTWV0YWRhdGESIgoHbGl0ZXJhbBgIIAEo", + "CzIRLnhsYS5MaXRlcmFsUHJvdG8SGAoQcGFyYW1ldGVyX251bWJlchgJIAEo", + "AxITCgtmdXNpb25fa2luZBgLIAEoCRITCgt0dXBsZV9pbmRleBgNIAEoAxIS", + "CgpkaW1lbnNpb25zGA4gAygDEhsKBndpbmRvdxgPIAEoCzILLnhsYS5XaW5k", + "b3cSRwodY29udm9sdXRpb25fZGltZW5zaW9uX251bWJlcnMYECABKAsyIC54", + "bGEuQ29udm9sdXRpb25EaW1lbnNpb25OdW1iZXJzEhsKE2ZlYXR1cmVfZ3Jv", + "dXBfY291bnQYMiABKAMSGQoRYmF0Y2hfZ3JvdXBfY291bnQYOiABKAMSQgoQ", + "c2xpY2VfZGltZW5zaW9ucxgRIAMoCzIoLnhsYS5IbG9JbnN0cnVjdGlvblBy", + "b3RvLlNsaWNlRGltZW5zaW9ucxIVCg1leHBvbmVudF9iaXRzGBIgASgFEhUK", + "DW1hbnRpc3NhX2JpdHMYEyABKAUSGwoTZHluYW1pY19zbGljZV9zaXplcxgU", + "IAMoAxIqCg5wYWRkaW5nX2NvbmZpZxgVIAEoCzISLnhsYS5QYWRkaW5nQ29u", + "ZmlnEhYKDm91dGZlZWRfY29uZmlnGBYgASgMEi0KDGRpc3RyaWJ1dGlvbhgX", + "IAEoDjIXLnhsYS5SYW5kb21EaXN0cmlidXRpb24SDwoHZXBzaWxvbhgYIAEo", + "AhIVCg1mZWF0dXJlX2luZGV4GBkgASgDEhIKCmNoYW5uZWxfaWQYGiABKAMS", + "FQoNaW5mZWVkX2NvbmZpZxgbIAEoDBIaChJjdXN0b21fY2FsbF90YXJnZXQY", + "HCABKAkSJgoNb3V0ZmVlZF9zaGFwZRgdIAEoCzIPLnhsYS5TaGFwZVByb3Rv", + "EjcKFWRvdF9kaW1lbnNpb25fbnVtYmVycxgeIAEoCzIYLnhsYS5Eb3REaW1l", + "bnNpb25OdW1iZXJzEh4KCGZmdF90eXBlGB8gASgOMgwueGxhLkZmdFR5cGUS", + "EgoKZmZ0X2xlbmd0aBggIAMoAxIcChRjb21wYXJpc29uX2RpcmVjdGlvbhg/", + "IAEoCRI9ChhnYXRoZXJfZGltZW5zaW9uX251bWJlcnMYISABKAsyGy54bGEu", + "R2F0aGVyRGltZW5zaW9uTnVtYmVycxIaChJnYXRoZXJfc2xpY2Vfc2l6ZXMY", + "IiADKAMSCgoCaWQYIyABKAMSEwoLb3BlcmFuZF9pZHMYJCADKAMSHwoXY29u", + "dHJvbF9wcmVkZWNlc3Nvcl9pZHMYJSADKAMSHgoWY2FsbGVkX2NvbXB1dGF0", + "aW9uX2lkcxgmIAMoAxIhCghzaGFyZGluZxgoIAEoCzIPLnhsYS5PcFNoYXJk", + "aW5nEhYKDmJhY2tlbmRfY29uZmlnGCsgASgMEikKDnJlcGxpY2FfZ3JvdXBz", + "GDEgAygLMhEueGxhLlJlcGxpY2FHcm91cBIZCg1hbGxfcmVkdWNlX2lkGC0g", + "ASgDQgIYARIdChV1c2VfZ2xvYmFsX2RldmljZV9pZHMYRyABKAgSGAoQaXNf", + "aG9zdF90cmFuc2ZlchgvIAEoCBIRCglpc19zdGFibGUYPCABKAgSPwoZc2Nh", + "dHRlcl9kaW1lbnNpb25fbnVtYmVycxgwIAEoCzIcLnhsYS5TY2F0dGVyRGlt", + "ZW5zaW9uTnVtYmVycxIuChBwcmVjaXNpb25fY29uZmlnGDMgASgLMhQueGxh", + "LlByZWNpc2lvbkNvbmZpZxIuChNzb3VyY2VfdGFyZ2V0X3BhaXJzGDQgAygL", + "MhEueGxhLlNvdXJjZVRhcmdldBIuChVkb21haW5fZW50cnlfc2hhcmRpbmcY", + "NiABKAsyDy54bGEuT3BTaGFyZGluZxItChRkb21haW5fZXhpdF9zaGFyZGlu", + "Zxg3IAEoCzIPLnhsYS5PcFNoYXJkaW5nEhgKEGNvbnN0cmFpbl9sYXlvdXQY", + "OCABKAgSMwoab3BlcmFuZF9zaGFwZXNfd2l0aF9sYXlvdXQYOSADKAsyDy54", + "bGEuU2hhcGVQcm90bxI9Chh0cmlhbmd1bGFyX3NvbHZlX29wdGlvbnMYOyAB", + "KAsyGy54bGEuVHJpYW5ndWxhclNvbHZlT3B0aW9ucxIuChBjaG9sZXNreV9v", + "cHRpb25zGD4gASgLMhQueGxhLkNob2xlc2t5T3B0aW9ucxI4ChVwYXJhbWV0", + "ZXJfcmVwbGljYXRpb24YPSABKAsyGS54bGEuUGFyYW1ldGVyUmVwbGljYXRp", + "b24SIwobY3VzdG9tX2NhbGxfaGFzX3NpZGVfZWZmZWN0GEEgASgIElEKI2N1", + "c3RvbV9jYWxsX291dHB1dF9vcGVyYW5kX2FsaWFzaW5nGEogAygLMiQueGxh", + "LkN1c3RvbUNhbGxPdXRwdXRPcGVyYW5kQWxpYXNpbmcSNQoUY3VzdG9tX2Nh", + "bGxfc2NoZWR1bGUYTCABKA4yFy54bGEuQ3VzdG9tQ2FsbFNjaGVkdWxlEg0K", + "BWRlbHRhGEIgASgDEhoKEmluZGljZXNfYXJlX3NvcnRlZBhDIAEoCBI0ChNm", + "cm9udGVuZF9hdHRyaWJ1dGVzGEQgASgLMhcueGxhLkZyb250ZW5kQXR0cmli", + "dXRlcxIWCg51bmlxdWVfaW5kaWNlcxhFIAEoCBIrCg1ybmdfYWxnb3JpdGht", + "GEYgASgOMhQueGxhLlJhbmRvbUFsZ29yaXRobRIXCg9jb21wYXJpc29uX3R5", + "cGUYSCABKAkSIQoZaXNfY3Jvc3NfcHJvZ3JhbV9wcmVmZXRjaBhJIAEoCBIm", + "CgxwYWRkaW5nX3R5cGUYSyABKA4yEC54bGEuUGFkZGluZ1R5cGUSOgoXY3Vz", + "dG9tX2NhbGxfYXBpX3ZlcnNpb24YTSABKA4yGS54bGEuQ3VzdG9tQ2FsbEFw", + "aVZlcnNpb24SFgoOYXN5bmNfZ3JvdXBfaWQYTiABKAMSHgoWYXN5bmNfZXhl", + "Y3V0aW9uX3RocmVhZBhPIAEoCRo/Cg9TbGljZURpbWVuc2lvbnMSDQoFc3Rh", + "cnQYASABKAMSDQoFbGltaXQYAiABKAMSDgoGc3RyaWRlGAMgASgDSgQIChAL", + "SgQIDBANSgQIBBAFSgQIBRAGSgQIBhAHSgQILBAtSgQINRA2SgQILhAvSgQI", + "KRAqSgQIKhArSgQIQBBBUg5wYXJhbWV0ZXJfbmFtZVIeZnVzZWRfaW5zdHJ1", + "Y3Rpb25zX2NvbXB1dGF0aW9uUg1vcGVyYW5kX25hbWVzUhljb250cm9sX3By", + "ZWRlY2Vzc29yX25hbWVzUhhjYWxsZWRfY29tcHV0YXRpb25fbmFtZXNSEXJl", + "cGxpY2FfZ3JvdXBfaWRzUhJjdXN0b21fY2FsbF9vcGFxdWVSEmFsbF9yZWR1", + "Y2VfYmFycmllciLpAQoTSGxvQ29tcHV0YXRpb25Qcm90bxIMCgRuYW1lGAEg", + "ASgJEi4KDGluc3RydWN0aW9ucxgCIAMoCzIYLnhsYS5IbG9JbnN0cnVjdGlv", + "blByb3RvEi0KDXByb2dyYW1fc2hhcGUYBCABKAsyFi54bGEuUHJvZ3JhbVNo", + "YXBlUHJvdG8SCgoCaWQYBSABKAMSDwoHcm9vdF9pZBgGIAEoAxIdChVpc19m", + "dXNpb25fY29tcHV0YXRpb24YByABKAgSGAoQZXhlY3V0aW9uX3RocmVhZBgI", + "IAEoCUoECAMQBFIJcm9vdF9uYW1lItgBChBIbG9TY2hlZHVsZVByb3RvEjcK", + "CXNlcXVlbmNlcxgBIAMoCzIkLnhsYS5IbG9TY2hlZHVsZVByb3RvLlNlcXVl", + "bmNlc0VudHJ5Gi4KE0luc3RydWN0aW9uU2VxdWVuY2USFwoPaW5zdHJ1Y3Rp", + "b25faWRzGAEgAygDGlsKDlNlcXVlbmNlc0VudHJ5EgsKA2tleRgBIAEoAxI4", + "CgV2YWx1ZRgCIAEoCzIpLnhsYS5IbG9TY2hlZHVsZVByb3RvLkluc3RydWN0", + "aW9uU2VxdWVuY2U6AjgBItsBChhIbG9JbnB1dE91dHB1dEFsaWFzUHJvdG8S", + "PgoHZW50cmllcxgBIAMoCzItLnhsYS5IbG9JbnB1dE91dHB1dEFsaWFzUHJv", + "dG8uQWxpYXNFbnRyeVByb3RvGn8KD0FsaWFzRW50cnlQcm90bxIaChJvdXRw", + "dXRfc2hhcGVfaW5kZXgYASADKAMSGAoQcGFyYW1ldGVyX251bWJlchgCIAEo", + "AxIdChVwYXJhbWV0ZXJfc2hhcGVfaW5kZXgYAyADKAMSFwoEa2luZBgEIAEo", + "DjIJLnhsYS5LaW5kIvIBChxEeW5hbWljUGFyYW1ldGVyQmluZGluZ1Byb3Rv", + "EjoKB2VudHJpZXMYASADKAsyKS54bGEuRHluYW1pY1BhcmFtZXRlckJpbmRp", + "bmdQcm90by5CaW5kaW5nGpUBCgdCaW5kaW5nEhkKEWR5bmFtaWNfcGFyYW1f", + "bnVtGAEgASgDEhsKE2R5bmFtaWNfcGFyYW1faW5kZXgYAiADKAMSGAoQdGFy", + "Z2V0X3BhcmFtX251bRgDIAEoAxIaChJ0YXJnZXRfcGFyYW1faW5kZXgYBCAD", + "KAMSHAoUdGFyZ2V0X3BhcmFtX2RpbV9udW0YBSABKAMiOAoUQ3Jvc3NQcm9n", + "cmFtUHJlZmV0Y2gSEQoJcGFyYW1ldGVyGAEgASgDEg0KBWluZGV4GAIgAygD", + "IsIHCg5IbG9Nb2R1bGVQcm90bxIMCgRuYW1lGAEgASgJEh4KFmVudHJ5X2Nv", + "bXB1dGF0aW9uX25hbWUYAiABKAkSHAoUZW50cnlfY29tcHV0YXRpb25faWQY", + "BiABKAMSLgoMY29tcHV0YXRpb25zGAMgAygLMhgueGxhLkhsb0NvbXB1dGF0", + "aW9uUHJvdG8SMgoSaG9zdF9wcm9ncmFtX3NoYXBlGAQgASgLMhYueGxhLlBy", + "b2dyYW1TaGFwZVByb3RvEgoKAmlkGAUgASgDEicKCHNjaGVkdWxlGAcgASgL", + "MhUueGxhLkhsb1NjaGVkdWxlUHJvdG8SOQoSaW5wdXRfb3V0cHV0X2FsaWFz", + "GAggASgLMh0ueGxhLkhsb0lucHV0T3V0cHV0QWxpYXNQcm90bxJEChlkeW5h", + "bWljX3BhcmFtZXRlcl9iaW5kaW5nGAkgASgLMiEueGxhLkR5bmFtaWNQYXJh", + "bWV0ZXJCaW5kaW5nUHJvdG8SOwoYY3Jvc3NfcHJvZ3JhbV9wcmVmZXRjaGVz", + "GAogAygLMhkueGxhLkNyb3NzUHJvZ3JhbVByZWZldGNoEhIKCmlzX2R5bmFt", + "aWMYCyABKAgSLQoUc3BtZF9vdXRwdXRfc2hhcmRpbmcYDCABKAsyDy54bGEu", + "T3BTaGFyZGluZxIyChlzcG1kX3BhcmFtZXRlcnNfc2hhcmRpbmdzGA4gAygL", + "Mg8ueGxhLk9wU2hhcmRpbmcSIgoadXNlX2F1dG9fc3BtZF9wYXJ0aXRpb25p", + "bmcYECABKAgSNQoMcHJvZmlsZV9pbmZvGA0gAygLMh8ueGxhLkhsb01vZHVs", + "ZVByb3RvLlByb2ZpbGVJbmZvEjUKEWRldmljZV9hc3NpZ25tZW50GA8gASgL", + "MhoueGxhLkRldmljZUFzc2lnbm1lbnRQcm90bxq8AQoLUHJvZmlsZUluZm8S", + "NQoMcHJvZmlsZV90eXBlGAEgASgOMh8ueGxhLkhsb01vZHVsZVByb3RvLlBy", + "b2ZpbGVUeXBlEhgKEHJlbGF0aXZlX3NwZWVkdXAYAiABKAESKgoOcHJvZmls", + "ZV9zb3VyY2UYAyABKA4yEi54bGEuUHJvZmlsZVNvdXJjZRIwChFjb21waWxh", + "dGlvbl9ldmVudBgEIAEoDjIVLnhsYS5Db21waWxhdGlvbkV2ZW50IkUKC1By", + "b2ZpbGVUeXBlEgsKB0lOVkFMSUQQABIICgRGTEFHEAESCgoGRlVTSU9OEAIS", + "CgoGTEFZT1VUEAMSBwoDRE9UEAQi6AEKEkxvZ2ljYWxCdWZmZXJQcm90bxIK", + "CgJpZBgBIAEoAxIMCgRzaXplGAIgASgDEjQKCmRlZmluZWRfYXQYAyABKAsy", + "IC54bGEuTG9naWNhbEJ1ZmZlclByb3RvLkxvY2F0aW9uEg0KBWNvbG9yGAQg", + "ASgDGnMKCExvY2F0aW9uEhwKEGNvbXB1dGF0aW9uX25hbWUYASABKAlCAhgB", + "EhwKEGluc3RydWN0aW9uX25hbWUYAiABKAlCAhgBEhYKDmluc3RydWN0aW9u", + "X2lkGAQgASgDEhMKC3NoYXBlX2luZGV4GAMgAygDIvgCChVCdWZmZXJBbGxv", + "Y2F0aW9uUHJvdG8SDQoFaW5kZXgYASABKAMSDAoEc2l6ZRgCIAEoAxIXCg9p", + "c190aHJlYWRfbG9jYWwYAyABKAgSEAoIaXNfdHVwbGUYCyABKAgSJgoeaXNf", + "ZW50cnlfY29tcHV0YXRpb25fcGFyYW1ldGVyGAUgASgIEhMKC2lzX2NvbnN0", + "YW50GAwgASgIEhgKEHBhcmFtZXRlcl9udW1iZXIYBiABKAMSHQoVcGFyYW1l", + "dGVyX3NoYXBlX2luZGV4GAogAygDEhYKDm1heWJlX2xpdmVfb3V0GAcgASgI", + "Eg0KBWNvbG9yGAggASgDEjUKCGFzc2lnbmVkGAkgAygLMiMueGxhLkJ1ZmZl", + "ckFsbG9jYXRpb25Qcm90by5Bc3NpZ25lZBpDCghBc3NpZ25lZBIZChFsb2dp", + "Y2FsX2J1ZmZlcl9pZBgBIAEoAxIOCgZvZmZzZXQYAiABKAMSDAoEc2l6ZRgD", + "IAEoAyLWAgoSSGVhcFNpbXVsYXRvclRyYWNlEi0KBmV2ZW50cxgBIAMoCzId", + "LnhsYS5IZWFwU2ltdWxhdG9yVHJhY2UuRXZlbnQSHwoXd2hvbGVfbW9kdWxl", + "X3NpbXVsYXRpb24YAiABKAgSHwoXYnVmZmVyX2FsbG9jYXRpb25faW5kZXgY", + "AyABKAMazgEKBUV2ZW50EjAKBGtpbmQYASABKA4yIi54bGEuSGVhcFNpbXVs", + "YXRvclRyYWNlLkV2ZW50LktpbmQSEQoJYnVmZmVyX2lkGAIgASgDEhgKEGNv", + "bXB1dGF0aW9uX25hbWUYAyABKAkSGAoQaW5zdHJ1Y3Rpb25fbmFtZRgEIAEo", + "CRIfChdzaGFyZV93aXRoX2Nhbm9uaWNhbF9pZBgFIAEoAyIrCgRLaW5kEgkK", + "BUFMTE9DEAASCAoERlJFRRABEg4KClNIQVJFX1dJVEgQAiJNChNIbG9Nb2R1", + "bGVHcm91cFByb3RvEgwKBG5hbWUYASABKAkSKAoLaGxvX21vZHVsZXMYAiAD", + "KAsyEy54bGEuSGxvTW9kdWxlUHJvdG8i1gIKFUJ1ZmZlckFzc2lnbm1lbnRQ", + "cm90bxIwCg9sb2dpY2FsX2J1ZmZlcnMYASADKAsyFy54bGEuTG9naWNhbEJ1", + "ZmZlclByb3RvEj4KDmJ1ZmZlcl9hbGlhc2VzGAIgAygLMiYueGxhLkJ1ZmZl", + "ckFzc2lnbm1lbnRQcm90by5CdWZmZXJBbGlhcxI2ChJidWZmZXJfYWxsb2Nh", + "dGlvbnMYAyADKAsyGi54bGEuQnVmZmVyQWxsb2NhdGlvblByb3RvEjYKFWhl", + "YXBfc2ltdWxhdG9yX3RyYWNlcxgEIAMoCzIXLnhsYS5IZWFwU2ltdWxhdG9y", + "VHJhY2UaWwoLQnVmZmVyQWxpYXMSGAoQc291cmNlX2J1ZmZlcl9pZBgBIAEo", + "AxIyCghsb2NhdGlvbhgCIAEoCzIgLnhsYS5Mb2dpY2FsQnVmZmVyUHJvdG8u", + "TG9jYXRpb24ifgoISGxvUHJvdG8SJwoKaGxvX21vZHVsZRgBIAEoCzITLnhs", + "YS5IbG9Nb2R1bGVQcm90bxI1ChFidWZmZXJfYXNzaWdubWVudBgDIAEoCzIa", + "LnhsYS5CdWZmZXJBc3NpZ25tZW50UHJvdG9KBAgCEANSDGhsb19vcmRlcmlu", + "ZyKOAQoLSGxvU25hcHNob3QSGgoDaGxvGAEgASgLMg0ueGxhLkhsb1Byb3Rv", + "EiQKCWFyZ3VtZW50cxgCIAMoCzIRLnhsYS5MaXRlcmFsUHJvdG8SIQoGcmVz", + "dWx0GAMgASgLMhEueGxhLkxpdGVyYWxQcm90bxIaChJleGVjdXRpb25fcGxh", + "dGZvcm0YBCABKAkiuQEKFkhsb01vZHVsZU1ldGFkYXRhUHJvdG8SGwoTY2Fu", + "b25pY2FsX21vZHVsZV9pZBgBIAEoAxIZChFtb2R1bGVfZ3JvdXBfbmFtZRgC", + "IAEoCRIaChJvcmlnaW5hbF9tb2R1bGVfaWQYAyABKAMSHgoWcGFydGl0aW9u", + "ZWRfbW9kdWxlX2lkcxgEIAMoAxIrCg1wYXNzX21ldGFkYXRhGAUgAygLMhQu", + "eGxhLkhsb1Bhc3NNZXRhZGF0YSLqAQoPSGxvUGFzc01ldGFkYXRhEg8KB3Bh", + "c3NfaWQYASABKAMSEQoJcGFzc19uYW1lGAIgASgJEhUKDXBpcGVsaW5lX25h", + "bWUYAyABKAkSFgoOZHVtcF9maWxlbmFtZXMYBCADKAkSFgoObW9kdWxlX2No", + "YW5nZWQYBSABKAgSEQoJbW9kdWxlX2lkGAYgASgDEh8KF21vZHVsZV9ncm91", + "cF9tb2R1bGVfaWRzGAcgAygDEhwKFHN0YXJ0X3RpbWVzdGFtcF91c2VjGAgg", + "ASgDEhoKEmVuZF90aW1lc3RhbXBfdXNlYxgJIAEoAyKzAwoXRW50cnlGdW5j", + "dGlvbkF0dHJpYnV0ZXMSRwoHYnVmZmVycxgBIAMoCzI2LnhsYS5FbnRyeUZ1", + "bmN0aW9uQXR0cmlidXRlcy5CdWZmZXJQYXJhbWV0ZXJBdHRyaWJ1dGVzEhgK", + "EHJlc3VsdF94bGFfc2hhcGUYAiABKAkaHQoKU2hhcGVJbmRleBIPCgdpbmRp", + "Y2VzGAEgAygDGpUCChlCdWZmZXJQYXJhbWV0ZXJBdHRyaWJ1dGVzEhQKDGxt", + "aGxvX3BhcmFtcxgBIAEoAxIcChRsbWhsb19wYXJhbXNfcHJlc2VudBgGIAEo", + "CBJIChdsbWhsb19wYXJhbV9zaGFwZV9pbmRleBgCIAEoCzInLnhsYS5FbnRy", + "eUZ1bmN0aW9uQXR0cmlidXRlcy5TaGFwZUluZGV4EhsKE2xtaGxvX2NvbnN0", + "YW50X25hbWUYAyABKAkSGAoQbG1obG9fbXVzdF9hbGlhcxgEIAEoCBJDChJs", + "bWhsb19vdXRwdXRfaW5kZXgYBSABKAsyJy54bGEuRW50cnlGdW5jdGlvbkF0", + "dHJpYnV0ZXMuU2hhcGVJbmRleCKpAQoZWGxhUnVudGltZUV4ZWN1dGFibGVQ", + "cm90bxItChBobG9fbW9kdWxlX3Byb3RvGAEgASgLMhMueGxhLkhsb01vZHVs", + "ZVByb3RvEjYKEGVudHJ5X2Z1bmNfYXR0cnMYAiABKAsyHC54bGEuRW50cnlG", + "dW5jdGlvbkF0dHJpYnV0ZXMSEAoIb2JqX2ZpbGUYAyABKAwSEwoLbWxpcl9t", + "b2R1bGUYBCABKAkqUwoSQ3VzdG9tQ2FsbFNjaGVkdWxlEhEKDVNDSEVEVUxF", + "X05PTkUQABITCg9TQ0hFRFVMRV9MQVRFU1QQARIVChFTQ0hFRFVMRV9FQVJM", + "SUVTVBACKpkBChRDdXN0b21DYWxsQXBpVmVyc2lvbhIbChdBUElfVkVSU0lP", + "Tl9VTlNQRUNJRklFRBAAEhgKFEFQSV9WRVJTSU9OX09SSUdJTkFMEAESIAoc", + "QVBJX1ZFUlNJT05fU1RBVFVTX1JFVFVSTklORxACEigKJEFQSV9WRVJTSU9O", + "X1NUQVRVU19SRVRVUk5JTkdfVU5JRklFRBADKjoKBEtpbmQSEwoPVU5ERUZJ", + "TkVEX0FMSUFTEAASDQoJTUFZX0FMSUFTEAESDgoKTVVTVF9BTElBUxACQgP4", + "AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Xla.XlaDataReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Xla.CustomCallSchedule), typeof(global::Xla.CustomCallApiVersion), typeof(global::Xla.Kind), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloInstructionProto), global::Xla.HloInstructionProto.Parser, new[]{ "Name", "Opcode", "Shape", "Metadata", "Literal", "ParameterNumber", "FusionKind", "TupleIndex", "Dimensions", "Window", "ConvolutionDimensionNumbers", "FeatureGroupCount", "BatchGroupCount", "SliceDimensions", "ExponentBits", "MantissaBits", "DynamicSliceSizes", "PaddingConfig", "OutfeedConfig", "Distribution", "Epsilon", "FeatureIndex", "ChannelId", "InfeedConfig", "CustomCallTarget", "OutfeedShape", "DotDimensionNumbers", "FftType", "FftLength", "ComparisonDirection", "GatherDimensionNumbers", "GatherSliceSizes", "Id", "OperandIds", "ControlPredecessorIds", "CalledComputationIds", "Sharding", "BackendConfig", "ReplicaGroups", "AllReduceId", "UseGlobalDeviceIds", "IsHostTransfer", "IsStable", "ScatterDimensionNumbers", "PrecisionConfig", "SourceTargetPairs", "DomainEntrySharding", "DomainExitSharding", "ConstrainLayout", "OperandShapesWithLayout", "TriangularSolveOptions", "CholeskyOptions", "ParameterReplication", "CustomCallHasSideEffect", "CustomCallOutputOperandAliasing", "CustomCallSchedule", "Delta", "IndicesAreSorted", "FrontendAttributes", "UniqueIndices", "RngAlgorithm", "ComparisonType", "IsCrossProgramPrefetch", "PaddingType", "CustomCallApiVersion", "AsyncGroupId", "AsyncExecutionThread" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloInstructionProto.Types.SliceDimensions), global::Xla.HloInstructionProto.Types.SliceDimensions.Parser, new[]{ "Start", "Limit", "Stride" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloComputationProto), global::Xla.HloComputationProto.Parser, new[]{ "Name", "Instructions", "ProgramShape", "Id", "RootId", "IsFusionComputation", "ExecutionThread" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloScheduleProto), global::Xla.HloScheduleProto.Parser, new[]{ "Sequences" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloScheduleProto.Types.InstructionSequence), global::Xla.HloScheduleProto.Types.InstructionSequence.Parser, new[]{ "InstructionIds" }, null, null, null, null), + null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloInputOutputAliasProto), global::Xla.HloInputOutputAliasProto.Parser, new[]{ "Entries" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloInputOutputAliasProto.Types.AliasEntryProto), global::Xla.HloInputOutputAliasProto.Types.AliasEntryProto.Parser, new[]{ "OutputShapeIndex", "ParameterNumber", "ParameterShapeIndex", "Kind" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DynamicParameterBindingProto), global::Xla.DynamicParameterBindingProto.Parser, new[]{ "Entries" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DynamicParameterBindingProto.Types.Binding), global::Xla.DynamicParameterBindingProto.Types.Binding.Parser, new[]{ "DynamicParamNum", "DynamicParamIndex", "TargetParamNum", "TargetParamIndex", "TargetParamDimNum" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CrossProgramPrefetch), global::Xla.CrossProgramPrefetch.Parser, new[]{ "Parameter", "Index" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloModuleProto), global::Xla.HloModuleProto.Parser, new[]{ "Name", "EntryComputationName", "EntryComputationId", "Computations", "HostProgramShape", "Id", "Schedule", "InputOutputAlias", "DynamicParameterBinding", "CrossProgramPrefetches", "IsDynamic", "SpmdOutputSharding", "SpmdParametersShardings", "UseAutoSpmdPartitioning", "ProfileInfo", "DeviceAssignment" }, null, new[]{ typeof(global::Xla.HloModuleProto.Types.ProfileType) }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloModuleProto.Types.ProfileInfo), global::Xla.HloModuleProto.Types.ProfileInfo.Parser, new[]{ "ProfileType", "RelativeSpeedup", "ProfileSource", "CompilationEvent" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LogicalBufferProto), global::Xla.LogicalBufferProto.Parser, new[]{ "Id", "Size", "DefinedAt", "Color" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LogicalBufferProto.Types.Location), global::Xla.LogicalBufferProto.Types.Location.Parser, new[]{ "ComputationName", "InstructionName", "InstructionId", "ShapeIndex" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.BufferAllocationProto), global::Xla.BufferAllocationProto.Parser, new[]{ "Index", "Size", "IsThreadLocal", "IsTuple", "IsEntryComputationParameter", "IsConstant", "ParameterNumber", "ParameterShapeIndex", "MaybeLiveOut", "Color", "Assigned" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.BufferAllocationProto.Types.Assigned), global::Xla.BufferAllocationProto.Types.Assigned.Parser, new[]{ "LogicalBufferId", "Offset", "Size" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HeapSimulatorTrace), global::Xla.HeapSimulatorTrace.Parser, new[]{ "Events", "WholeModuleSimulation", "BufferAllocationIndex" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HeapSimulatorTrace.Types.Event), global::Xla.HeapSimulatorTrace.Types.Event.Parser, new[]{ "Kind", "BufferId", "ComputationName", "InstructionName", "ShareWithCanonicalId" }, null, new[]{ typeof(global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind) }, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloModuleGroupProto), global::Xla.HloModuleGroupProto.Parser, new[]{ "Name", "HloModules" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.BufferAssignmentProto), global::Xla.BufferAssignmentProto.Parser, new[]{ "LogicalBuffers", "BufferAliases", "BufferAllocations", "HeapSimulatorTraces" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.BufferAssignmentProto.Types.BufferAlias), global::Xla.BufferAssignmentProto.Types.BufferAlias.Parser, new[]{ "SourceBufferId", "Location" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloProto), global::Xla.HloProto.Parser, new[]{ "HloModule", "BufferAssignment" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloSnapshot), global::Xla.HloSnapshot.Parser, new[]{ "Hlo", "Arguments", "Result", "ExecutionPlatform" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloModuleMetadataProto), global::Xla.HloModuleMetadataProto.Parser, new[]{ "CanonicalModuleId", "ModuleGroupName", "OriginalModuleId", "PartitionedModuleIds", "PassMetadata" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HloPassMetadata), global::Xla.HloPassMetadata.Parser, new[]{ "PassId", "PassName", "PipelineName", "DumpFilenames", "ModuleChanged", "ModuleId", "ModuleGroupModuleIds", "StartTimestampUsec", "EndTimestampUsec" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.EntryFunctionAttributes), global::Xla.EntryFunctionAttributes.Parser, new[]{ "Buffers", "ResultXlaShape" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.EntryFunctionAttributes.Types.ShapeIndex), global::Xla.EntryFunctionAttributes.Types.ShapeIndex.Parser, new[]{ "Indices" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.EntryFunctionAttributes.Types.BufferParameterAttributes), global::Xla.EntryFunctionAttributes.Types.BufferParameterAttributes.Parser, new[]{ "LmhloParams", "LmhloParamsPresent", "LmhloParamShapeIndex", "LmhloConstantName", "LmhloMustAlias", "LmhloOutputIndex" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.XlaRuntimeExecutableProto), global::Xla.XlaRuntimeExecutableProto.Parser, new[]{ "HloModuleProto", "EntryFuncAttrs", "ObjFile", "MlirModule" }, null, null, null, null) + })); + } + #endregion + + } + #region Enums + public enum CustomCallSchedule { + [pbr::OriginalName("SCHEDULE_NONE")] ScheduleNone = 0, + [pbr::OriginalName("SCHEDULE_LATEST")] ScheduleLatest = 1, + [pbr::OriginalName("SCHEDULE_EARLIEST")] ScheduleEarliest = 2, + } + + /// + /// The version of the API used by the custom call function. The signatures for + /// each version are given below. + /// TODO(b/189822916): Remove this enum when all clients are migrated to the + /// status-returning API. + /// + public enum CustomCallApiVersion { + [pbr::OriginalName("API_VERSION_UNSPECIFIED")] ApiVersionUnspecified = 0, + /// + /// The first version of the API, with the following signatures: + /// + /// CPU: + /// void do_custom_call(void* out, const void** in); + /// + /// GPU: + /// void do_custom_call(CUstream stream, void** buffers, + /// const char* opaque, size_t opaque_len); + /// + [pbr::OriginalName("API_VERSION_ORIGINAL")] ApiVersionOriginal = 1, + /// + /// When the ability to return success/failure status was added: + /// + /// CPU: + /// void do_custom_call(void* out, const void** in, + /// XlaCustomCallStatus* status); + /// + /// GPU: + /// void do_custom_call(CUstream stream, void** buffers, + /// const char* opaque, size_t opaque_len, + /// XlaCustomCallStatus* status); + /// + [pbr::OriginalName("API_VERSION_STATUS_RETURNING")] ApiVersionStatusReturning = 2, + /// + /// Fixes the API signatures on the CPU side of the version STATUS_RETURNING by + /// adding the opaque string so that the custom call API is consistent across + /// CPUs and GPUs. For GPUs, the behaviors invoked by + /// API_VERSION_STATUS_RETURNING and API_VERSION_STATUS_RETURNING_UNIFIED are + /// the same. + /// + /// CPU: + /// void do_custom_call(void* out, const void** in, + /// const char* opaque, size_t opaque_len, + /// XlaCustomCallStatus* status); + /// + /// GPU: + /// void do_custom_call(CUstream stream, void** buffers, + /// const char* opaque, size_t opaque_len, + /// XlaCustomCallStatus* status); + /// + [pbr::OriginalName("API_VERSION_STATUS_RETURNING_UNIFIED")] ApiVersionStatusReturningUnified = 3, + } + + public enum Kind { + /// + /// Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + /// behavior and missing has_*() APIs. + /// + [pbr::OriginalName("UNDEFINED_ALIAS")] UndefinedAlias = 0, + /// + /// The buffers may or may not alias at runtime. + /// + [pbr::OriginalName("MAY_ALIAS")] MayAlias = 1, + /// + /// The buffers must alias at runtime. + /// + [pbr::OriginalName("MUST_ALIAS")] MustAlias = 2, + } + + #endregion + + #region Messages + /// + /// Serialization of HloInstruction. + /// Next ID: 80 + /// + public sealed partial class HloInstructionProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloInstructionProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInstructionProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInstructionProto(HloInstructionProto other) : this() { + name_ = other.name_; + opcode_ = other.opcode_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + metadata_ = other.metadata_ != null ? other.metadata_.Clone() : null; + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + parameterNumber_ = other.parameterNumber_; + fusionKind_ = other.fusionKind_; + tupleIndex_ = other.tupleIndex_; + dimensions_ = other.dimensions_.Clone(); + window_ = other.window_ != null ? other.window_.Clone() : null; + convolutionDimensionNumbers_ = other.convolutionDimensionNumbers_ != null ? other.convolutionDimensionNumbers_.Clone() : null; + featureGroupCount_ = other.featureGroupCount_; + batchGroupCount_ = other.batchGroupCount_; + sliceDimensions_ = other.sliceDimensions_.Clone(); + exponentBits_ = other.exponentBits_; + mantissaBits_ = other.mantissaBits_; + dynamicSliceSizes_ = other.dynamicSliceSizes_.Clone(); + paddingConfig_ = other.paddingConfig_ != null ? other.paddingConfig_.Clone() : null; + outfeedConfig_ = other.outfeedConfig_; + distribution_ = other.distribution_; + epsilon_ = other.epsilon_; + featureIndex_ = other.featureIndex_; + channelId_ = other.channelId_; + infeedConfig_ = other.infeedConfig_; + customCallTarget_ = other.customCallTarget_; + outfeedShape_ = other.outfeedShape_ != null ? other.outfeedShape_.Clone() : null; + dotDimensionNumbers_ = other.dotDimensionNumbers_ != null ? other.dotDimensionNumbers_.Clone() : null; + fftType_ = other.fftType_; + fftLength_ = other.fftLength_.Clone(); + comparisonDirection_ = other.comparisonDirection_; + gatherDimensionNumbers_ = other.gatherDimensionNumbers_ != null ? other.gatherDimensionNumbers_.Clone() : null; + gatherSliceSizes_ = other.gatherSliceSizes_.Clone(); + id_ = other.id_; + operandIds_ = other.operandIds_.Clone(); + controlPredecessorIds_ = other.controlPredecessorIds_.Clone(); + calledComputationIds_ = other.calledComputationIds_.Clone(); + sharding_ = other.sharding_ != null ? other.sharding_.Clone() : null; + backendConfig_ = other.backendConfig_; + replicaGroups_ = other.replicaGroups_.Clone(); + allReduceId_ = other.allReduceId_; + useGlobalDeviceIds_ = other.useGlobalDeviceIds_; + isHostTransfer_ = other.isHostTransfer_; + isStable_ = other.isStable_; + scatterDimensionNumbers_ = other.scatterDimensionNumbers_ != null ? other.scatterDimensionNumbers_.Clone() : null; + precisionConfig_ = other.precisionConfig_ != null ? other.precisionConfig_.Clone() : null; + sourceTargetPairs_ = other.sourceTargetPairs_.Clone(); + domainEntrySharding_ = other.domainEntrySharding_ != null ? other.domainEntrySharding_.Clone() : null; + domainExitSharding_ = other.domainExitSharding_ != null ? other.domainExitSharding_.Clone() : null; + constrainLayout_ = other.constrainLayout_; + operandShapesWithLayout_ = other.operandShapesWithLayout_.Clone(); + triangularSolveOptions_ = other.triangularSolveOptions_ != null ? other.triangularSolveOptions_.Clone() : null; + choleskyOptions_ = other.choleskyOptions_ != null ? other.choleskyOptions_.Clone() : null; + parameterReplication_ = other.parameterReplication_ != null ? other.parameterReplication_.Clone() : null; + customCallHasSideEffect_ = other.customCallHasSideEffect_; + customCallOutputOperandAliasing_ = other.customCallOutputOperandAliasing_.Clone(); + customCallSchedule_ = other.customCallSchedule_; + delta_ = other.delta_; + indicesAreSorted_ = other.indicesAreSorted_; + frontendAttributes_ = other.frontendAttributes_ != null ? other.frontendAttributes_.Clone() : null; + uniqueIndices_ = other.uniqueIndices_; + rngAlgorithm_ = other.rngAlgorithm_; + comparisonType_ = other.comparisonType_; + isCrossProgramPrefetch_ = other.isCrossProgramPrefetch_; + paddingType_ = other.paddingType_; + customCallApiVersion_ = other.customCallApiVersion_; + asyncGroupId_ = other.asyncGroupId_; + asyncExecutionThread_ = other.asyncExecutionThread_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInstructionProto Clone() { + return new HloInstructionProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "opcode" field. + public const int OpcodeFieldNumber = 2; + private string opcode_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Opcode { + get { return opcode_; } + set { + opcode_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 3; + private global::Xla.ShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "metadata" field. + public const int MetadataFieldNumber = 7; + private global::Xla.OpMetadata metadata_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpMetadata Metadata { + get { return metadata_; } + set { + metadata_ = value; + } + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 8; + private global::Xla.LiteralProto literal_; + /// + /// Literal, only present for kConstant. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + /// Field number for the "parameter_number" field. + public const int ParameterNumberFieldNumber = 9; + private long parameterNumber_; + /// + /// Parameter number is only present for kParameter. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ParameterNumber { + get { return parameterNumber_; } + set { + parameterNumber_ = value; + } + } + + /// Field number for the "fusion_kind" field. + public const int FusionKindFieldNumber = 11; + private string fusionKind_ = ""; + /// + /// Fusion state, only present for kFusion. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FusionKind { + get { return fusionKind_; } + set { + fusionKind_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tuple_index" field. + public const int TupleIndexFieldNumber = 13; + private long tupleIndex_; + /// + /// Index for kGetTupleElement. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TupleIndex { + get { return tupleIndex_; } + set { + tupleIndex_ = value; + } + } + + /// Field number for the "dimensions" field. + public const int DimensionsFieldNumber = 14; + private static readonly pb::FieldCodec _repeated_dimensions_codec + = pb::FieldCodec.ForInt64(114); + private readonly pbc::RepeatedField dimensions_ = new pbc::RepeatedField(); + /// + /// Dimensions present for some operations that require reshaping or + /// broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dimensions { + get { return dimensions_; } + } + + /// Field number for the "window" field. + public const int WindowFieldNumber = 15; + private global::Xla.Window window_; + /// + /// Describes the window in a windowed operation such as convolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.Window Window { + get { return window_; } + set { + window_ = value; + } + } + + /// Field number for the "convolution_dimension_numbers" field. + public const int ConvolutionDimensionNumbersFieldNumber = 16; + private global::Xla.ConvolutionDimensionNumbers convolutionDimensionNumbers_; + /// + /// Describes the dimension numbers used for a convolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ConvolutionDimensionNumbers ConvolutionDimensionNumbers { + get { return convolutionDimensionNumbers_; } + set { + convolutionDimensionNumbers_ = value; + } + } + + /// Field number for the "feature_group_count" field. + public const int FeatureGroupCountFieldNumber = 50; + private long featureGroupCount_; + /// + /// The number of feature groups. Used for a convolution. Must be a divisor of + /// the input feature dimension and output feature dimension. If not specified, + /// it will use a default value of 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long FeatureGroupCount { + get { return featureGroupCount_; } + set { + featureGroupCount_ = value; + } + } + + /// Field number for the "batch_group_count" field. + public const int BatchGroupCountFieldNumber = 58; + private long batchGroupCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BatchGroupCount { + get { return batchGroupCount_; } + set { + batchGroupCount_ = value; + } + } + + /// Field number for the "slice_dimensions" field. + public const int SliceDimensionsFieldNumber = 17; + private static readonly pb::FieldCodec _repeated_sliceDimensions_codec + = pb::FieldCodec.ForMessage(138, global::Xla.HloInstructionProto.Types.SliceDimensions.Parser); + private readonly pbc::RepeatedField sliceDimensions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SliceDimensions { + get { return sliceDimensions_; } + } + + /// Field number for the "exponent_bits" field. + public const int ExponentBitsFieldNumber = 18; + private int exponentBits_; + /// + /// The bit sizes for a reduce-precision operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ExponentBits { + get { return exponentBits_; } + set { + exponentBits_ = value; + } + } + + /// Field number for the "mantissa_bits" field. + public const int MantissaBitsFieldNumber = 19; + private int mantissaBits_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int MantissaBits { + get { return mantissaBits_; } + set { + mantissaBits_ = value; + } + } + + /// Field number for the "dynamic_slice_sizes" field. + public const int DynamicSliceSizesFieldNumber = 20; + private static readonly pb::FieldCodec _repeated_dynamicSliceSizes_codec + = pb::FieldCodec.ForInt64(162); + private readonly pbc::RepeatedField dynamicSliceSizes_ = new pbc::RepeatedField(); + /// + /// Describes the [start, start + size) range size for a dynamic slice + /// ('start' is specified dynamically in the second operand of the operation). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DynamicSliceSizes { + get { return dynamicSliceSizes_; } + } + + /// Field number for the "padding_config" field. + public const int PaddingConfigFieldNumber = 21; + private global::Xla.PaddingConfig paddingConfig_; + /// + /// The padding configuration that describes the edge padding and interior + /// padding of this pad instruction. Only set for pad instructions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.PaddingConfig PaddingConfig { + get { return paddingConfig_; } + set { + paddingConfig_ = value; + } + } + + /// Field number for the "outfeed_config" field. + public const int OutfeedConfigFieldNumber = 22; + private pb::ByteString outfeedConfig_ = pb::ByteString.Empty; + /// + /// Outfeed configuration information, only present for kOutfeed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString OutfeedConfig { + get { return outfeedConfig_; } + set { + outfeedConfig_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "distribution" field. + public const int DistributionFieldNumber = 23; + private global::Xla.RandomDistribution distribution_ = global::Xla.RandomDistribution.RngInvalid; + /// + /// The distribution requested for random number generation. + /// Only present for kRng. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.RandomDistribution Distribution { + get { return distribution_; } + set { + distribution_ = value; + } + } + + /// Field number for the "epsilon" field. + public const int EpsilonFieldNumber = 24; + private float epsilon_; + /// + /// A small float number added to the variance to avoid divide-by-zero error. + /// Only present for kBatchNormTraining, kBatchNormInference, and + /// kBatchNormGrad. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public float Epsilon { + get { return epsilon_; } + set { + epsilon_ = value; + } + } + + /// Field number for the "feature_index" field. + public const int FeatureIndexFieldNumber = 25; + private long featureIndex_; + /// + /// An integer value representing the index of the feature dimension. + /// Only present for kBatchNormTraining, kBatchNormInference, and + /// kBatchNormGrad. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long FeatureIndex { + get { return featureIndex_; } + set { + featureIndex_ = value; + } + } + + /// Field number for the "channel_id" field. + public const int ChannelIdFieldNumber = 26; + private long channelId_; + /// + /// Represents a unique identifier for each Send/Recv instruction pair or + /// optionally for collective instructions (AllReduce, CollectivePermute, + /// AllToAll). Non-positive channel_id is equivalent to no channel id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ChannelId { + get { return channelId_; } + set { + channelId_ = value; + } + } + + /// Field number for the "infeed_config" field. + public const int InfeedConfigFieldNumber = 27; + private pb::ByteString infeedConfig_ = pb::ByteString.Empty; + /// + /// The string representation of the infeed configuration. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString InfeedConfig { + get { return infeedConfig_; } + set { + infeedConfig_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "custom_call_target" field. + public const int CustomCallTargetFieldNumber = 28; + private string customCallTarget_ = ""; + /// + /// Name of a external target (eg, global symbol) to call, only present for + /// kCustomCall. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CustomCallTarget { + get { return customCallTarget_; } + set { + customCallTarget_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "outfeed_shape" field. + public const int OutfeedShapeFieldNumber = 29; + private global::Xla.ShapeProto outfeedShape_; + /// + /// Shape of outfeed request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto OutfeedShape { + get { return outfeedShape_; } + set { + outfeedShape_ = value; + } + } + + /// Field number for the "dot_dimension_numbers" field. + public const int DotDimensionNumbersFieldNumber = 30; + private global::Xla.DotDimensionNumbers dotDimensionNumbers_; + /// + /// Describes the dimension numbers used for a dot operation + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DotDimensionNumbers DotDimensionNumbers { + get { return dotDimensionNumbers_; } + set { + dotDimensionNumbers_ = value; + } + } + + /// Field number for the "fft_type" field. + public const int FftTypeFieldNumber = 31; + private global::Xla.FftType fftType_ = global::Xla.FftType.Fft; + /// + /// FFT type (FFT, IFFT, etc). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.FftType FftType { + get { return fftType_; } + set { + fftType_ = value; + } + } + + /// Field number for the "fft_length" field. + public const int FftLengthFieldNumber = 32; + private static readonly pb::FieldCodec _repeated_fftLength_codec + = pb::FieldCodec.ForInt64(258); + private readonly pbc::RepeatedField fftLength_ = new pbc::RepeatedField(); + /// + /// FFT length. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField FftLength { + get { return fftLength_; } + } + + /// Field number for the "comparison_direction" field. + public const int ComparisonDirectionFieldNumber = 63; + private string comparisonDirection_ = ""; + /// + /// Comparison direction only used for kCompare. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ComparisonDirection { + get { return comparisonDirection_; } + set { + comparisonDirection_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "gather_dimension_numbers" field. + public const int GatherDimensionNumbersFieldNumber = 33; + private global::Xla.GatherDimensionNumbers gatherDimensionNumbers_; + /// + /// Gather dimension numbers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GatherDimensionNumbers GatherDimensionNumbers { + get { return gatherDimensionNumbers_; } + set { + gatherDimensionNumbers_ = value; + } + } + + /// Field number for the "gather_slice_sizes" field. + public const int GatherSliceSizesFieldNumber = 34; + private static readonly pb::FieldCodec _repeated_gatherSliceSizes_codec + = pb::FieldCodec.ForInt64(274); + private readonly pbc::RepeatedField gatherSliceSizes_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField GatherSliceSizes { + get { return gatherSliceSizes_; } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 35; + private long id_; + /// + /// The id of this instruction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "operand_ids" field. + public const int OperandIdsFieldNumber = 36; + private static readonly pb::FieldCodec _repeated_operandIds_codec + = pb::FieldCodec.ForInt64(290); + private readonly pbc::RepeatedField operandIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OperandIds { + get { return operandIds_; } + } + + /// Field number for the "control_predecessor_ids" field. + public const int ControlPredecessorIdsFieldNumber = 37; + private static readonly pb::FieldCodec _repeated_controlPredecessorIds_codec + = pb::FieldCodec.ForInt64(298); + private readonly pbc::RepeatedField controlPredecessorIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ControlPredecessorIds { + get { return controlPredecessorIds_; } + } + + /// Field number for the "called_computation_ids" field. + public const int CalledComputationIdsFieldNumber = 38; + private static readonly pb::FieldCodec _repeated_calledComputationIds_codec + = pb::FieldCodec.ForInt64(306); + private readonly pbc::RepeatedField calledComputationIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CalledComputationIds { + get { return calledComputationIds_; } + } + + /// Field number for the "sharding" field. + public const int ShardingFieldNumber = 40; + private global::Xla.OpSharding sharding_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpSharding Sharding { + get { return sharding_; } + set { + sharding_ = value; + } + } + + /// Field number for the "backend_config" field. + public const int BackendConfigFieldNumber = 43; + private pb::ByteString backendConfig_ = pb::ByteString.Empty; + /// + /// Backend configuration for the instruction. Has backend-specific meaning. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString BackendConfig { + get { return backendConfig_; } + set { + backendConfig_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "replica_groups" field. + public const int ReplicaGroupsFieldNumber = 49; + private static readonly pb::FieldCodec _repeated_replicaGroups_codec + = pb::FieldCodec.ForMessage(394, global::Xla.ReplicaGroup.Parser); + private readonly pbc::RepeatedField replicaGroups_ = new pbc::RepeatedField(); + /// + /// Cross replica op fields. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ReplicaGroups { + get { return replicaGroups_; } + } + + /// Field number for the "all_reduce_id" field. + public const int AllReduceIdFieldNumber = 45; + private long allReduceId_; + /// + /// Deprecated, but keeping it for backward compatibility. Use channel_id. + /// Non-positive all_reduce_id is equivalent to no all_reduce_id. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllReduceId { + get { return allReduceId_; } + set { + allReduceId_ = value; + } + } + + /// Field number for the "use_global_device_ids" field. + public const int UseGlobalDeviceIdsFieldNumber = 71; + private bool useGlobalDeviceIds_; + /// + /// If true, interprets ids in ReplicaGroup as global device ids, which is + /// a linearized id of `replica_id * partition_count + partition_id`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseGlobalDeviceIds { + get { return useGlobalDeviceIds_; } + set { + useGlobalDeviceIds_ = value; + } + } + + /// Field number for the "is_host_transfer" field. + public const int IsHostTransferFieldNumber = 47; + private bool isHostTransfer_; + /// + /// Whether this Send/Recv instruction transfers data to/from the host. Only + /// present for Send and Recv instructions and their SendDone and RecvDone + /// partners. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsHostTransfer { + get { return isHostTransfer_; } + set { + isHostTransfer_ = value; + } + } + + /// Field number for the "is_stable" field. + public const int IsStableFieldNumber = 60; + private bool isStable_; + /// + /// Whether this Sort instruction should be stable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsStable { + get { return isStable_; } + set { + isStable_ = value; + } + } + + /// Field number for the "scatter_dimension_numbers" field. + public const int ScatterDimensionNumbersFieldNumber = 48; + private global::Xla.ScatterDimensionNumbers scatterDimensionNumbers_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ScatterDimensionNumbers ScatterDimensionNumbers { + get { return scatterDimensionNumbers_; } + set { + scatterDimensionNumbers_ = value; + } + } + + /// Field number for the "precision_config" field. + public const int PrecisionConfigFieldNumber = 51; + private global::Xla.PrecisionConfig precisionConfig_; + /// + /// Precision configuration for the instruction. Has backend-specific meaning. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.PrecisionConfig PrecisionConfig { + get { return precisionConfig_; } + set { + precisionConfig_ = value; + } + } + + /// Field number for the "source_target_pairs" field. + public const int SourceTargetPairsFieldNumber = 52; + private static readonly pb::FieldCodec _repeated_sourceTargetPairs_codec + = pb::FieldCodec.ForMessage(418, global::Xla.SourceTarget.Parser); + private readonly pbc::RepeatedField sourceTargetPairs_ = new pbc::RepeatedField(); + /// + /// Collective permute field. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SourceTargetPairs { + get { return sourceTargetPairs_; } + } + + /// Field number for the "domain_entry_sharding" field. + public const int DomainEntryShardingFieldNumber = 54; + private global::Xla.OpSharding domainEntrySharding_; + /// + /// Sharding for kDomain instructions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpSharding DomainEntrySharding { + get { return domainEntrySharding_; } + set { + domainEntrySharding_ = value; + } + } + + /// Field number for the "domain_exit_sharding" field. + public const int DomainExitShardingFieldNumber = 55; + private global::Xla.OpSharding domainExitSharding_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpSharding DomainExitSharding { + get { return domainExitSharding_; } + set { + domainExitSharding_ = value; + } + } + + /// Field number for the "constrain_layout" field. + public const int ConstrainLayoutFieldNumber = 56; + private bool constrainLayout_; + /// + /// For custom call this indicates that the layouts are constrained. If + /// constrain_layout is true then the 'shape' field must contain a layout, and + /// 'operand_shapes_with_layout' must contain a shape with layout for each + /// operand. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ConstrainLayout { + get { return constrainLayout_; } + set { + constrainLayout_ = value; + } + } + + /// Field number for the "operand_shapes_with_layout" field. + public const int OperandShapesWithLayoutFieldNumber = 57; + private static readonly pb::FieldCodec _repeated_operandShapesWithLayout_codec + = pb::FieldCodec.ForMessage(458, global::Xla.ShapeProto.Parser); + private readonly pbc::RepeatedField operandShapesWithLayout_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OperandShapesWithLayout { + get { return operandShapesWithLayout_; } + } + + /// Field number for the "triangular_solve_options" field. + public const int TriangularSolveOptionsFieldNumber = 59; + private global::Xla.TriangularSolveOptions triangularSolveOptions_; + /// + /// Options for TriangularSolve + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.TriangularSolveOptions TriangularSolveOptions { + get { return triangularSolveOptions_; } + set { + triangularSolveOptions_ = value; + } + } + + /// Field number for the "cholesky_options" field. + public const int CholeskyOptionsFieldNumber = 62; + private global::Xla.CholeskyOptions choleskyOptions_; + /// + /// Options for Cholesky + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.CholeskyOptions CholeskyOptions { + get { return choleskyOptions_; } + set { + choleskyOptions_ = value; + } + } + + /// Field number for the "parameter_replication" field. + public const int ParameterReplicationFieldNumber = 61; + private global::Xla.ParameterReplication parameterReplication_; + /// + /// Describes how parameters behave with regards to replicas. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ParameterReplication ParameterReplication { + get { return parameterReplication_; } + set { + parameterReplication_ = value; + } + } + + /// Field number for the "custom_call_has_side_effect" field. + public const int CustomCallHasSideEffectFieldNumber = 65; + private bool customCallHasSideEffect_; + /// + /// Whether the kCustomCall instruction has side-effects, only present for + /// kCustomCall. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CustomCallHasSideEffect { + get { return customCallHasSideEffect_; } + set { + customCallHasSideEffect_ = value; + } + } + + /// Field number for the "custom_call_output_operand_aliasing" field. + public const int CustomCallOutputOperandAliasingFieldNumber = 74; + private static readonly pb::FieldCodec _repeated_customCallOutputOperandAliasing_codec + = pb::FieldCodec.ForMessage(594, global::Xla.CustomCallOutputOperandAliasing.Parser); + private readonly pbc::RepeatedField customCallOutputOperandAliasing_ = new pbc::RepeatedField(); + /// + /// A list of CustomCallOutputOperandAliasing pairs that specifies aliasing + /// buffers between output and operands for kCustomCall. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CustomCallOutputOperandAliasing { + get { return customCallOutputOperandAliasing_; } + } + + /// Field number for the "custom_call_schedule" field. + public const int CustomCallScheduleFieldNumber = 76; + private global::Xla.CustomCallSchedule customCallSchedule_ = global::Xla.CustomCallSchedule.ScheduleNone; + /// + /// Specifies the desired schedule for the custom-call. The field is only + /// present for custom-call. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.CustomCallSchedule CustomCallSchedule { + get { return customCallSchedule_; } + set { + customCallSchedule_ = value; + } + } + + /// Field number for the "delta" field. + public const int DeltaFieldNumber = 66; + private long delta_; + /// + /// The delta value for kRngGetAndUpdateState. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Delta { + get { return delta_; } + set { + delta_ = value; + } + } + + /// Field number for the "indices_are_sorted" field. + public const int IndicesAreSortedFieldNumber = 67; + private bool indicesAreSorted_; + /// + /// Specifies if the gather/scatter indices are guaranteed to be sorted by the + /// caller. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IndicesAreSorted { + get { return indicesAreSorted_; } + set { + indicesAreSorted_ = value; + } + } + + /// Field number for the "frontend_attributes" field. + public const int FrontendAttributesFieldNumber = 68; + private global::Xla.FrontendAttributes frontendAttributes_; + /// + /// Frontend attributes to pass to the XLA backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.FrontendAttributes FrontendAttributes { + get { return frontendAttributes_; } + set { + frontendAttributes_ = value; + } + } + + /// Field number for the "unique_indices" field. + public const int UniqueIndicesFieldNumber = 69; + private bool uniqueIndices_; + /// + /// Specifies if all elements updated are guaranteed to be unique by + /// the caller. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UniqueIndices { + get { return uniqueIndices_; } + set { + uniqueIndices_ = value; + } + } + + /// Field number for the "rng_algorithm" field. + public const int RngAlgorithmFieldNumber = 70; + private global::Xla.RandomAlgorithm rngAlgorithm_ = global::Xla.RandomAlgorithm.RngDefault; + /// + /// RNG algorithm used by kRngBitGenerator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.RandomAlgorithm RngAlgorithm { + get { return rngAlgorithm_; } + set { + rngAlgorithm_ = value; + } + } + + /// Field number for the "comparison_type" field. + public const int ComparisonTypeFieldNumber = 72; + private string comparisonType_ = ""; + /// + /// The comparison type used for kCompare. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ComparisonType { + get { return comparisonType_; } + set { + comparisonType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "is_cross_program_prefetch" field. + public const int IsCrossProgramPrefetchFieldNumber = 73; + private bool isCrossProgramPrefetch_; + /// + /// Specifies if this is a cross-program-prefetch, used by kCopyStart. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsCrossProgramPrefetch { + get { return isCrossProgramPrefetch_; } + set { + isCrossProgramPrefetch_ = value; + } + } + + /// Field number for the "padding_type" field. + public const int PaddingTypeFieldNumber = 75; + private global::Xla.PaddingType paddingType_ = global::Xla.PaddingType.PaddingInvalid; + /// + /// If a convolution is dynamic, a dynamic padding type will be specified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.PaddingType PaddingType { + get { return paddingType_; } + set { + paddingType_ = value; + } + } + + /// Field number for the "custom_call_api_version" field. + public const int CustomCallApiVersionFieldNumber = 77; + private global::Xla.CustomCallApiVersion customCallApiVersion_ = global::Xla.CustomCallApiVersion.ApiVersionUnspecified; + /// + /// The API version used by the custom call function. This field is only + /// present for custom-call. + /// TODO(b/189822916): Remove this field when all clients are migrated to the + /// status-returning API. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.CustomCallApiVersion CustomCallApiVersion { + get { return customCallApiVersion_; } + set { + customCallApiVersion_ = value; + } + } + + /// Field number for the "async_group_id" field. + public const int AsyncGroupIdFieldNumber = 78; + private long asyncGroupId_; + /// + /// Represents a unique identifier for an async group which consists of an + /// async start, async done, and zero or more async update operations. + /// Negative async_group_id is equivalent to no async group id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AsyncGroupId { + get { return asyncGroupId_; } + set { + asyncGroupId_ = value; + } + } + + /// Field number for the "async_execution_thread" field. + public const int AsyncExecutionThreadFieldNumber = 79; + private string asyncExecutionThread_ = ""; + /// + /// Represents a unique execution thread name for one or more async groups. + /// Each HLO module may contain a main thread and one or more parallel threads. + /// Empty async_execution_thread is equivalent to main thread. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AsyncExecutionThread { + get { return asyncExecutionThread_; } + set { + asyncExecutionThread_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloInstructionProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloInstructionProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Opcode != other.Opcode) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (!object.Equals(Metadata, other.Metadata)) return false; + if (!object.Equals(Literal, other.Literal)) return false; + if (ParameterNumber != other.ParameterNumber) return false; + if (FusionKind != other.FusionKind) return false; + if (TupleIndex != other.TupleIndex) return false; + if(!dimensions_.Equals(other.dimensions_)) return false; + if (!object.Equals(Window, other.Window)) return false; + if (!object.Equals(ConvolutionDimensionNumbers, other.ConvolutionDimensionNumbers)) return false; + if (FeatureGroupCount != other.FeatureGroupCount) return false; + if (BatchGroupCount != other.BatchGroupCount) return false; + if(!sliceDimensions_.Equals(other.sliceDimensions_)) return false; + if (ExponentBits != other.ExponentBits) return false; + if (MantissaBits != other.MantissaBits) return false; + if(!dynamicSliceSizes_.Equals(other.dynamicSliceSizes_)) return false; + if (!object.Equals(PaddingConfig, other.PaddingConfig)) return false; + if (OutfeedConfig != other.OutfeedConfig) return false; + if (Distribution != other.Distribution) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Epsilon, other.Epsilon)) return false; + if (FeatureIndex != other.FeatureIndex) return false; + if (ChannelId != other.ChannelId) return false; + if (InfeedConfig != other.InfeedConfig) return false; + if (CustomCallTarget != other.CustomCallTarget) return false; + if (!object.Equals(OutfeedShape, other.OutfeedShape)) return false; + if (!object.Equals(DotDimensionNumbers, other.DotDimensionNumbers)) return false; + if (FftType != other.FftType) return false; + if(!fftLength_.Equals(other.fftLength_)) return false; + if (ComparisonDirection != other.ComparisonDirection) return false; + if (!object.Equals(GatherDimensionNumbers, other.GatherDimensionNumbers)) return false; + if(!gatherSliceSizes_.Equals(other.gatherSliceSizes_)) return false; + if (Id != other.Id) return false; + if(!operandIds_.Equals(other.operandIds_)) return false; + if(!controlPredecessorIds_.Equals(other.controlPredecessorIds_)) return false; + if(!calledComputationIds_.Equals(other.calledComputationIds_)) return false; + if (!object.Equals(Sharding, other.Sharding)) return false; + if (BackendConfig != other.BackendConfig) return false; + if(!replicaGroups_.Equals(other.replicaGroups_)) return false; + if (AllReduceId != other.AllReduceId) return false; + if (UseGlobalDeviceIds != other.UseGlobalDeviceIds) return false; + if (IsHostTransfer != other.IsHostTransfer) return false; + if (IsStable != other.IsStable) return false; + if (!object.Equals(ScatterDimensionNumbers, other.ScatterDimensionNumbers)) return false; + if (!object.Equals(PrecisionConfig, other.PrecisionConfig)) return false; + if(!sourceTargetPairs_.Equals(other.sourceTargetPairs_)) return false; + if (!object.Equals(DomainEntrySharding, other.DomainEntrySharding)) return false; + if (!object.Equals(DomainExitSharding, other.DomainExitSharding)) return false; + if (ConstrainLayout != other.ConstrainLayout) return false; + if(!operandShapesWithLayout_.Equals(other.operandShapesWithLayout_)) return false; + if (!object.Equals(TriangularSolveOptions, other.TriangularSolveOptions)) return false; + if (!object.Equals(CholeskyOptions, other.CholeskyOptions)) return false; + if (!object.Equals(ParameterReplication, other.ParameterReplication)) return false; + if (CustomCallHasSideEffect != other.CustomCallHasSideEffect) return false; + if(!customCallOutputOperandAliasing_.Equals(other.customCallOutputOperandAliasing_)) return false; + if (CustomCallSchedule != other.CustomCallSchedule) return false; + if (Delta != other.Delta) return false; + if (IndicesAreSorted != other.IndicesAreSorted) return false; + if (!object.Equals(FrontendAttributes, other.FrontendAttributes)) return false; + if (UniqueIndices != other.UniqueIndices) return false; + if (RngAlgorithm != other.RngAlgorithm) return false; + if (ComparisonType != other.ComparisonType) return false; + if (IsCrossProgramPrefetch != other.IsCrossProgramPrefetch) return false; + if (PaddingType != other.PaddingType) return false; + if (CustomCallApiVersion != other.CustomCallApiVersion) return false; + if (AsyncGroupId != other.AsyncGroupId) return false; + if (AsyncExecutionThread != other.AsyncExecutionThread) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Opcode.Length != 0) hash ^= Opcode.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (metadata_ != null) hash ^= Metadata.GetHashCode(); + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (ParameterNumber != 0L) hash ^= ParameterNumber.GetHashCode(); + if (FusionKind.Length != 0) hash ^= FusionKind.GetHashCode(); + if (TupleIndex != 0L) hash ^= TupleIndex.GetHashCode(); + hash ^= dimensions_.GetHashCode(); + if (window_ != null) hash ^= Window.GetHashCode(); + if (convolutionDimensionNumbers_ != null) hash ^= ConvolutionDimensionNumbers.GetHashCode(); + if (FeatureGroupCount != 0L) hash ^= FeatureGroupCount.GetHashCode(); + if (BatchGroupCount != 0L) hash ^= BatchGroupCount.GetHashCode(); + hash ^= sliceDimensions_.GetHashCode(); + if (ExponentBits != 0) hash ^= ExponentBits.GetHashCode(); + if (MantissaBits != 0) hash ^= MantissaBits.GetHashCode(); + hash ^= dynamicSliceSizes_.GetHashCode(); + if (paddingConfig_ != null) hash ^= PaddingConfig.GetHashCode(); + if (OutfeedConfig.Length != 0) hash ^= OutfeedConfig.GetHashCode(); + if (Distribution != global::Xla.RandomDistribution.RngInvalid) hash ^= Distribution.GetHashCode(); + if (Epsilon != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Epsilon); + if (FeatureIndex != 0L) hash ^= FeatureIndex.GetHashCode(); + if (ChannelId != 0L) hash ^= ChannelId.GetHashCode(); + if (InfeedConfig.Length != 0) hash ^= InfeedConfig.GetHashCode(); + if (CustomCallTarget.Length != 0) hash ^= CustomCallTarget.GetHashCode(); + if (outfeedShape_ != null) hash ^= OutfeedShape.GetHashCode(); + if (dotDimensionNumbers_ != null) hash ^= DotDimensionNumbers.GetHashCode(); + if (FftType != global::Xla.FftType.Fft) hash ^= FftType.GetHashCode(); + hash ^= fftLength_.GetHashCode(); + if (ComparisonDirection.Length != 0) hash ^= ComparisonDirection.GetHashCode(); + if (gatherDimensionNumbers_ != null) hash ^= GatherDimensionNumbers.GetHashCode(); + hash ^= gatherSliceSizes_.GetHashCode(); + if (Id != 0L) hash ^= Id.GetHashCode(); + hash ^= operandIds_.GetHashCode(); + hash ^= controlPredecessorIds_.GetHashCode(); + hash ^= calledComputationIds_.GetHashCode(); + if (sharding_ != null) hash ^= Sharding.GetHashCode(); + if (BackendConfig.Length != 0) hash ^= BackendConfig.GetHashCode(); + hash ^= replicaGroups_.GetHashCode(); + if (AllReduceId != 0L) hash ^= AllReduceId.GetHashCode(); + if (UseGlobalDeviceIds != false) hash ^= UseGlobalDeviceIds.GetHashCode(); + if (IsHostTransfer != false) hash ^= IsHostTransfer.GetHashCode(); + if (IsStable != false) hash ^= IsStable.GetHashCode(); + if (scatterDimensionNumbers_ != null) hash ^= ScatterDimensionNumbers.GetHashCode(); + if (precisionConfig_ != null) hash ^= PrecisionConfig.GetHashCode(); + hash ^= sourceTargetPairs_.GetHashCode(); + if (domainEntrySharding_ != null) hash ^= DomainEntrySharding.GetHashCode(); + if (domainExitSharding_ != null) hash ^= DomainExitSharding.GetHashCode(); + if (ConstrainLayout != false) hash ^= ConstrainLayout.GetHashCode(); + hash ^= operandShapesWithLayout_.GetHashCode(); + if (triangularSolveOptions_ != null) hash ^= TriangularSolveOptions.GetHashCode(); + if (choleskyOptions_ != null) hash ^= CholeskyOptions.GetHashCode(); + if (parameterReplication_ != null) hash ^= ParameterReplication.GetHashCode(); + if (CustomCallHasSideEffect != false) hash ^= CustomCallHasSideEffect.GetHashCode(); + hash ^= customCallOutputOperandAliasing_.GetHashCode(); + if (CustomCallSchedule != global::Xla.CustomCallSchedule.ScheduleNone) hash ^= CustomCallSchedule.GetHashCode(); + if (Delta != 0L) hash ^= Delta.GetHashCode(); + if (IndicesAreSorted != false) hash ^= IndicesAreSorted.GetHashCode(); + if (frontendAttributes_ != null) hash ^= FrontendAttributes.GetHashCode(); + if (UniqueIndices != false) hash ^= UniqueIndices.GetHashCode(); + if (RngAlgorithm != global::Xla.RandomAlgorithm.RngDefault) hash ^= RngAlgorithm.GetHashCode(); + if (ComparisonType.Length != 0) hash ^= ComparisonType.GetHashCode(); + if (IsCrossProgramPrefetch != false) hash ^= IsCrossProgramPrefetch.GetHashCode(); + if (PaddingType != global::Xla.PaddingType.PaddingInvalid) hash ^= PaddingType.GetHashCode(); + if (CustomCallApiVersion != global::Xla.CustomCallApiVersion.ApiVersionUnspecified) hash ^= CustomCallApiVersion.GetHashCode(); + if (AsyncGroupId != 0L) hash ^= AsyncGroupId.GetHashCode(); + if (AsyncExecutionThread.Length != 0) hash ^= AsyncExecutionThread.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Opcode.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Opcode); + } + if (shape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Shape); + } + if (metadata_ != null) { + output.WriteRawTag(58); + output.WriteMessage(Metadata); + } + if (literal_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Literal); + } + if (ParameterNumber != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ParameterNumber); + } + if (FusionKind.Length != 0) { + output.WriteRawTag(90); + output.WriteString(FusionKind); + } + if (TupleIndex != 0L) { + output.WriteRawTag(104); + output.WriteInt64(TupleIndex); + } + dimensions_.WriteTo(output, _repeated_dimensions_codec); + if (window_ != null) { + output.WriteRawTag(122); + output.WriteMessage(Window); + } + if (convolutionDimensionNumbers_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(ConvolutionDimensionNumbers); + } + sliceDimensions_.WriteTo(output, _repeated_sliceDimensions_codec); + if (ExponentBits != 0) { + output.WriteRawTag(144, 1); + output.WriteInt32(ExponentBits); + } + if (MantissaBits != 0) { + output.WriteRawTag(152, 1); + output.WriteInt32(MantissaBits); + } + dynamicSliceSizes_.WriteTo(output, _repeated_dynamicSliceSizes_codec); + if (paddingConfig_ != null) { + output.WriteRawTag(170, 1); + output.WriteMessage(PaddingConfig); + } + if (OutfeedConfig.Length != 0) { + output.WriteRawTag(178, 1); + output.WriteBytes(OutfeedConfig); + } + if (Distribution != global::Xla.RandomDistribution.RngInvalid) { + output.WriteRawTag(184, 1); + output.WriteEnum((int) Distribution); + } + if (Epsilon != 0F) { + output.WriteRawTag(197, 1); + output.WriteFloat(Epsilon); + } + if (FeatureIndex != 0L) { + output.WriteRawTag(200, 1); + output.WriteInt64(FeatureIndex); + } + if (ChannelId != 0L) { + output.WriteRawTag(208, 1); + output.WriteInt64(ChannelId); + } + if (InfeedConfig.Length != 0) { + output.WriteRawTag(218, 1); + output.WriteBytes(InfeedConfig); + } + if (CustomCallTarget.Length != 0) { + output.WriteRawTag(226, 1); + output.WriteString(CustomCallTarget); + } + if (outfeedShape_ != null) { + output.WriteRawTag(234, 1); + output.WriteMessage(OutfeedShape); + } + if (dotDimensionNumbers_ != null) { + output.WriteRawTag(242, 1); + output.WriteMessage(DotDimensionNumbers); + } + if (FftType != global::Xla.FftType.Fft) { + output.WriteRawTag(248, 1); + output.WriteEnum((int) FftType); + } + fftLength_.WriteTo(output, _repeated_fftLength_codec); + if (gatherDimensionNumbers_ != null) { + output.WriteRawTag(138, 2); + output.WriteMessage(GatherDimensionNumbers); + } + gatherSliceSizes_.WriteTo(output, _repeated_gatherSliceSizes_codec); + if (Id != 0L) { + output.WriteRawTag(152, 2); + output.WriteInt64(Id); + } + operandIds_.WriteTo(output, _repeated_operandIds_codec); + controlPredecessorIds_.WriteTo(output, _repeated_controlPredecessorIds_codec); + calledComputationIds_.WriteTo(output, _repeated_calledComputationIds_codec); + if (sharding_ != null) { + output.WriteRawTag(194, 2); + output.WriteMessage(Sharding); + } + if (BackendConfig.Length != 0) { + output.WriteRawTag(218, 2); + output.WriteBytes(BackendConfig); + } + if (AllReduceId != 0L) { + output.WriteRawTag(232, 2); + output.WriteInt64(AllReduceId); + } + if (IsHostTransfer != false) { + output.WriteRawTag(248, 2); + output.WriteBool(IsHostTransfer); + } + if (scatterDimensionNumbers_ != null) { + output.WriteRawTag(130, 3); + output.WriteMessage(ScatterDimensionNumbers); + } + replicaGroups_.WriteTo(output, _repeated_replicaGroups_codec); + if (FeatureGroupCount != 0L) { + output.WriteRawTag(144, 3); + output.WriteInt64(FeatureGroupCount); + } + if (precisionConfig_ != null) { + output.WriteRawTag(154, 3); + output.WriteMessage(PrecisionConfig); + } + sourceTargetPairs_.WriteTo(output, _repeated_sourceTargetPairs_codec); + if (domainEntrySharding_ != null) { + output.WriteRawTag(178, 3); + output.WriteMessage(DomainEntrySharding); + } + if (domainExitSharding_ != null) { + output.WriteRawTag(186, 3); + output.WriteMessage(DomainExitSharding); + } + if (ConstrainLayout != false) { + output.WriteRawTag(192, 3); + output.WriteBool(ConstrainLayout); + } + operandShapesWithLayout_.WriteTo(output, _repeated_operandShapesWithLayout_codec); + if (BatchGroupCount != 0L) { + output.WriteRawTag(208, 3); + output.WriteInt64(BatchGroupCount); + } + if (triangularSolveOptions_ != null) { + output.WriteRawTag(218, 3); + output.WriteMessage(TriangularSolveOptions); + } + if (IsStable != false) { + output.WriteRawTag(224, 3); + output.WriteBool(IsStable); + } + if (parameterReplication_ != null) { + output.WriteRawTag(234, 3); + output.WriteMessage(ParameterReplication); + } + if (choleskyOptions_ != null) { + output.WriteRawTag(242, 3); + output.WriteMessage(CholeskyOptions); + } + if (ComparisonDirection.Length != 0) { + output.WriteRawTag(250, 3); + output.WriteString(ComparisonDirection); + } + if (CustomCallHasSideEffect != false) { + output.WriteRawTag(136, 4); + output.WriteBool(CustomCallHasSideEffect); + } + if (Delta != 0L) { + output.WriteRawTag(144, 4); + output.WriteInt64(Delta); + } + if (IndicesAreSorted != false) { + output.WriteRawTag(152, 4); + output.WriteBool(IndicesAreSorted); + } + if (frontendAttributes_ != null) { + output.WriteRawTag(162, 4); + output.WriteMessage(FrontendAttributes); + } + if (UniqueIndices != false) { + output.WriteRawTag(168, 4); + output.WriteBool(UniqueIndices); + } + if (RngAlgorithm != global::Xla.RandomAlgorithm.RngDefault) { + output.WriteRawTag(176, 4); + output.WriteEnum((int) RngAlgorithm); + } + if (UseGlobalDeviceIds != false) { + output.WriteRawTag(184, 4); + output.WriteBool(UseGlobalDeviceIds); + } + if (ComparisonType.Length != 0) { + output.WriteRawTag(194, 4); + output.WriteString(ComparisonType); + } + if (IsCrossProgramPrefetch != false) { + output.WriteRawTag(200, 4); + output.WriteBool(IsCrossProgramPrefetch); + } + customCallOutputOperandAliasing_.WriteTo(output, _repeated_customCallOutputOperandAliasing_codec); + if (PaddingType != global::Xla.PaddingType.PaddingInvalid) { + output.WriteRawTag(216, 4); + output.WriteEnum((int) PaddingType); + } + if (CustomCallSchedule != global::Xla.CustomCallSchedule.ScheduleNone) { + output.WriteRawTag(224, 4); + output.WriteEnum((int) CustomCallSchedule); + } + if (CustomCallApiVersion != global::Xla.CustomCallApiVersion.ApiVersionUnspecified) { + output.WriteRawTag(232, 4); + output.WriteEnum((int) CustomCallApiVersion); + } + if (AsyncGroupId != 0L) { + output.WriteRawTag(240, 4); + output.WriteInt64(AsyncGroupId); + } + if (AsyncExecutionThread.Length != 0) { + output.WriteRawTag(250, 4); + output.WriteString(AsyncExecutionThread); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Opcode.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Opcode); + } + if (shape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Shape); + } + if (metadata_ != null) { + output.WriteRawTag(58); + output.WriteMessage(Metadata); + } + if (literal_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Literal); + } + if (ParameterNumber != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ParameterNumber); + } + if (FusionKind.Length != 0) { + output.WriteRawTag(90); + output.WriteString(FusionKind); + } + if (TupleIndex != 0L) { + output.WriteRawTag(104); + output.WriteInt64(TupleIndex); + } + dimensions_.WriteTo(ref output, _repeated_dimensions_codec); + if (window_ != null) { + output.WriteRawTag(122); + output.WriteMessage(Window); + } + if (convolutionDimensionNumbers_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(ConvolutionDimensionNumbers); + } + sliceDimensions_.WriteTo(ref output, _repeated_sliceDimensions_codec); + if (ExponentBits != 0) { + output.WriteRawTag(144, 1); + output.WriteInt32(ExponentBits); + } + if (MantissaBits != 0) { + output.WriteRawTag(152, 1); + output.WriteInt32(MantissaBits); + } + dynamicSliceSizes_.WriteTo(ref output, _repeated_dynamicSliceSizes_codec); + if (paddingConfig_ != null) { + output.WriteRawTag(170, 1); + output.WriteMessage(PaddingConfig); + } + if (OutfeedConfig.Length != 0) { + output.WriteRawTag(178, 1); + output.WriteBytes(OutfeedConfig); + } + if (Distribution != global::Xla.RandomDistribution.RngInvalid) { + output.WriteRawTag(184, 1); + output.WriteEnum((int) Distribution); + } + if (Epsilon != 0F) { + output.WriteRawTag(197, 1); + output.WriteFloat(Epsilon); + } + if (FeatureIndex != 0L) { + output.WriteRawTag(200, 1); + output.WriteInt64(FeatureIndex); + } + if (ChannelId != 0L) { + output.WriteRawTag(208, 1); + output.WriteInt64(ChannelId); + } + if (InfeedConfig.Length != 0) { + output.WriteRawTag(218, 1); + output.WriteBytes(InfeedConfig); + } + if (CustomCallTarget.Length != 0) { + output.WriteRawTag(226, 1); + output.WriteString(CustomCallTarget); + } + if (outfeedShape_ != null) { + output.WriteRawTag(234, 1); + output.WriteMessage(OutfeedShape); + } + if (dotDimensionNumbers_ != null) { + output.WriteRawTag(242, 1); + output.WriteMessage(DotDimensionNumbers); + } + if (FftType != global::Xla.FftType.Fft) { + output.WriteRawTag(248, 1); + output.WriteEnum((int) FftType); + } + fftLength_.WriteTo(ref output, _repeated_fftLength_codec); + if (gatherDimensionNumbers_ != null) { + output.WriteRawTag(138, 2); + output.WriteMessage(GatherDimensionNumbers); + } + gatherSliceSizes_.WriteTo(ref output, _repeated_gatherSliceSizes_codec); + if (Id != 0L) { + output.WriteRawTag(152, 2); + output.WriteInt64(Id); + } + operandIds_.WriteTo(ref output, _repeated_operandIds_codec); + controlPredecessorIds_.WriteTo(ref output, _repeated_controlPredecessorIds_codec); + calledComputationIds_.WriteTo(ref output, _repeated_calledComputationIds_codec); + if (sharding_ != null) { + output.WriteRawTag(194, 2); + output.WriteMessage(Sharding); + } + if (BackendConfig.Length != 0) { + output.WriteRawTag(218, 2); + output.WriteBytes(BackendConfig); + } + if (AllReduceId != 0L) { + output.WriteRawTag(232, 2); + output.WriteInt64(AllReduceId); + } + if (IsHostTransfer != false) { + output.WriteRawTag(248, 2); + output.WriteBool(IsHostTransfer); + } + if (scatterDimensionNumbers_ != null) { + output.WriteRawTag(130, 3); + output.WriteMessage(ScatterDimensionNumbers); + } + replicaGroups_.WriteTo(ref output, _repeated_replicaGroups_codec); + if (FeatureGroupCount != 0L) { + output.WriteRawTag(144, 3); + output.WriteInt64(FeatureGroupCount); + } + if (precisionConfig_ != null) { + output.WriteRawTag(154, 3); + output.WriteMessage(PrecisionConfig); + } + sourceTargetPairs_.WriteTo(ref output, _repeated_sourceTargetPairs_codec); + if (domainEntrySharding_ != null) { + output.WriteRawTag(178, 3); + output.WriteMessage(DomainEntrySharding); + } + if (domainExitSharding_ != null) { + output.WriteRawTag(186, 3); + output.WriteMessage(DomainExitSharding); + } + if (ConstrainLayout != false) { + output.WriteRawTag(192, 3); + output.WriteBool(ConstrainLayout); + } + operandShapesWithLayout_.WriteTo(ref output, _repeated_operandShapesWithLayout_codec); + if (BatchGroupCount != 0L) { + output.WriteRawTag(208, 3); + output.WriteInt64(BatchGroupCount); + } + if (triangularSolveOptions_ != null) { + output.WriteRawTag(218, 3); + output.WriteMessage(TriangularSolveOptions); + } + if (IsStable != false) { + output.WriteRawTag(224, 3); + output.WriteBool(IsStable); + } + if (parameterReplication_ != null) { + output.WriteRawTag(234, 3); + output.WriteMessage(ParameterReplication); + } + if (choleskyOptions_ != null) { + output.WriteRawTag(242, 3); + output.WriteMessage(CholeskyOptions); + } + if (ComparisonDirection.Length != 0) { + output.WriteRawTag(250, 3); + output.WriteString(ComparisonDirection); + } + if (CustomCallHasSideEffect != false) { + output.WriteRawTag(136, 4); + output.WriteBool(CustomCallHasSideEffect); + } + if (Delta != 0L) { + output.WriteRawTag(144, 4); + output.WriteInt64(Delta); + } + if (IndicesAreSorted != false) { + output.WriteRawTag(152, 4); + output.WriteBool(IndicesAreSorted); + } + if (frontendAttributes_ != null) { + output.WriteRawTag(162, 4); + output.WriteMessage(FrontendAttributes); + } + if (UniqueIndices != false) { + output.WriteRawTag(168, 4); + output.WriteBool(UniqueIndices); + } + if (RngAlgorithm != global::Xla.RandomAlgorithm.RngDefault) { + output.WriteRawTag(176, 4); + output.WriteEnum((int) RngAlgorithm); + } + if (UseGlobalDeviceIds != false) { + output.WriteRawTag(184, 4); + output.WriteBool(UseGlobalDeviceIds); + } + if (ComparisonType.Length != 0) { + output.WriteRawTag(194, 4); + output.WriteString(ComparisonType); + } + if (IsCrossProgramPrefetch != false) { + output.WriteRawTag(200, 4); + output.WriteBool(IsCrossProgramPrefetch); + } + customCallOutputOperandAliasing_.WriteTo(ref output, _repeated_customCallOutputOperandAliasing_codec); + if (PaddingType != global::Xla.PaddingType.PaddingInvalid) { + output.WriteRawTag(216, 4); + output.WriteEnum((int) PaddingType); + } + if (CustomCallSchedule != global::Xla.CustomCallSchedule.ScheduleNone) { + output.WriteRawTag(224, 4); + output.WriteEnum((int) CustomCallSchedule); + } + if (CustomCallApiVersion != global::Xla.CustomCallApiVersion.ApiVersionUnspecified) { + output.WriteRawTag(232, 4); + output.WriteEnum((int) CustomCallApiVersion); + } + if (AsyncGroupId != 0L) { + output.WriteRawTag(240, 4); + output.WriteInt64(AsyncGroupId); + } + if (AsyncExecutionThread.Length != 0) { + output.WriteRawTag(250, 4); + output.WriteString(AsyncExecutionThread); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Opcode.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Opcode); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (metadata_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Metadata); + } + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (ParameterNumber != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ParameterNumber); + } + if (FusionKind.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FusionKind); + } + if (TupleIndex != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TupleIndex); + } + size += dimensions_.CalculateSize(_repeated_dimensions_codec); + if (window_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Window); + } + if (convolutionDimensionNumbers_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ConvolutionDimensionNumbers); + } + if (FeatureGroupCount != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(FeatureGroupCount); + } + if (BatchGroupCount != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(BatchGroupCount); + } + size += sliceDimensions_.CalculateSize(_repeated_sliceDimensions_codec); + if (ExponentBits != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(ExponentBits); + } + if (MantissaBits != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MantissaBits); + } + size += dynamicSliceSizes_.CalculateSize(_repeated_dynamicSliceSizes_codec); + if (paddingConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(PaddingConfig); + } + if (OutfeedConfig.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeBytesSize(OutfeedConfig); + } + if (Distribution != global::Xla.RandomDistribution.RngInvalid) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) Distribution); + } + if (Epsilon != 0F) { + size += 2 + 4; + } + if (FeatureIndex != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(FeatureIndex); + } + if (ChannelId != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(ChannelId); + } + if (InfeedConfig.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeBytesSize(InfeedConfig); + } + if (CustomCallTarget.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(CustomCallTarget); + } + if (outfeedShape_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(OutfeedShape); + } + if (dotDimensionNumbers_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(DotDimensionNumbers); + } + if (FftType != global::Xla.FftType.Fft) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) FftType); + } + size += fftLength_.CalculateSize(_repeated_fftLength_codec); + if (ComparisonDirection.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(ComparisonDirection); + } + if (gatherDimensionNumbers_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(GatherDimensionNumbers); + } + size += gatherSliceSizes_.CalculateSize(_repeated_gatherSliceSizes_codec); + if (Id != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(Id); + } + size += operandIds_.CalculateSize(_repeated_operandIds_codec); + size += controlPredecessorIds_.CalculateSize(_repeated_controlPredecessorIds_codec); + size += calledComputationIds_.CalculateSize(_repeated_calledComputationIds_codec); + if (sharding_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(Sharding); + } + if (BackendConfig.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeBytesSize(BackendConfig); + } + size += replicaGroups_.CalculateSize(_repeated_replicaGroups_codec); + if (AllReduceId != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(AllReduceId); + } + if (UseGlobalDeviceIds != false) { + size += 2 + 1; + } + if (IsHostTransfer != false) { + size += 2 + 1; + } + if (IsStable != false) { + size += 2 + 1; + } + if (scatterDimensionNumbers_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ScatterDimensionNumbers); + } + if (precisionConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(PrecisionConfig); + } + size += sourceTargetPairs_.CalculateSize(_repeated_sourceTargetPairs_codec); + if (domainEntrySharding_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(DomainEntrySharding); + } + if (domainExitSharding_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(DomainExitSharding); + } + if (ConstrainLayout != false) { + size += 2 + 1; + } + size += operandShapesWithLayout_.CalculateSize(_repeated_operandShapesWithLayout_codec); + if (triangularSolveOptions_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(TriangularSolveOptions); + } + if (choleskyOptions_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(CholeskyOptions); + } + if (parameterReplication_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ParameterReplication); + } + if (CustomCallHasSideEffect != false) { + size += 2 + 1; + } + size += customCallOutputOperandAliasing_.CalculateSize(_repeated_customCallOutputOperandAliasing_codec); + if (CustomCallSchedule != global::Xla.CustomCallSchedule.ScheduleNone) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) CustomCallSchedule); + } + if (Delta != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(Delta); + } + if (IndicesAreSorted != false) { + size += 2 + 1; + } + if (frontendAttributes_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(FrontendAttributes); + } + if (UniqueIndices != false) { + size += 2 + 1; + } + if (RngAlgorithm != global::Xla.RandomAlgorithm.RngDefault) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) RngAlgorithm); + } + if (ComparisonType.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(ComparisonType); + } + if (IsCrossProgramPrefetch != false) { + size += 2 + 1; + } + if (PaddingType != global::Xla.PaddingType.PaddingInvalid) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) PaddingType); + } + if (CustomCallApiVersion != global::Xla.CustomCallApiVersion.ApiVersionUnspecified) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) CustomCallApiVersion); + } + if (AsyncGroupId != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(AsyncGroupId); + } + if (AsyncExecutionThread.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(AsyncExecutionThread); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloInstructionProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Opcode.Length != 0) { + Opcode = other.Opcode; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.metadata_ != null) { + if (metadata_ == null) { + Metadata = new global::Xla.OpMetadata(); + } + Metadata.MergeFrom(other.Metadata); + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + if (other.ParameterNumber != 0L) { + ParameterNumber = other.ParameterNumber; + } + if (other.FusionKind.Length != 0) { + FusionKind = other.FusionKind; + } + if (other.TupleIndex != 0L) { + TupleIndex = other.TupleIndex; + } + dimensions_.Add(other.dimensions_); + if (other.window_ != null) { + if (window_ == null) { + Window = new global::Xla.Window(); + } + Window.MergeFrom(other.Window); + } + if (other.convolutionDimensionNumbers_ != null) { + if (convolutionDimensionNumbers_ == null) { + ConvolutionDimensionNumbers = new global::Xla.ConvolutionDimensionNumbers(); + } + ConvolutionDimensionNumbers.MergeFrom(other.ConvolutionDimensionNumbers); + } + if (other.FeatureGroupCount != 0L) { + FeatureGroupCount = other.FeatureGroupCount; + } + if (other.BatchGroupCount != 0L) { + BatchGroupCount = other.BatchGroupCount; + } + sliceDimensions_.Add(other.sliceDimensions_); + if (other.ExponentBits != 0) { + ExponentBits = other.ExponentBits; + } + if (other.MantissaBits != 0) { + MantissaBits = other.MantissaBits; + } + dynamicSliceSizes_.Add(other.dynamicSliceSizes_); + if (other.paddingConfig_ != null) { + if (paddingConfig_ == null) { + PaddingConfig = new global::Xla.PaddingConfig(); + } + PaddingConfig.MergeFrom(other.PaddingConfig); + } + if (other.OutfeedConfig.Length != 0) { + OutfeedConfig = other.OutfeedConfig; + } + if (other.Distribution != global::Xla.RandomDistribution.RngInvalid) { + Distribution = other.Distribution; + } + if (other.Epsilon != 0F) { + Epsilon = other.Epsilon; + } + if (other.FeatureIndex != 0L) { + FeatureIndex = other.FeatureIndex; + } + if (other.ChannelId != 0L) { + ChannelId = other.ChannelId; + } + if (other.InfeedConfig.Length != 0) { + InfeedConfig = other.InfeedConfig; + } + if (other.CustomCallTarget.Length != 0) { + CustomCallTarget = other.CustomCallTarget; + } + if (other.outfeedShape_ != null) { + if (outfeedShape_ == null) { + OutfeedShape = new global::Xla.ShapeProto(); + } + OutfeedShape.MergeFrom(other.OutfeedShape); + } + if (other.dotDimensionNumbers_ != null) { + if (dotDimensionNumbers_ == null) { + DotDimensionNumbers = new global::Xla.DotDimensionNumbers(); + } + DotDimensionNumbers.MergeFrom(other.DotDimensionNumbers); + } + if (other.FftType != global::Xla.FftType.Fft) { + FftType = other.FftType; + } + fftLength_.Add(other.fftLength_); + if (other.ComparisonDirection.Length != 0) { + ComparisonDirection = other.ComparisonDirection; + } + if (other.gatherDimensionNumbers_ != null) { + if (gatherDimensionNumbers_ == null) { + GatherDimensionNumbers = new global::Xla.GatherDimensionNumbers(); + } + GatherDimensionNumbers.MergeFrom(other.GatherDimensionNumbers); + } + gatherSliceSizes_.Add(other.gatherSliceSizes_); + if (other.Id != 0L) { + Id = other.Id; + } + operandIds_.Add(other.operandIds_); + controlPredecessorIds_.Add(other.controlPredecessorIds_); + calledComputationIds_.Add(other.calledComputationIds_); + if (other.sharding_ != null) { + if (sharding_ == null) { + Sharding = new global::Xla.OpSharding(); + } + Sharding.MergeFrom(other.Sharding); + } + if (other.BackendConfig.Length != 0) { + BackendConfig = other.BackendConfig; + } + replicaGroups_.Add(other.replicaGroups_); + if (other.AllReduceId != 0L) { + AllReduceId = other.AllReduceId; + } + if (other.UseGlobalDeviceIds != false) { + UseGlobalDeviceIds = other.UseGlobalDeviceIds; + } + if (other.IsHostTransfer != false) { + IsHostTransfer = other.IsHostTransfer; + } + if (other.IsStable != false) { + IsStable = other.IsStable; + } + if (other.scatterDimensionNumbers_ != null) { + if (scatterDimensionNumbers_ == null) { + ScatterDimensionNumbers = new global::Xla.ScatterDimensionNumbers(); + } + ScatterDimensionNumbers.MergeFrom(other.ScatterDimensionNumbers); + } + if (other.precisionConfig_ != null) { + if (precisionConfig_ == null) { + PrecisionConfig = new global::Xla.PrecisionConfig(); + } + PrecisionConfig.MergeFrom(other.PrecisionConfig); + } + sourceTargetPairs_.Add(other.sourceTargetPairs_); + if (other.domainEntrySharding_ != null) { + if (domainEntrySharding_ == null) { + DomainEntrySharding = new global::Xla.OpSharding(); + } + DomainEntrySharding.MergeFrom(other.DomainEntrySharding); + } + if (other.domainExitSharding_ != null) { + if (domainExitSharding_ == null) { + DomainExitSharding = new global::Xla.OpSharding(); + } + DomainExitSharding.MergeFrom(other.DomainExitSharding); + } + if (other.ConstrainLayout != false) { + ConstrainLayout = other.ConstrainLayout; + } + operandShapesWithLayout_.Add(other.operandShapesWithLayout_); + if (other.triangularSolveOptions_ != null) { + if (triangularSolveOptions_ == null) { + TriangularSolveOptions = new global::Xla.TriangularSolveOptions(); + } + TriangularSolveOptions.MergeFrom(other.TriangularSolveOptions); + } + if (other.choleskyOptions_ != null) { + if (choleskyOptions_ == null) { + CholeskyOptions = new global::Xla.CholeskyOptions(); + } + CholeskyOptions.MergeFrom(other.CholeskyOptions); + } + if (other.parameterReplication_ != null) { + if (parameterReplication_ == null) { + ParameterReplication = new global::Xla.ParameterReplication(); + } + ParameterReplication.MergeFrom(other.ParameterReplication); + } + if (other.CustomCallHasSideEffect != false) { + CustomCallHasSideEffect = other.CustomCallHasSideEffect; + } + customCallOutputOperandAliasing_.Add(other.customCallOutputOperandAliasing_); + if (other.CustomCallSchedule != global::Xla.CustomCallSchedule.ScheduleNone) { + CustomCallSchedule = other.CustomCallSchedule; + } + if (other.Delta != 0L) { + Delta = other.Delta; + } + if (other.IndicesAreSorted != false) { + IndicesAreSorted = other.IndicesAreSorted; + } + if (other.frontendAttributes_ != null) { + if (frontendAttributes_ == null) { + FrontendAttributes = new global::Xla.FrontendAttributes(); + } + FrontendAttributes.MergeFrom(other.FrontendAttributes); + } + if (other.UniqueIndices != false) { + UniqueIndices = other.UniqueIndices; + } + if (other.RngAlgorithm != global::Xla.RandomAlgorithm.RngDefault) { + RngAlgorithm = other.RngAlgorithm; + } + if (other.ComparisonType.Length != 0) { + ComparisonType = other.ComparisonType; + } + if (other.IsCrossProgramPrefetch != false) { + IsCrossProgramPrefetch = other.IsCrossProgramPrefetch; + } + if (other.PaddingType != global::Xla.PaddingType.PaddingInvalid) { + PaddingType = other.PaddingType; + } + if (other.CustomCallApiVersion != global::Xla.CustomCallApiVersion.ApiVersionUnspecified) { + CustomCallApiVersion = other.CustomCallApiVersion; + } + if (other.AsyncGroupId != 0L) { + AsyncGroupId = other.AsyncGroupId; + } + if (other.AsyncExecutionThread.Length != 0) { + AsyncExecutionThread = other.AsyncExecutionThread; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Opcode = input.ReadString(); + break; + } + case 26: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 58: { + if (metadata_ == null) { + Metadata = new global::Xla.OpMetadata(); + } + input.ReadMessage(Metadata); + break; + } + case 66: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 72: { + ParameterNumber = input.ReadInt64(); + break; + } + case 90: { + FusionKind = input.ReadString(); + break; + } + case 104: { + TupleIndex = input.ReadInt64(); + break; + } + case 114: + case 112: { + dimensions_.AddEntriesFrom(input, _repeated_dimensions_codec); + break; + } + case 122: { + if (window_ == null) { + Window = new global::Xla.Window(); + } + input.ReadMessage(Window); + break; + } + case 130: { + if (convolutionDimensionNumbers_ == null) { + ConvolutionDimensionNumbers = new global::Xla.ConvolutionDimensionNumbers(); + } + input.ReadMessage(ConvolutionDimensionNumbers); + break; + } + case 138: { + sliceDimensions_.AddEntriesFrom(input, _repeated_sliceDimensions_codec); + break; + } + case 144: { + ExponentBits = input.ReadInt32(); + break; + } + case 152: { + MantissaBits = input.ReadInt32(); + break; + } + case 162: + case 160: { + dynamicSliceSizes_.AddEntriesFrom(input, _repeated_dynamicSliceSizes_codec); + break; + } + case 170: { + if (paddingConfig_ == null) { + PaddingConfig = new global::Xla.PaddingConfig(); + } + input.ReadMessage(PaddingConfig); + break; + } + case 178: { + OutfeedConfig = input.ReadBytes(); + break; + } + case 184: { + Distribution = (global::Xla.RandomDistribution) input.ReadEnum(); + break; + } + case 197: { + Epsilon = input.ReadFloat(); + break; + } + case 200: { + FeatureIndex = input.ReadInt64(); + break; + } + case 208: { + ChannelId = input.ReadInt64(); + break; + } + case 218: { + InfeedConfig = input.ReadBytes(); + break; + } + case 226: { + CustomCallTarget = input.ReadString(); + break; + } + case 234: { + if (outfeedShape_ == null) { + OutfeedShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(OutfeedShape); + break; + } + case 242: { + if (dotDimensionNumbers_ == null) { + DotDimensionNumbers = new global::Xla.DotDimensionNumbers(); + } + input.ReadMessage(DotDimensionNumbers); + break; + } + case 248: { + FftType = (global::Xla.FftType) input.ReadEnum(); + break; + } + case 258: + case 256: { + fftLength_.AddEntriesFrom(input, _repeated_fftLength_codec); + break; + } + case 266: { + if (gatherDimensionNumbers_ == null) { + GatherDimensionNumbers = new global::Xla.GatherDimensionNumbers(); + } + input.ReadMessage(GatherDimensionNumbers); + break; + } + case 274: + case 272: { + gatherSliceSizes_.AddEntriesFrom(input, _repeated_gatherSliceSizes_codec); + break; + } + case 280: { + Id = input.ReadInt64(); + break; + } + case 290: + case 288: { + operandIds_.AddEntriesFrom(input, _repeated_operandIds_codec); + break; + } + case 298: + case 296: { + controlPredecessorIds_.AddEntriesFrom(input, _repeated_controlPredecessorIds_codec); + break; + } + case 306: + case 304: { + calledComputationIds_.AddEntriesFrom(input, _repeated_calledComputationIds_codec); + break; + } + case 322: { + if (sharding_ == null) { + Sharding = new global::Xla.OpSharding(); + } + input.ReadMessage(Sharding); + break; + } + case 346: { + BackendConfig = input.ReadBytes(); + break; + } + case 360: { + AllReduceId = input.ReadInt64(); + break; + } + case 376: { + IsHostTransfer = input.ReadBool(); + break; + } + case 386: { + if (scatterDimensionNumbers_ == null) { + ScatterDimensionNumbers = new global::Xla.ScatterDimensionNumbers(); + } + input.ReadMessage(ScatterDimensionNumbers); + break; + } + case 394: { + replicaGroups_.AddEntriesFrom(input, _repeated_replicaGroups_codec); + break; + } + case 400: { + FeatureGroupCount = input.ReadInt64(); + break; + } + case 410: { + if (precisionConfig_ == null) { + PrecisionConfig = new global::Xla.PrecisionConfig(); + } + input.ReadMessage(PrecisionConfig); + break; + } + case 418: { + sourceTargetPairs_.AddEntriesFrom(input, _repeated_sourceTargetPairs_codec); + break; + } + case 434: { + if (domainEntrySharding_ == null) { + DomainEntrySharding = new global::Xla.OpSharding(); + } + input.ReadMessage(DomainEntrySharding); + break; + } + case 442: { + if (domainExitSharding_ == null) { + DomainExitSharding = new global::Xla.OpSharding(); + } + input.ReadMessage(DomainExitSharding); + break; + } + case 448: { + ConstrainLayout = input.ReadBool(); + break; + } + case 458: { + operandShapesWithLayout_.AddEntriesFrom(input, _repeated_operandShapesWithLayout_codec); + break; + } + case 464: { + BatchGroupCount = input.ReadInt64(); + break; + } + case 474: { + if (triangularSolveOptions_ == null) { + TriangularSolveOptions = new global::Xla.TriangularSolveOptions(); + } + input.ReadMessage(TriangularSolveOptions); + break; + } + case 480: { + IsStable = input.ReadBool(); + break; + } + case 490: { + if (parameterReplication_ == null) { + ParameterReplication = new global::Xla.ParameterReplication(); + } + input.ReadMessage(ParameterReplication); + break; + } + case 498: { + if (choleskyOptions_ == null) { + CholeskyOptions = new global::Xla.CholeskyOptions(); + } + input.ReadMessage(CholeskyOptions); + break; + } + case 506: { + ComparisonDirection = input.ReadString(); + break; + } + case 520: { + CustomCallHasSideEffect = input.ReadBool(); + break; + } + case 528: { + Delta = input.ReadInt64(); + break; + } + case 536: { + IndicesAreSorted = input.ReadBool(); + break; + } + case 546: { + if (frontendAttributes_ == null) { + FrontendAttributes = new global::Xla.FrontendAttributes(); + } + input.ReadMessage(FrontendAttributes); + break; + } + case 552: { + UniqueIndices = input.ReadBool(); + break; + } + case 560: { + RngAlgorithm = (global::Xla.RandomAlgorithm) input.ReadEnum(); + break; + } + case 568: { + UseGlobalDeviceIds = input.ReadBool(); + break; + } + case 578: { + ComparisonType = input.ReadString(); + break; + } + case 584: { + IsCrossProgramPrefetch = input.ReadBool(); + break; + } + case 594: { + customCallOutputOperandAliasing_.AddEntriesFrom(input, _repeated_customCallOutputOperandAliasing_codec); + break; + } + case 600: { + PaddingType = (global::Xla.PaddingType) input.ReadEnum(); + break; + } + case 608: { + CustomCallSchedule = (global::Xla.CustomCallSchedule) input.ReadEnum(); + break; + } + case 616: { + CustomCallApiVersion = (global::Xla.CustomCallApiVersion) input.ReadEnum(); + break; + } + case 624: { + AsyncGroupId = input.ReadInt64(); + break; + } + case 634: { + AsyncExecutionThread = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Opcode = input.ReadString(); + break; + } + case 26: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 58: { + if (metadata_ == null) { + Metadata = new global::Xla.OpMetadata(); + } + input.ReadMessage(Metadata); + break; + } + case 66: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 72: { + ParameterNumber = input.ReadInt64(); + break; + } + case 90: { + FusionKind = input.ReadString(); + break; + } + case 104: { + TupleIndex = input.ReadInt64(); + break; + } + case 114: + case 112: { + dimensions_.AddEntriesFrom(ref input, _repeated_dimensions_codec); + break; + } + case 122: { + if (window_ == null) { + Window = new global::Xla.Window(); + } + input.ReadMessage(Window); + break; + } + case 130: { + if (convolutionDimensionNumbers_ == null) { + ConvolutionDimensionNumbers = new global::Xla.ConvolutionDimensionNumbers(); + } + input.ReadMessage(ConvolutionDimensionNumbers); + break; + } + case 138: { + sliceDimensions_.AddEntriesFrom(ref input, _repeated_sliceDimensions_codec); + break; + } + case 144: { + ExponentBits = input.ReadInt32(); + break; + } + case 152: { + MantissaBits = input.ReadInt32(); + break; + } + case 162: + case 160: { + dynamicSliceSizes_.AddEntriesFrom(ref input, _repeated_dynamicSliceSizes_codec); + break; + } + case 170: { + if (paddingConfig_ == null) { + PaddingConfig = new global::Xla.PaddingConfig(); + } + input.ReadMessage(PaddingConfig); + break; + } + case 178: { + OutfeedConfig = input.ReadBytes(); + break; + } + case 184: { + Distribution = (global::Xla.RandomDistribution) input.ReadEnum(); + break; + } + case 197: { + Epsilon = input.ReadFloat(); + break; + } + case 200: { + FeatureIndex = input.ReadInt64(); + break; + } + case 208: { + ChannelId = input.ReadInt64(); + break; + } + case 218: { + InfeedConfig = input.ReadBytes(); + break; + } + case 226: { + CustomCallTarget = input.ReadString(); + break; + } + case 234: { + if (outfeedShape_ == null) { + OutfeedShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(OutfeedShape); + break; + } + case 242: { + if (dotDimensionNumbers_ == null) { + DotDimensionNumbers = new global::Xla.DotDimensionNumbers(); + } + input.ReadMessage(DotDimensionNumbers); + break; + } + case 248: { + FftType = (global::Xla.FftType) input.ReadEnum(); + break; + } + case 258: + case 256: { + fftLength_.AddEntriesFrom(ref input, _repeated_fftLength_codec); + break; + } + case 266: { + if (gatherDimensionNumbers_ == null) { + GatherDimensionNumbers = new global::Xla.GatherDimensionNumbers(); + } + input.ReadMessage(GatherDimensionNumbers); + break; + } + case 274: + case 272: { + gatherSliceSizes_.AddEntriesFrom(ref input, _repeated_gatherSliceSizes_codec); + break; + } + case 280: { + Id = input.ReadInt64(); + break; + } + case 290: + case 288: { + operandIds_.AddEntriesFrom(ref input, _repeated_operandIds_codec); + break; + } + case 298: + case 296: { + controlPredecessorIds_.AddEntriesFrom(ref input, _repeated_controlPredecessorIds_codec); + break; + } + case 306: + case 304: { + calledComputationIds_.AddEntriesFrom(ref input, _repeated_calledComputationIds_codec); + break; + } + case 322: { + if (sharding_ == null) { + Sharding = new global::Xla.OpSharding(); + } + input.ReadMessage(Sharding); + break; + } + case 346: { + BackendConfig = input.ReadBytes(); + break; + } + case 360: { + AllReduceId = input.ReadInt64(); + break; + } + case 376: { + IsHostTransfer = input.ReadBool(); + break; + } + case 386: { + if (scatterDimensionNumbers_ == null) { + ScatterDimensionNumbers = new global::Xla.ScatterDimensionNumbers(); + } + input.ReadMessage(ScatterDimensionNumbers); + break; + } + case 394: { + replicaGroups_.AddEntriesFrom(ref input, _repeated_replicaGroups_codec); + break; + } + case 400: { + FeatureGroupCount = input.ReadInt64(); + break; + } + case 410: { + if (precisionConfig_ == null) { + PrecisionConfig = new global::Xla.PrecisionConfig(); + } + input.ReadMessage(PrecisionConfig); + break; + } + case 418: { + sourceTargetPairs_.AddEntriesFrom(ref input, _repeated_sourceTargetPairs_codec); + break; + } + case 434: { + if (domainEntrySharding_ == null) { + DomainEntrySharding = new global::Xla.OpSharding(); + } + input.ReadMessage(DomainEntrySharding); + break; + } + case 442: { + if (domainExitSharding_ == null) { + DomainExitSharding = new global::Xla.OpSharding(); + } + input.ReadMessage(DomainExitSharding); + break; + } + case 448: { + ConstrainLayout = input.ReadBool(); + break; + } + case 458: { + operandShapesWithLayout_.AddEntriesFrom(ref input, _repeated_operandShapesWithLayout_codec); + break; + } + case 464: { + BatchGroupCount = input.ReadInt64(); + break; + } + case 474: { + if (triangularSolveOptions_ == null) { + TriangularSolveOptions = new global::Xla.TriangularSolveOptions(); + } + input.ReadMessage(TriangularSolveOptions); + break; + } + case 480: { + IsStable = input.ReadBool(); + break; + } + case 490: { + if (parameterReplication_ == null) { + ParameterReplication = new global::Xla.ParameterReplication(); + } + input.ReadMessage(ParameterReplication); + break; + } + case 498: { + if (choleskyOptions_ == null) { + CholeskyOptions = new global::Xla.CholeskyOptions(); + } + input.ReadMessage(CholeskyOptions); + break; + } + case 506: { + ComparisonDirection = input.ReadString(); + break; + } + case 520: { + CustomCallHasSideEffect = input.ReadBool(); + break; + } + case 528: { + Delta = input.ReadInt64(); + break; + } + case 536: { + IndicesAreSorted = input.ReadBool(); + break; + } + case 546: { + if (frontendAttributes_ == null) { + FrontendAttributes = new global::Xla.FrontendAttributes(); + } + input.ReadMessage(FrontendAttributes); + break; + } + case 552: { + UniqueIndices = input.ReadBool(); + break; + } + case 560: { + RngAlgorithm = (global::Xla.RandomAlgorithm) input.ReadEnum(); + break; + } + case 568: { + UseGlobalDeviceIds = input.ReadBool(); + break; + } + case 578: { + ComparisonType = input.ReadString(); + break; + } + case 584: { + IsCrossProgramPrefetch = input.ReadBool(); + break; + } + case 594: { + customCallOutputOperandAliasing_.AddEntriesFrom(ref input, _repeated_customCallOutputOperandAliasing_codec); + break; + } + case 600: { + PaddingType = (global::Xla.PaddingType) input.ReadEnum(); + break; + } + case 608: { + CustomCallSchedule = (global::Xla.CustomCallSchedule) input.ReadEnum(); + break; + } + case 616: { + CustomCallApiVersion = (global::Xla.CustomCallApiVersion) input.ReadEnum(); + break; + } + case 624: { + AsyncGroupId = input.ReadInt64(); + break; + } + case 634: { + AsyncExecutionThread = input.ReadString(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the HloInstructionProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Describes the [begin, end) index range and stride for slices. + /// + public sealed partial class SliceDimensions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SliceDimensions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloInstructionProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SliceDimensions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SliceDimensions(SliceDimensions other) : this() { + start_ = other.start_; + limit_ = other.limit_; + stride_ = other.stride_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SliceDimensions Clone() { + return new SliceDimensions(this); + } + + /// Field number for the "start" field. + public const int StartFieldNumber = 1; + private long start_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Start { + get { return start_; } + set { + start_ = value; + } + } + + /// Field number for the "limit" field. + public const int LimitFieldNumber = 2; + private long limit_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Limit { + get { return limit_; } + set { + limit_ = value; + } + } + + /// Field number for the "stride" field. + public const int StrideFieldNumber = 3; + private long stride_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Stride { + get { return stride_; } + set { + stride_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SliceDimensions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SliceDimensions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Start != other.Start) return false; + if (Limit != other.Limit) return false; + if (Stride != other.Stride) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Start != 0L) hash ^= Start.GetHashCode(); + if (Limit != 0L) hash ^= Limit.GetHashCode(); + if (Stride != 0L) hash ^= Stride.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Start != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Start); + } + if (Limit != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Limit); + } + if (Stride != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Stride); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Start != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Start); + } + if (Limit != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Limit); + } + if (Stride != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Stride); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Start != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Start); + } + if (Limit != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Limit); + } + if (Stride != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Stride); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SliceDimensions other) { + if (other == null) { + return; + } + if (other.Start != 0L) { + Start = other.Start; + } + if (other.Limit != 0L) { + Limit = other.Limit; + } + if (other.Stride != 0L) { + Stride = other.Stride; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Start = input.ReadInt64(); + break; + } + case 16: { + Limit = input.ReadInt64(); + break; + } + case 24: { + Stride = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Start = input.ReadInt64(); + break; + } + case 16: { + Limit = input.ReadInt64(); + break; + } + case 24: { + Stride = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Serialization of HloComputation. + /// + public sealed partial class HloComputationProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloComputationProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloComputationProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloComputationProto(HloComputationProto other) : this() { + name_ = other.name_; + instructions_ = other.instructions_.Clone(); + programShape_ = other.programShape_ != null ? other.programShape_.Clone() : null; + id_ = other.id_; + rootId_ = other.rootId_; + isFusionComputation_ = other.isFusionComputation_; + executionThread_ = other.executionThread_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloComputationProto Clone() { + return new HloComputationProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "instructions" field. + public const int InstructionsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_instructions_codec + = pb::FieldCodec.ForMessage(18, global::Xla.HloInstructionProto.Parser); + private readonly pbc::RepeatedField instructions_ = new pbc::RepeatedField(); + /// + /// The array of instructions is always in a valid dependency order, where + /// operands appear before their users. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Instructions { + get { return instructions_; } + } + + /// Field number for the "program_shape" field. + public const int ProgramShapeFieldNumber = 4; + private global::Xla.ProgramShapeProto programShape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ProgramShapeProto ProgramShape { + get { return programShape_; } + set { + programShape_ = value; + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 5; + private long id_; + /// + /// The id of this computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "root_id" field. + public const int RootIdFieldNumber = 6; + private long rootId_; + /// + /// The id of the root of the computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long RootId { + get { return rootId_; } + set { + rootId_ = value; + } + } + + /// Field number for the "is_fusion_computation" field. + public const int IsFusionComputationFieldNumber = 7; + private bool isFusionComputation_; + /// + /// Whether this is a fusion computation. Fusion computations should use this + /// to determine whether they are a fusion in CreateFromProto since the + /// parent fusion_instruction_ may get removed and be nullptr. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsFusionComputation { + get { return isFusionComputation_; } + set { + isFusionComputation_ = value; + } + } + + /// Field number for the "execution_thread" field. + public const int ExecutionThreadFieldNumber = 8; + private string executionThread_ = ""; + /// + /// The name of execution thread this computation belongs to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ExecutionThread { + get { return executionThread_; } + set { + executionThread_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloComputationProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloComputationProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!instructions_.Equals(other.instructions_)) return false; + if (!object.Equals(ProgramShape, other.ProgramShape)) return false; + if (Id != other.Id) return false; + if (RootId != other.RootId) return false; + if (IsFusionComputation != other.IsFusionComputation) return false; + if (ExecutionThread != other.ExecutionThread) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= instructions_.GetHashCode(); + if (programShape_ != null) hash ^= ProgramShape.GetHashCode(); + if (Id != 0L) hash ^= Id.GetHashCode(); + if (RootId != 0L) hash ^= RootId.GetHashCode(); + if (IsFusionComputation != false) hash ^= IsFusionComputation.GetHashCode(); + if (ExecutionThread.Length != 0) hash ^= ExecutionThread.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + instructions_.WriteTo(output, _repeated_instructions_codec); + if (programShape_ != null) { + output.WriteRawTag(34); + output.WriteMessage(ProgramShape); + } + if (Id != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Id); + } + if (RootId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(RootId); + } + if (IsFusionComputation != false) { + output.WriteRawTag(56); + output.WriteBool(IsFusionComputation); + } + if (ExecutionThread.Length != 0) { + output.WriteRawTag(66); + output.WriteString(ExecutionThread); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + instructions_.WriteTo(ref output, _repeated_instructions_codec); + if (programShape_ != null) { + output.WriteRawTag(34); + output.WriteMessage(ProgramShape); + } + if (Id != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Id); + } + if (RootId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(RootId); + } + if (IsFusionComputation != false) { + output.WriteRawTag(56); + output.WriteBool(IsFusionComputation); + } + if (ExecutionThread.Length != 0) { + output.WriteRawTag(66); + output.WriteString(ExecutionThread); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += instructions_.CalculateSize(_repeated_instructions_codec); + if (programShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ProgramShape); + } + if (Id != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Id); + } + if (RootId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(RootId); + } + if (IsFusionComputation != false) { + size += 1 + 1; + } + if (ExecutionThread.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ExecutionThread); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloComputationProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + instructions_.Add(other.instructions_); + if (other.programShape_ != null) { + if (programShape_ == null) { + ProgramShape = new global::Xla.ProgramShapeProto(); + } + ProgramShape.MergeFrom(other.ProgramShape); + } + if (other.Id != 0L) { + Id = other.Id; + } + if (other.RootId != 0L) { + RootId = other.RootId; + } + if (other.IsFusionComputation != false) { + IsFusionComputation = other.IsFusionComputation; + } + if (other.ExecutionThread.Length != 0) { + ExecutionThread = other.ExecutionThread; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + instructions_.AddEntriesFrom(input, _repeated_instructions_codec); + break; + } + case 34: { + if (programShape_ == null) { + ProgramShape = new global::Xla.ProgramShapeProto(); + } + input.ReadMessage(ProgramShape); + break; + } + case 40: { + Id = input.ReadInt64(); + break; + } + case 48: { + RootId = input.ReadInt64(); + break; + } + case 56: { + IsFusionComputation = input.ReadBool(); + break; + } + case 66: { + ExecutionThread = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + instructions_.AddEntriesFrom(ref input, _repeated_instructions_codec); + break; + } + case 34: { + if (programShape_ == null) { + ProgramShape = new global::Xla.ProgramShapeProto(); + } + input.ReadMessage(ProgramShape); + break; + } + case 40: { + Id = input.ReadInt64(); + break; + } + case 48: { + RootId = input.ReadInt64(); + break; + } + case 56: { + IsFusionComputation = input.ReadBool(); + break; + } + case 66: { + ExecutionThread = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Serialization of an HLO schedule. An HLO schedule contains a total order of + /// instructions for each non-fusion computation in the module. + /// + public sealed partial class HloScheduleProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloScheduleProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloScheduleProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloScheduleProto(HloScheduleProto other) : this() { + sequences_ = other.sequences_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloScheduleProto Clone() { + return new HloScheduleProto(this); + } + + /// Field number for the "sequences" field. + public const int SequencesFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_sequences_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForInt64(8, 0L), pb::FieldCodec.ForMessage(18, global::Xla.HloScheduleProto.Types.InstructionSequence.Parser), 10); + private readonly pbc::MapField sequences_ = new pbc::MapField(); + /// + /// Map from computation id to sequence. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Sequences { + get { return sequences_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloScheduleProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloScheduleProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Sequences.Equals(other.Sequences)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= Sequences.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + sequences_.WriteTo(output, _map_sequences_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + sequences_.WriteTo(ref output, _map_sequences_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += sequences_.CalculateSize(_map_sequences_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloScheduleProto other) { + if (other == null) { + return; + } + sequences_.Add(other.sequences_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + sequences_.AddEntriesFrom(input, _map_sequences_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + sequences_.AddEntriesFrom(ref input, _map_sequences_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the HloScheduleProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class InstructionSequence : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InstructionSequence()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloScheduleProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InstructionSequence() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InstructionSequence(InstructionSequence other) : this() { + instructionIds_ = other.instructionIds_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public InstructionSequence Clone() { + return new InstructionSequence(this); + } + + /// Field number for the "instruction_ids" field. + public const int InstructionIdsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_instructionIds_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField instructionIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InstructionIds { + get { return instructionIds_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as InstructionSequence); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(InstructionSequence other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!instructionIds_.Equals(other.instructionIds_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= instructionIds_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + instructionIds_.WriteTo(output, _repeated_instructionIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + instructionIds_.WriteTo(ref output, _repeated_instructionIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += instructionIds_.CalculateSize(_repeated_instructionIds_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(InstructionSequence other) { + if (other == null) { + return; + } + instructionIds_.Add(other.instructionIds_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + instructionIds_.AddEntriesFrom(input, _repeated_instructionIds_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + instructionIds_.AddEntriesFrom(ref input, _repeated_instructionIds_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + public sealed partial class HloInputOutputAliasProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloInputOutputAliasProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInputOutputAliasProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInputOutputAliasProto(HloInputOutputAliasProto other) : this() { + entries_ = other.entries_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloInputOutputAliasProto Clone() { + return new HloInputOutputAliasProto(this); + } + + /// Field number for the "entries" field. + public const int EntriesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_entries_codec + = pb::FieldCodec.ForMessage(10, global::Xla.HloInputOutputAliasProto.Types.AliasEntryProto.Parser); + private readonly pbc::RepeatedField entries_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Entries { + get { return entries_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloInputOutputAliasProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloInputOutputAliasProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!entries_.Equals(other.entries_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= entries_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + entries_.WriteTo(output, _repeated_entries_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + entries_.WriteTo(ref output, _repeated_entries_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += entries_.CalculateSize(_repeated_entries_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloInputOutputAliasProto other) { + if (other == null) { + return; + } + entries_.Add(other.entries_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + entries_.AddEntriesFrom(input, _repeated_entries_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + entries_.AddEntriesFrom(ref input, _repeated_entries_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the HloInputOutputAliasProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// The following proto describes a pair of aliased an input + /// (described by parameter number and a ShapeIndex of the parameter) + /// and an output (described by a ShapeIndex of the root + /// instruction). For example: + /// + /// entry = { + /// output_shape_index={1}, + /// parameter_number=0, + /// parameter_shape_index={1, 2}, + /// } + /// + /// This entry indicates that the first paremter's {1, 2} element is + /// aliased with the {1} element of the root instruction. + /// + public sealed partial class AliasEntryProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AliasEntryProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloInputOutputAliasProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AliasEntryProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AliasEntryProto(AliasEntryProto other) : this() { + outputShapeIndex_ = other.outputShapeIndex_.Clone(); + parameterNumber_ = other.parameterNumber_; + parameterShapeIndex_ = other.parameterShapeIndex_.Clone(); + kind_ = other.kind_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AliasEntryProto Clone() { + return new AliasEntryProto(this); + } + + /// Field number for the "output_shape_index" field. + public const int OutputShapeIndexFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_outputShapeIndex_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField outputShapeIndex_ = new pbc::RepeatedField(); + /// + /// ShapeIndex of the root hlo. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OutputShapeIndex { + get { return outputShapeIndex_; } + } + + /// Field number for the "parameter_number" field. + public const int ParameterNumberFieldNumber = 2; + private long parameterNumber_; + /// + /// Number of the parameter in entry computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ParameterNumber { + get { return parameterNumber_; } + set { + parameterNumber_ = value; + } + } + + /// Field number for the "parameter_shape_index" field. + public const int ParameterShapeIndexFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_parameterShapeIndex_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField parameterShapeIndex_ = new pbc::RepeatedField(); + /// + /// ShapeIndex of the parameter instruction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ParameterShapeIndex { + get { return parameterShapeIndex_; } + } + + /// Field number for the "kind" field. + public const int KindFieldNumber = 4; + private global::Xla.Kind kind_ = global::Xla.Kind.UndefinedAlias; + /// + /// The kind of alias to be setup. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.Kind Kind { + get { return kind_; } + set { + kind_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AliasEntryProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AliasEntryProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!outputShapeIndex_.Equals(other.outputShapeIndex_)) return false; + if (ParameterNumber != other.ParameterNumber) return false; + if(!parameterShapeIndex_.Equals(other.parameterShapeIndex_)) return false; + if (Kind != other.Kind) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= outputShapeIndex_.GetHashCode(); + if (ParameterNumber != 0L) hash ^= ParameterNumber.GetHashCode(); + hash ^= parameterShapeIndex_.GetHashCode(); + if (Kind != global::Xla.Kind.UndefinedAlias) hash ^= Kind.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + outputShapeIndex_.WriteTo(output, _repeated_outputShapeIndex_codec); + if (ParameterNumber != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ParameterNumber); + } + parameterShapeIndex_.WriteTo(output, _repeated_parameterShapeIndex_codec); + if (Kind != global::Xla.Kind.UndefinedAlias) { + output.WriteRawTag(32); + output.WriteEnum((int) Kind); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + outputShapeIndex_.WriteTo(ref output, _repeated_outputShapeIndex_codec); + if (ParameterNumber != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ParameterNumber); + } + parameterShapeIndex_.WriteTo(ref output, _repeated_parameterShapeIndex_codec); + if (Kind != global::Xla.Kind.UndefinedAlias) { + output.WriteRawTag(32); + output.WriteEnum((int) Kind); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += outputShapeIndex_.CalculateSize(_repeated_outputShapeIndex_codec); + if (ParameterNumber != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ParameterNumber); + } + size += parameterShapeIndex_.CalculateSize(_repeated_parameterShapeIndex_codec); + if (Kind != global::Xla.Kind.UndefinedAlias) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Kind); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AliasEntryProto other) { + if (other == null) { + return; + } + outputShapeIndex_.Add(other.outputShapeIndex_); + if (other.ParameterNumber != 0L) { + ParameterNumber = other.ParameterNumber; + } + parameterShapeIndex_.Add(other.parameterShapeIndex_); + if (other.Kind != global::Xla.Kind.UndefinedAlias) { + Kind = other.Kind; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + outputShapeIndex_.AddEntriesFrom(input, _repeated_outputShapeIndex_codec); + break; + } + case 16: { + ParameterNumber = input.ReadInt64(); + break; + } + case 26: + case 24: { + parameterShapeIndex_.AddEntriesFrom(input, _repeated_parameterShapeIndex_codec); + break; + } + case 32: { + Kind = (global::Xla.Kind) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + outputShapeIndex_.AddEntriesFrom(ref input, _repeated_outputShapeIndex_codec); + break; + } + case 16: { + ParameterNumber = input.ReadInt64(); + break; + } + case 26: + case 24: { + parameterShapeIndex_.AddEntriesFrom(ref input, _repeated_parameterShapeIndex_codec); + break; + } + case 32: { + Kind = (global::Xla.Kind) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + public sealed partial class DynamicParameterBindingProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DynamicParameterBindingProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DynamicParameterBindingProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DynamicParameterBindingProto(DynamicParameterBindingProto other) : this() { + entries_ = other.entries_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DynamicParameterBindingProto Clone() { + return new DynamicParameterBindingProto(this); + } + + /// Field number for the "entries" field. + public const int EntriesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_entries_codec + = pb::FieldCodec.ForMessage(10, global::Xla.DynamicParameterBindingProto.Types.Binding.Parser); + private readonly pbc::RepeatedField entries_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Entries { + get { return entries_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DynamicParameterBindingProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DynamicParameterBindingProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!entries_.Equals(other.entries_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= entries_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + entries_.WriteTo(output, _repeated_entries_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + entries_.WriteTo(ref output, _repeated_entries_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += entries_.CalculateSize(_repeated_entries_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DynamicParameterBindingProto other) { + if (other == null) { + return; + } + entries_.Add(other.entries_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + entries_.AddEntriesFrom(input, _repeated_entries_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + entries_.AddEntriesFrom(ref input, _repeated_entries_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the DynamicParameterBindingProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// A list of bindings which indicates that the `target_dim_num` in + /// the subshape `target_param_index` of parameter `target_param_num` + /// is a dynamic dimension and its real dynamic size is represented + /// by `dynamic_param_index` in parameter `dynamic_param_num`. + /// + /// As an example, imagine we have a program: + /// + /// ENTRY main { + /// a = f32[] parameter(0) + /// b = f32[10] parameter(1) + /// ROOT root = (f32[], f32[10]) tuple(%a, %b) + /// } + /// + /// Let's say 'b' (param index 1) is a dynamic shape whose input has + /// an upperbound of 10 and real size is determined at runtime.'a' + /// represents the real size of b's first dimension. + /// + /// In this case, the fields are set in the following way: + /// dynamic_param_num = 1 + /// dynamic_param_index = {} + /// target_param_num = 0 + /// target_param_index = {} + /// target_param_dim = 0 + /// + public sealed partial class Binding : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Binding()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.DynamicParameterBindingProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Binding() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Binding(Binding other) : this() { + dynamicParamNum_ = other.dynamicParamNum_; + dynamicParamIndex_ = other.dynamicParamIndex_.Clone(); + targetParamNum_ = other.targetParamNum_; + targetParamIndex_ = other.targetParamIndex_.Clone(); + targetParamDimNum_ = other.targetParamDimNum_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Binding Clone() { + return new Binding(this); + } + + /// Field number for the "dynamic_param_num" field. + public const int DynamicParamNumFieldNumber = 1; + private long dynamicParamNum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DynamicParamNum { + get { return dynamicParamNum_; } + set { + dynamicParamNum_ = value; + } + } + + /// Field number for the "dynamic_param_index" field. + public const int DynamicParamIndexFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_dynamicParamIndex_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField dynamicParamIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DynamicParamIndex { + get { return dynamicParamIndex_; } + } + + /// Field number for the "target_param_num" field. + public const int TargetParamNumFieldNumber = 3; + private long targetParamNum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TargetParamNum { + get { return targetParamNum_; } + set { + targetParamNum_ = value; + } + } + + /// Field number for the "target_param_index" field. + public const int TargetParamIndexFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_targetParamIndex_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField targetParamIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TargetParamIndex { + get { return targetParamIndex_; } + } + + /// Field number for the "target_param_dim_num" field. + public const int TargetParamDimNumFieldNumber = 5; + private long targetParamDimNum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TargetParamDimNum { + get { return targetParamDimNum_; } + set { + targetParamDimNum_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Binding); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Binding other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DynamicParamNum != other.DynamicParamNum) return false; + if(!dynamicParamIndex_.Equals(other.dynamicParamIndex_)) return false; + if (TargetParamNum != other.TargetParamNum) return false; + if(!targetParamIndex_.Equals(other.targetParamIndex_)) return false; + if (TargetParamDimNum != other.TargetParamDimNum) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DynamicParamNum != 0L) hash ^= DynamicParamNum.GetHashCode(); + hash ^= dynamicParamIndex_.GetHashCode(); + if (TargetParamNum != 0L) hash ^= TargetParamNum.GetHashCode(); + hash ^= targetParamIndex_.GetHashCode(); + if (TargetParamDimNum != 0L) hash ^= TargetParamDimNum.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DynamicParamNum != 0L) { + output.WriteRawTag(8); + output.WriteInt64(DynamicParamNum); + } + dynamicParamIndex_.WriteTo(output, _repeated_dynamicParamIndex_codec); + if (TargetParamNum != 0L) { + output.WriteRawTag(24); + output.WriteInt64(TargetParamNum); + } + targetParamIndex_.WriteTo(output, _repeated_targetParamIndex_codec); + if (TargetParamDimNum != 0L) { + output.WriteRawTag(40); + output.WriteInt64(TargetParamDimNum); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DynamicParamNum != 0L) { + output.WriteRawTag(8); + output.WriteInt64(DynamicParamNum); + } + dynamicParamIndex_.WriteTo(ref output, _repeated_dynamicParamIndex_codec); + if (TargetParamNum != 0L) { + output.WriteRawTag(24); + output.WriteInt64(TargetParamNum); + } + targetParamIndex_.WriteTo(ref output, _repeated_targetParamIndex_codec); + if (TargetParamDimNum != 0L) { + output.WriteRawTag(40); + output.WriteInt64(TargetParamDimNum); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DynamicParamNum != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DynamicParamNum); + } + size += dynamicParamIndex_.CalculateSize(_repeated_dynamicParamIndex_codec); + if (TargetParamNum != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TargetParamNum); + } + size += targetParamIndex_.CalculateSize(_repeated_targetParamIndex_codec); + if (TargetParamDimNum != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TargetParamDimNum); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Binding other) { + if (other == null) { + return; + } + if (other.DynamicParamNum != 0L) { + DynamicParamNum = other.DynamicParamNum; + } + dynamicParamIndex_.Add(other.dynamicParamIndex_); + if (other.TargetParamNum != 0L) { + TargetParamNum = other.TargetParamNum; + } + targetParamIndex_.Add(other.targetParamIndex_); + if (other.TargetParamDimNum != 0L) { + TargetParamDimNum = other.TargetParamDimNum; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + DynamicParamNum = input.ReadInt64(); + break; + } + case 18: + case 16: { + dynamicParamIndex_.AddEntriesFrom(input, _repeated_dynamicParamIndex_codec); + break; + } + case 24: { + TargetParamNum = input.ReadInt64(); + break; + } + case 34: + case 32: { + targetParamIndex_.AddEntriesFrom(input, _repeated_targetParamIndex_codec); + break; + } + case 40: { + TargetParamDimNum = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + DynamicParamNum = input.ReadInt64(); + break; + } + case 18: + case 16: { + dynamicParamIndex_.AddEntriesFrom(ref input, _repeated_dynamicParamIndex_codec); + break; + } + case 24: { + TargetParamNum = input.ReadInt64(); + break; + } + case 34: + case 32: { + targetParamIndex_.AddEntriesFrom(ref input, _repeated_targetParamIndex_codec); + break; + } + case 40: { + TargetParamDimNum = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + public sealed partial class CrossProgramPrefetch : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CrossProgramPrefetch()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossProgramPrefetch() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossProgramPrefetch(CrossProgramPrefetch other) : this() { + parameter_ = other.parameter_; + index_ = other.index_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CrossProgramPrefetch Clone() { + return new CrossProgramPrefetch(this); + } + + /// Field number for the "parameter" field. + public const int ParameterFieldNumber = 1; + private long parameter_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Parameter { + get { return parameter_; } + set { + parameter_ = value; + } + } + + /// Field number for the "index" field. + public const int IndexFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_index_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField index_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Index { + get { return index_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CrossProgramPrefetch); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CrossProgramPrefetch other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Parameter != other.Parameter) return false; + if(!index_.Equals(other.index_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Parameter != 0L) hash ^= Parameter.GetHashCode(); + hash ^= index_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Parameter != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Parameter); + } + index_.WriteTo(output, _repeated_index_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Parameter != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Parameter); + } + index_.WriteTo(ref output, _repeated_index_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Parameter != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Parameter); + } + size += index_.CalculateSize(_repeated_index_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CrossProgramPrefetch other) { + if (other == null) { + return; + } + if (other.Parameter != 0L) { + Parameter = other.Parameter; + } + index_.Add(other.index_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Parameter = input.ReadInt64(); + break; + } + case 18: + case 16: { + index_.AddEntriesFrom(input, _repeated_index_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Parameter = input.ReadInt64(); + break; + } + case 18: + case 16: { + index_.AddEntriesFrom(ref input, _repeated_index_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Serialization of HloModule. + /// + public sealed partial class HloModuleProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloModuleProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleProto(HloModuleProto other) : this() { + name_ = other.name_; + entryComputationName_ = other.entryComputationName_; + entryComputationId_ = other.entryComputationId_; + computations_ = other.computations_.Clone(); + hostProgramShape_ = other.hostProgramShape_ != null ? other.hostProgramShape_.Clone() : null; + id_ = other.id_; + schedule_ = other.schedule_ != null ? other.schedule_.Clone() : null; + inputOutputAlias_ = other.inputOutputAlias_ != null ? other.inputOutputAlias_.Clone() : null; + dynamicParameterBinding_ = other.dynamicParameterBinding_ != null ? other.dynamicParameterBinding_.Clone() : null; + crossProgramPrefetches_ = other.crossProgramPrefetches_.Clone(); + isDynamic_ = other.isDynamic_; + spmdOutputSharding_ = other.spmdOutputSharding_ != null ? other.spmdOutputSharding_.Clone() : null; + spmdParametersShardings_ = other.spmdParametersShardings_.Clone(); + useAutoSpmdPartitioning_ = other.useAutoSpmdPartitioning_; + profileInfo_ = other.profileInfo_.Clone(); + deviceAssignment_ = other.deviceAssignment_ != null ? other.deviceAssignment_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleProto Clone() { + return new HloModuleProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "entry_computation_name" field. + public const int EntryComputationNameFieldNumber = 2; + private string entryComputationName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string EntryComputationName { + get { return entryComputationName_; } + set { + entryComputationName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "entry_computation_id" field. + public const int EntryComputationIdFieldNumber = 6; + private long entryComputationId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long EntryComputationId { + get { return entryComputationId_; } + set { + entryComputationId_ = value; + } + } + + /// Field number for the "computations" field. + public const int ComputationsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_computations_codec + = pb::FieldCodec.ForMessage(26, global::Xla.HloComputationProto.Parser); + private readonly pbc::RepeatedField computations_ = new pbc::RepeatedField(); + /// + /// The array of computations is always in a valid dependency order, where + /// callees appear before their callers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Computations { + get { return computations_; } + } + + /// Field number for the "host_program_shape" field. + public const int HostProgramShapeFieldNumber = 4; + private global::Xla.ProgramShapeProto hostProgramShape_; + /// + /// The host program shape (with layout) of the entry computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ProgramShapeProto HostProgramShape { + get { return hostProgramShape_; } + set { + hostProgramShape_ = value; + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 5; + private long id_; + /// + /// The id of this module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "schedule" field. + public const int ScheduleFieldNumber = 7; + private global::Xla.HloScheduleProto schedule_; + /// + /// The schedule for this module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloScheduleProto Schedule { + get { return schedule_; } + set { + schedule_ = value; + } + } + + /// Field number for the "input_output_alias" field. + public const int InputOutputAliasFieldNumber = 8; + private global::Xla.HloInputOutputAliasProto inputOutputAlias_; + /// + /// Describes alias information between inputs and outputs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloInputOutputAliasProto InputOutputAlias { + get { return inputOutputAlias_; } + set { + inputOutputAlias_ = value; + } + } + + /// Field number for the "dynamic_parameter_binding" field. + public const int DynamicParameterBindingFieldNumber = 9; + private global::Xla.DynamicParameterBindingProto dynamicParameterBinding_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DynamicParameterBindingProto DynamicParameterBinding { + get { return dynamicParameterBinding_; } + set { + dynamicParameterBinding_ = value; + } + } + + /// Field number for the "cross_program_prefetches" field. + public const int CrossProgramPrefetchesFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_crossProgramPrefetches_codec + = pb::FieldCodec.ForMessage(82, global::Xla.CrossProgramPrefetch.Parser); + private readonly pbc::RepeatedField crossProgramPrefetches_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CrossProgramPrefetches { + get { return crossProgramPrefetches_; } + } + + /// Field number for the "is_dynamic" field. + public const int IsDynamicFieldNumber = 11; + private bool isDynamic_; + /// + /// True if the module contains dynamic computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsDynamic { + get { return isDynamic_; } + set { + isDynamic_ = value; + } + } + + /// Field number for the "spmd_output_sharding" field. + public const int SpmdOutputShardingFieldNumber = 12; + private global::Xla.OpSharding spmdOutputSharding_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpSharding SpmdOutputSharding { + get { return spmdOutputSharding_; } + set { + spmdOutputSharding_ = value; + } + } + + /// Field number for the "spmd_parameters_shardings" field. + public const int SpmdParametersShardingsFieldNumber = 14; + private static readonly pb::FieldCodec _repeated_spmdParametersShardings_codec + = pb::FieldCodec.ForMessage(114, global::Xla.OpSharding.Parser); + private readonly pbc::RepeatedField spmdParametersShardings_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SpmdParametersShardings { + get { return spmdParametersShardings_; } + } + + /// Field number for the "use_auto_spmd_partitioning" field. + public const int UseAutoSpmdPartitioningFieldNumber = 16; + private bool useAutoSpmdPartitioning_; + /// + /// Uses AutoSharding pass or not. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseAutoSpmdPartitioning { + get { return useAutoSpmdPartitioning_; } + set { + useAutoSpmdPartitioning_ = value; + } + } + + /// Field number for the "profile_info" field. + public const int ProfileInfoFieldNumber = 13; + private static readonly pb::FieldCodec _repeated_profileInfo_codec + = pb::FieldCodec.ForMessage(106, global::Xla.HloModuleProto.Types.ProfileInfo.Parser); + private readonly pbc::RepeatedField profileInfo_ = new pbc::RepeatedField(); + /// + /// Profile information for the HLO module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ProfileInfo { + get { return profileInfo_; } + } + + /// Field number for the "device_assignment" field. + public const int DeviceAssignmentFieldNumber = 15; + private global::Xla.DeviceAssignmentProto deviceAssignment_; + /// + /// DeviceAssignment object information. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceAssignmentProto DeviceAssignment { + get { return deviceAssignment_; } + set { + deviceAssignment_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloModuleProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloModuleProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (EntryComputationName != other.EntryComputationName) return false; + if (EntryComputationId != other.EntryComputationId) return false; + if(!computations_.Equals(other.computations_)) return false; + if (!object.Equals(HostProgramShape, other.HostProgramShape)) return false; + if (Id != other.Id) return false; + if (!object.Equals(Schedule, other.Schedule)) return false; + if (!object.Equals(InputOutputAlias, other.InputOutputAlias)) return false; + if (!object.Equals(DynamicParameterBinding, other.DynamicParameterBinding)) return false; + if(!crossProgramPrefetches_.Equals(other.crossProgramPrefetches_)) return false; + if (IsDynamic != other.IsDynamic) return false; + if (!object.Equals(SpmdOutputSharding, other.SpmdOutputSharding)) return false; + if(!spmdParametersShardings_.Equals(other.spmdParametersShardings_)) return false; + if (UseAutoSpmdPartitioning != other.UseAutoSpmdPartitioning) return false; + if(!profileInfo_.Equals(other.profileInfo_)) return false; + if (!object.Equals(DeviceAssignment, other.DeviceAssignment)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (EntryComputationName.Length != 0) hash ^= EntryComputationName.GetHashCode(); + if (EntryComputationId != 0L) hash ^= EntryComputationId.GetHashCode(); + hash ^= computations_.GetHashCode(); + if (hostProgramShape_ != null) hash ^= HostProgramShape.GetHashCode(); + if (Id != 0L) hash ^= Id.GetHashCode(); + if (schedule_ != null) hash ^= Schedule.GetHashCode(); + if (inputOutputAlias_ != null) hash ^= InputOutputAlias.GetHashCode(); + if (dynamicParameterBinding_ != null) hash ^= DynamicParameterBinding.GetHashCode(); + hash ^= crossProgramPrefetches_.GetHashCode(); + if (IsDynamic != false) hash ^= IsDynamic.GetHashCode(); + if (spmdOutputSharding_ != null) hash ^= SpmdOutputSharding.GetHashCode(); + hash ^= spmdParametersShardings_.GetHashCode(); + if (UseAutoSpmdPartitioning != false) hash ^= UseAutoSpmdPartitioning.GetHashCode(); + hash ^= profileInfo_.GetHashCode(); + if (deviceAssignment_ != null) hash ^= DeviceAssignment.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (EntryComputationName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(EntryComputationName); + } + computations_.WriteTo(output, _repeated_computations_codec); + if (hostProgramShape_ != null) { + output.WriteRawTag(34); + output.WriteMessage(HostProgramShape); + } + if (Id != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Id); + } + if (EntryComputationId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(EntryComputationId); + } + if (schedule_ != null) { + output.WriteRawTag(58); + output.WriteMessage(Schedule); + } + if (inputOutputAlias_ != null) { + output.WriteRawTag(66); + output.WriteMessage(InputOutputAlias); + } + if (dynamicParameterBinding_ != null) { + output.WriteRawTag(74); + output.WriteMessage(DynamicParameterBinding); + } + crossProgramPrefetches_.WriteTo(output, _repeated_crossProgramPrefetches_codec); + if (IsDynamic != false) { + output.WriteRawTag(88); + output.WriteBool(IsDynamic); + } + if (spmdOutputSharding_ != null) { + output.WriteRawTag(98); + output.WriteMessage(SpmdOutputSharding); + } + profileInfo_.WriteTo(output, _repeated_profileInfo_codec); + spmdParametersShardings_.WriteTo(output, _repeated_spmdParametersShardings_codec); + if (deviceAssignment_ != null) { + output.WriteRawTag(122); + output.WriteMessage(DeviceAssignment); + } + if (UseAutoSpmdPartitioning != false) { + output.WriteRawTag(128, 1); + output.WriteBool(UseAutoSpmdPartitioning); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (EntryComputationName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(EntryComputationName); + } + computations_.WriteTo(ref output, _repeated_computations_codec); + if (hostProgramShape_ != null) { + output.WriteRawTag(34); + output.WriteMessage(HostProgramShape); + } + if (Id != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Id); + } + if (EntryComputationId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(EntryComputationId); + } + if (schedule_ != null) { + output.WriteRawTag(58); + output.WriteMessage(Schedule); + } + if (inputOutputAlias_ != null) { + output.WriteRawTag(66); + output.WriteMessage(InputOutputAlias); + } + if (dynamicParameterBinding_ != null) { + output.WriteRawTag(74); + output.WriteMessage(DynamicParameterBinding); + } + crossProgramPrefetches_.WriteTo(ref output, _repeated_crossProgramPrefetches_codec); + if (IsDynamic != false) { + output.WriteRawTag(88); + output.WriteBool(IsDynamic); + } + if (spmdOutputSharding_ != null) { + output.WriteRawTag(98); + output.WriteMessage(SpmdOutputSharding); + } + profileInfo_.WriteTo(ref output, _repeated_profileInfo_codec); + spmdParametersShardings_.WriteTo(ref output, _repeated_spmdParametersShardings_codec); + if (deviceAssignment_ != null) { + output.WriteRawTag(122); + output.WriteMessage(DeviceAssignment); + } + if (UseAutoSpmdPartitioning != false) { + output.WriteRawTag(128, 1); + output.WriteBool(UseAutoSpmdPartitioning); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (EntryComputationName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(EntryComputationName); + } + if (EntryComputationId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(EntryComputationId); + } + size += computations_.CalculateSize(_repeated_computations_codec); + if (hostProgramShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(HostProgramShape); + } + if (Id != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Id); + } + if (schedule_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Schedule); + } + if (inputOutputAlias_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(InputOutputAlias); + } + if (dynamicParameterBinding_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DynamicParameterBinding); + } + size += crossProgramPrefetches_.CalculateSize(_repeated_crossProgramPrefetches_codec); + if (IsDynamic != false) { + size += 1 + 1; + } + if (spmdOutputSharding_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SpmdOutputSharding); + } + size += spmdParametersShardings_.CalculateSize(_repeated_spmdParametersShardings_codec); + if (UseAutoSpmdPartitioning != false) { + size += 2 + 1; + } + size += profileInfo_.CalculateSize(_repeated_profileInfo_codec); + if (deviceAssignment_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceAssignment); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloModuleProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.EntryComputationName.Length != 0) { + EntryComputationName = other.EntryComputationName; + } + if (other.EntryComputationId != 0L) { + EntryComputationId = other.EntryComputationId; + } + computations_.Add(other.computations_); + if (other.hostProgramShape_ != null) { + if (hostProgramShape_ == null) { + HostProgramShape = new global::Xla.ProgramShapeProto(); + } + HostProgramShape.MergeFrom(other.HostProgramShape); + } + if (other.Id != 0L) { + Id = other.Id; + } + if (other.schedule_ != null) { + if (schedule_ == null) { + Schedule = new global::Xla.HloScheduleProto(); + } + Schedule.MergeFrom(other.Schedule); + } + if (other.inputOutputAlias_ != null) { + if (inputOutputAlias_ == null) { + InputOutputAlias = new global::Xla.HloInputOutputAliasProto(); + } + InputOutputAlias.MergeFrom(other.InputOutputAlias); + } + if (other.dynamicParameterBinding_ != null) { + if (dynamicParameterBinding_ == null) { + DynamicParameterBinding = new global::Xla.DynamicParameterBindingProto(); + } + DynamicParameterBinding.MergeFrom(other.DynamicParameterBinding); + } + crossProgramPrefetches_.Add(other.crossProgramPrefetches_); + if (other.IsDynamic != false) { + IsDynamic = other.IsDynamic; + } + if (other.spmdOutputSharding_ != null) { + if (spmdOutputSharding_ == null) { + SpmdOutputSharding = new global::Xla.OpSharding(); + } + SpmdOutputSharding.MergeFrom(other.SpmdOutputSharding); + } + spmdParametersShardings_.Add(other.spmdParametersShardings_); + if (other.UseAutoSpmdPartitioning != false) { + UseAutoSpmdPartitioning = other.UseAutoSpmdPartitioning; + } + profileInfo_.Add(other.profileInfo_); + if (other.deviceAssignment_ != null) { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + DeviceAssignment.MergeFrom(other.DeviceAssignment); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + EntryComputationName = input.ReadString(); + break; + } + case 26: { + computations_.AddEntriesFrom(input, _repeated_computations_codec); + break; + } + case 34: { + if (hostProgramShape_ == null) { + HostProgramShape = new global::Xla.ProgramShapeProto(); + } + input.ReadMessage(HostProgramShape); + break; + } + case 40: { + Id = input.ReadInt64(); + break; + } + case 48: { + EntryComputationId = input.ReadInt64(); + break; + } + case 58: { + if (schedule_ == null) { + Schedule = new global::Xla.HloScheduleProto(); + } + input.ReadMessage(Schedule); + break; + } + case 66: { + if (inputOutputAlias_ == null) { + InputOutputAlias = new global::Xla.HloInputOutputAliasProto(); + } + input.ReadMessage(InputOutputAlias); + break; + } + case 74: { + if (dynamicParameterBinding_ == null) { + DynamicParameterBinding = new global::Xla.DynamicParameterBindingProto(); + } + input.ReadMessage(DynamicParameterBinding); + break; + } + case 82: { + crossProgramPrefetches_.AddEntriesFrom(input, _repeated_crossProgramPrefetches_codec); + break; + } + case 88: { + IsDynamic = input.ReadBool(); + break; + } + case 98: { + if (spmdOutputSharding_ == null) { + SpmdOutputSharding = new global::Xla.OpSharding(); + } + input.ReadMessage(SpmdOutputSharding); + break; + } + case 106: { + profileInfo_.AddEntriesFrom(input, _repeated_profileInfo_codec); + break; + } + case 114: { + spmdParametersShardings_.AddEntriesFrom(input, _repeated_spmdParametersShardings_codec); + break; + } + case 122: { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + input.ReadMessage(DeviceAssignment); + break; + } + case 128: { + UseAutoSpmdPartitioning = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + EntryComputationName = input.ReadString(); + break; + } + case 26: { + computations_.AddEntriesFrom(ref input, _repeated_computations_codec); + break; + } + case 34: { + if (hostProgramShape_ == null) { + HostProgramShape = new global::Xla.ProgramShapeProto(); + } + input.ReadMessage(HostProgramShape); + break; + } + case 40: { + Id = input.ReadInt64(); + break; + } + case 48: { + EntryComputationId = input.ReadInt64(); + break; + } + case 58: { + if (schedule_ == null) { + Schedule = new global::Xla.HloScheduleProto(); + } + input.ReadMessage(Schedule); + break; + } + case 66: { + if (inputOutputAlias_ == null) { + InputOutputAlias = new global::Xla.HloInputOutputAliasProto(); + } + input.ReadMessage(InputOutputAlias); + break; + } + case 74: { + if (dynamicParameterBinding_ == null) { + DynamicParameterBinding = new global::Xla.DynamicParameterBindingProto(); + } + input.ReadMessage(DynamicParameterBinding); + break; + } + case 82: { + crossProgramPrefetches_.AddEntriesFrom(ref input, _repeated_crossProgramPrefetches_codec); + break; + } + case 88: { + IsDynamic = input.ReadBool(); + break; + } + case 98: { + if (spmdOutputSharding_ == null) { + SpmdOutputSharding = new global::Xla.OpSharding(); + } + input.ReadMessage(SpmdOutputSharding); + break; + } + case 106: { + profileInfo_.AddEntriesFrom(ref input, _repeated_profileInfo_codec); + break; + } + case 114: { + spmdParametersShardings_.AddEntriesFrom(ref input, _repeated_spmdParametersShardings_codec); + break; + } + case 122: { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + input.ReadMessage(DeviceAssignment); + break; + } + case 128: { + UseAutoSpmdPartitioning = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the HloModuleProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// The type of optimization profile in use for module-level optimizations. + /// + public enum ProfileType { + [pbr::OriginalName("INVALID")] Invalid = 0, + [pbr::OriginalName("FLAG")] Flag = 1, + [pbr::OriginalName("FUSION")] Fusion = 2, + [pbr::OriginalName("LAYOUT")] Layout = 3, + [pbr::OriginalName("DOT")] Dot = 4, + } + + /// + /// Information about the optimization profile that this module contains. + /// + public sealed partial class ProfileInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ProfileInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloModuleProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo(ProfileInfo other) : this() { + profileType_ = other.profileType_; + relativeSpeedup_ = other.relativeSpeedup_; + profileSource_ = other.profileSource_; + compilationEvent_ = other.compilationEvent_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo Clone() { + return new ProfileInfo(this); + } + + /// Field number for the "profile_type" field. + public const int ProfileTypeFieldNumber = 1; + private global::Xla.HloModuleProto.Types.ProfileType profileType_ = global::Xla.HloModuleProto.Types.ProfileType.Invalid; + /// + /// The optimization profiles that this module contains. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto.Types.ProfileType ProfileType { + get { return profileType_; } + set { + profileType_ = value; + } + } + + /// Field number for the "relative_speedup" field. + public const int RelativeSpeedupFieldNumber = 2; + private double relativeSpeedup_; + /// + /// Speedup of tuned config compared to default config. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double RelativeSpeedup { + get { return relativeSpeedup_; } + set { + relativeSpeedup_ = value; + } + } + + /// Field number for the "profile_source" field. + public const int ProfileSourceFieldNumber = 3; + private global::Xla.ProfileSource profileSource_ = global::Xla.ProfileSource.UnknownSource; + /// + /// The source of the optimization profile that this module contains. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ProfileSource ProfileSource { + get { return profileSource_; } + set { + profileSource_ = value; + } + } + + /// Field number for the "compilation_event" field. + public const int CompilationEventFieldNumber = 4; + private global::Xla.CompilationEvent compilationEvent_ = global::Xla.CompilationEvent.UnknownEvent; + /// + /// The compilation event that triggered the use of the profile. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.CompilationEvent CompilationEvent { + get { return compilationEvent_; } + set { + compilationEvent_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ProfileInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ProfileInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ProfileType != other.ProfileType) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(RelativeSpeedup, other.RelativeSpeedup)) return false; + if (ProfileSource != other.ProfileSource) return false; + if (CompilationEvent != other.CompilationEvent) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ProfileType != global::Xla.HloModuleProto.Types.ProfileType.Invalid) hash ^= ProfileType.GetHashCode(); + if (RelativeSpeedup != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(RelativeSpeedup); + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) hash ^= ProfileSource.GetHashCode(); + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) hash ^= CompilationEvent.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ProfileType != global::Xla.HloModuleProto.Types.ProfileType.Invalid) { + output.WriteRawTag(8); + output.WriteEnum((int) ProfileType); + } + if (RelativeSpeedup != 0D) { + output.WriteRawTag(17); + output.WriteDouble(RelativeSpeedup); + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + output.WriteRawTag(24); + output.WriteEnum((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + output.WriteRawTag(32); + output.WriteEnum((int) CompilationEvent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ProfileType != global::Xla.HloModuleProto.Types.ProfileType.Invalid) { + output.WriteRawTag(8); + output.WriteEnum((int) ProfileType); + } + if (RelativeSpeedup != 0D) { + output.WriteRawTag(17); + output.WriteDouble(RelativeSpeedup); + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + output.WriteRawTag(24); + output.WriteEnum((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + output.WriteRawTag(32); + output.WriteEnum((int) CompilationEvent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ProfileType != global::Xla.HloModuleProto.Types.ProfileType.Invalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ProfileType); + } + if (RelativeSpeedup != 0D) { + size += 1 + 8; + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) CompilationEvent); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ProfileInfo other) { + if (other == null) { + return; + } + if (other.ProfileType != global::Xla.HloModuleProto.Types.ProfileType.Invalid) { + ProfileType = other.ProfileType; + } + if (other.RelativeSpeedup != 0D) { + RelativeSpeedup = other.RelativeSpeedup; + } + if (other.ProfileSource != global::Xla.ProfileSource.UnknownSource) { + ProfileSource = other.ProfileSource; + } + if (other.CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + CompilationEvent = other.CompilationEvent; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ProfileType = (global::Xla.HloModuleProto.Types.ProfileType) input.ReadEnum(); + break; + } + case 17: { + RelativeSpeedup = input.ReadDouble(); + break; + } + case 24: { + ProfileSource = (global::Xla.ProfileSource) input.ReadEnum(); + break; + } + case 32: { + CompilationEvent = (global::Xla.CompilationEvent) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ProfileType = (global::Xla.HloModuleProto.Types.ProfileType) input.ReadEnum(); + break; + } + case 17: { + RelativeSpeedup = input.ReadDouble(); + break; + } + case 24: { + ProfileSource = (global::Xla.ProfileSource) input.ReadEnum(); + break; + } + case 32: { + CompilationEvent = (global::Xla.CompilationEvent) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Serialization of LogicalBuffer. + /// + public sealed partial class LogicalBufferProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LogicalBufferProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogicalBufferProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogicalBufferProto(LogicalBufferProto other) : this() { + id_ = other.id_; + size_ = other.size_; + definedAt_ = other.definedAt_ != null ? other.definedAt_.Clone() : null; + color_ = other.color_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LogicalBufferProto Clone() { + return new LogicalBufferProto(this); + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 1; + private long id_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 2; + private long size_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + /// Field number for the "defined_at" field. + public const int DefinedAtFieldNumber = 3; + private global::Xla.LogicalBufferProto.Types.Location definedAt_; + /// + /// The location where the buffer is defined. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LogicalBufferProto.Types.Location DefinedAt { + get { return definedAt_; } + set { + definedAt_ = value; + } + } + + /// Field number for the "color" field. + public const int ColorFieldNumber = 4; + private long color_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Color { + get { return color_; } + set { + color_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LogicalBufferProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LogicalBufferProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Id != other.Id) return false; + if (Size != other.Size) return false; + if (!object.Equals(DefinedAt, other.DefinedAt)) return false; + if (Color != other.Color) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Id != 0L) hash ^= Id.GetHashCode(); + if (Size != 0L) hash ^= Size.GetHashCode(); + if (definedAt_ != null) hash ^= DefinedAt.GetHashCode(); + if (Color != 0L) hash ^= Color.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Id != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Id); + } + if (Size != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Size); + } + if (definedAt_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefinedAt); + } + if (Color != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Color); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Id != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Id); + } + if (Size != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Size); + } + if (definedAt_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefinedAt); + } + if (Color != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Color); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Id != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Id); + } + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (definedAt_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DefinedAt); + } + if (Color != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Color); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LogicalBufferProto other) { + if (other == null) { + return; + } + if (other.Id != 0L) { + Id = other.Id; + } + if (other.Size != 0L) { + Size = other.Size; + } + if (other.definedAt_ != null) { + if (definedAt_ == null) { + DefinedAt = new global::Xla.LogicalBufferProto.Types.Location(); + } + DefinedAt.MergeFrom(other.DefinedAt); + } + if (other.Color != 0L) { + Color = other.Color; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Id = input.ReadInt64(); + break; + } + case 16: { + Size = input.ReadInt64(); + break; + } + case 26: { + if (definedAt_ == null) { + DefinedAt = new global::Xla.LogicalBufferProto.Types.Location(); + } + input.ReadMessage(DefinedAt); + break; + } + case 32: { + Color = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Id = input.ReadInt64(); + break; + } + case 16: { + Size = input.ReadInt64(); + break; + } + case 26: { + if (definedAt_ == null) { + DefinedAt = new global::Xla.LogicalBufferProto.Types.Location(); + } + input.ReadMessage(DefinedAt); + break; + } + case 32: { + Color = input.ReadInt64(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the LogicalBufferProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Location represents an instruction and its shape index, which uniquely + /// identifies a point where a buffer is needed. + /// + public sealed partial class Location : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Location()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.LogicalBufferProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Location() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Location(Location other) : this() { + computationName_ = other.computationName_; + instructionName_ = other.instructionName_; + instructionId_ = other.instructionId_; + shapeIndex_ = other.shapeIndex_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Location Clone() { + return new Location(this); + } + + /// Field number for the "computation_name" field. + public const int ComputationNameFieldNumber = 1; + private string computationName_ = ""; + /// + /// NOTE: module_name isn't necessary, since all LogicalBuffers are + /// associated with a single HloModule. + /// TODO(b/239098765): Remove instruction_name and computation_name. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ComputationName { + get { return computationName_; } + set { + computationName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "instruction_name" field. + public const int InstructionNameFieldNumber = 2; + private string instructionName_ = ""; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string InstructionName { + get { return instructionName_; } + set { + instructionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "instruction_id" field. + public const int InstructionIdFieldNumber = 4; + private long instructionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long InstructionId { + get { return instructionId_; } + set { + instructionId_ = value; + } + } + + /// Field number for the "shape_index" field. + public const int ShapeIndexFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_shapeIndex_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField shapeIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ShapeIndex { + get { return shapeIndex_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Location); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Location other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ComputationName != other.ComputationName) return false; + if (InstructionName != other.InstructionName) return false; + if (InstructionId != other.InstructionId) return false; + if(!shapeIndex_.Equals(other.shapeIndex_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ComputationName.Length != 0) hash ^= ComputationName.GetHashCode(); + if (InstructionName.Length != 0) hash ^= InstructionName.GetHashCode(); + if (InstructionId != 0L) hash ^= InstructionId.GetHashCode(); + hash ^= shapeIndex_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ComputationName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ComputationName); + } + if (InstructionName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(InstructionName); + } + shapeIndex_.WriteTo(output, _repeated_shapeIndex_codec); + if (InstructionId != 0L) { + output.WriteRawTag(32); + output.WriteInt64(InstructionId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ComputationName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ComputationName); + } + if (InstructionName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(InstructionName); + } + shapeIndex_.WriteTo(ref output, _repeated_shapeIndex_codec); + if (InstructionId != 0L) { + output.WriteRawTag(32); + output.WriteInt64(InstructionId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ComputationName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ComputationName); + } + if (InstructionName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InstructionName); + } + if (InstructionId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(InstructionId); + } + size += shapeIndex_.CalculateSize(_repeated_shapeIndex_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Location other) { + if (other == null) { + return; + } + if (other.ComputationName.Length != 0) { + ComputationName = other.ComputationName; + } + if (other.InstructionName.Length != 0) { + InstructionName = other.InstructionName; + } + if (other.InstructionId != 0L) { + InstructionId = other.InstructionId; + } + shapeIndex_.Add(other.shapeIndex_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ComputationName = input.ReadString(); + break; + } + case 18: { + InstructionName = input.ReadString(); + break; + } + case 26: + case 24: { + shapeIndex_.AddEntriesFrom(input, _repeated_shapeIndex_codec); + break; + } + case 32: { + InstructionId = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ComputationName = input.ReadString(); + break; + } + case 18: { + InstructionName = input.ReadString(); + break; + } + case 26: + case 24: { + shapeIndex_.AddEntriesFrom(ref input, _repeated_shapeIndex_codec); + break; + } + case 32: { + InstructionId = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Serialization of BufferAllocation. + /// + public sealed partial class BufferAllocationProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BufferAllocationProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAllocationProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAllocationProto(BufferAllocationProto other) : this() { + index_ = other.index_; + size_ = other.size_; + isThreadLocal_ = other.isThreadLocal_; + isTuple_ = other.isTuple_; + isEntryComputationParameter_ = other.isEntryComputationParameter_; + isConstant_ = other.isConstant_; + parameterNumber_ = other.parameterNumber_; + parameterShapeIndex_ = other.parameterShapeIndex_.Clone(); + maybeLiveOut_ = other.maybeLiveOut_; + color_ = other.color_; + assigned_ = other.assigned_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAllocationProto Clone() { + return new BufferAllocationProto(this); + } + + /// Field number for the "index" field. + public const int IndexFieldNumber = 1; + private long index_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Index { + get { return index_; } + set { + index_ = value; + } + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 2; + private long size_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + /// Field number for the "is_thread_local" field. + public const int IsThreadLocalFieldNumber = 3; + private bool isThreadLocal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsThreadLocal { + get { return isThreadLocal_; } + set { + isThreadLocal_ = value; + } + } + + /// Field number for the "is_tuple" field. + public const int IsTupleFieldNumber = 11; + private bool isTuple_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsTuple { + get { return isTuple_; } + set { + isTuple_ = value; + } + } + + /// Field number for the "is_entry_computation_parameter" field. + public const int IsEntryComputationParameterFieldNumber = 5; + private bool isEntryComputationParameter_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsEntryComputationParameter { + get { return isEntryComputationParameter_; } + set { + isEntryComputationParameter_ = value; + } + } + + /// Field number for the "is_constant" field. + public const int IsConstantFieldNumber = 12; + private bool isConstant_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsConstant { + get { return isConstant_; } + set { + isConstant_ = value; + } + } + + /// Field number for the "parameter_number" field. + public const int ParameterNumberFieldNumber = 6; + private long parameterNumber_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ParameterNumber { + get { return parameterNumber_; } + set { + parameterNumber_ = value; + } + } + + /// Field number for the "parameter_shape_index" field. + public const int ParameterShapeIndexFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_parameterShapeIndex_codec + = pb::FieldCodec.ForInt64(82); + private readonly pbc::RepeatedField parameterShapeIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ParameterShapeIndex { + get { return parameterShapeIndex_; } + } + + /// Field number for the "maybe_live_out" field. + public const int MaybeLiveOutFieldNumber = 7; + private bool maybeLiveOut_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool MaybeLiveOut { + get { return maybeLiveOut_; } + set { + maybeLiveOut_ = value; + } + } + + /// Field number for the "color" field. + public const int ColorFieldNumber = 8; + private long color_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Color { + get { return color_; } + set { + color_ = value; + } + } + + /// Field number for the "assigned" field. + public const int AssignedFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_assigned_codec + = pb::FieldCodec.ForMessage(74, global::Xla.BufferAllocationProto.Types.Assigned.Parser); + private readonly pbc::RepeatedField assigned_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Assigned { + get { return assigned_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BufferAllocationProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BufferAllocationProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Index != other.Index) return false; + if (Size != other.Size) return false; + if (IsThreadLocal != other.IsThreadLocal) return false; + if (IsTuple != other.IsTuple) return false; + if (IsEntryComputationParameter != other.IsEntryComputationParameter) return false; + if (IsConstant != other.IsConstant) return false; + if (ParameterNumber != other.ParameterNumber) return false; + if(!parameterShapeIndex_.Equals(other.parameterShapeIndex_)) return false; + if (MaybeLiveOut != other.MaybeLiveOut) return false; + if (Color != other.Color) return false; + if(!assigned_.Equals(other.assigned_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Index != 0L) hash ^= Index.GetHashCode(); + if (Size != 0L) hash ^= Size.GetHashCode(); + if (IsThreadLocal != false) hash ^= IsThreadLocal.GetHashCode(); + if (IsTuple != false) hash ^= IsTuple.GetHashCode(); + if (IsEntryComputationParameter != false) hash ^= IsEntryComputationParameter.GetHashCode(); + if (IsConstant != false) hash ^= IsConstant.GetHashCode(); + if (ParameterNumber != 0L) hash ^= ParameterNumber.GetHashCode(); + hash ^= parameterShapeIndex_.GetHashCode(); + if (MaybeLiveOut != false) hash ^= MaybeLiveOut.GetHashCode(); + if (Color != 0L) hash ^= Color.GetHashCode(); + hash ^= assigned_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Index != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Index); + } + if (Size != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Size); + } + if (IsThreadLocal != false) { + output.WriteRawTag(24); + output.WriteBool(IsThreadLocal); + } + if (IsEntryComputationParameter != false) { + output.WriteRawTag(40); + output.WriteBool(IsEntryComputationParameter); + } + if (ParameterNumber != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ParameterNumber); + } + if (MaybeLiveOut != false) { + output.WriteRawTag(56); + output.WriteBool(MaybeLiveOut); + } + if (Color != 0L) { + output.WriteRawTag(64); + output.WriteInt64(Color); + } + assigned_.WriteTo(output, _repeated_assigned_codec); + parameterShapeIndex_.WriteTo(output, _repeated_parameterShapeIndex_codec); + if (IsTuple != false) { + output.WriteRawTag(88); + output.WriteBool(IsTuple); + } + if (IsConstant != false) { + output.WriteRawTag(96); + output.WriteBool(IsConstant); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Index != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Index); + } + if (Size != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Size); + } + if (IsThreadLocal != false) { + output.WriteRawTag(24); + output.WriteBool(IsThreadLocal); + } + if (IsEntryComputationParameter != false) { + output.WriteRawTag(40); + output.WriteBool(IsEntryComputationParameter); + } + if (ParameterNumber != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ParameterNumber); + } + if (MaybeLiveOut != false) { + output.WriteRawTag(56); + output.WriteBool(MaybeLiveOut); + } + if (Color != 0L) { + output.WriteRawTag(64); + output.WriteInt64(Color); + } + assigned_.WriteTo(ref output, _repeated_assigned_codec); + parameterShapeIndex_.WriteTo(ref output, _repeated_parameterShapeIndex_codec); + if (IsTuple != false) { + output.WriteRawTag(88); + output.WriteBool(IsTuple); + } + if (IsConstant != false) { + output.WriteRawTag(96); + output.WriteBool(IsConstant); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Index != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Index); + } + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (IsThreadLocal != false) { + size += 1 + 1; + } + if (IsTuple != false) { + size += 1 + 1; + } + if (IsEntryComputationParameter != false) { + size += 1 + 1; + } + if (IsConstant != false) { + size += 1 + 1; + } + if (ParameterNumber != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ParameterNumber); + } + size += parameterShapeIndex_.CalculateSize(_repeated_parameterShapeIndex_codec); + if (MaybeLiveOut != false) { + size += 1 + 1; + } + if (Color != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Color); + } + size += assigned_.CalculateSize(_repeated_assigned_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BufferAllocationProto other) { + if (other == null) { + return; + } + if (other.Index != 0L) { + Index = other.Index; + } + if (other.Size != 0L) { + Size = other.Size; + } + if (other.IsThreadLocal != false) { + IsThreadLocal = other.IsThreadLocal; + } + if (other.IsTuple != false) { + IsTuple = other.IsTuple; + } + if (other.IsEntryComputationParameter != false) { + IsEntryComputationParameter = other.IsEntryComputationParameter; + } + if (other.IsConstant != false) { + IsConstant = other.IsConstant; + } + if (other.ParameterNumber != 0L) { + ParameterNumber = other.ParameterNumber; + } + parameterShapeIndex_.Add(other.parameterShapeIndex_); + if (other.MaybeLiveOut != false) { + MaybeLiveOut = other.MaybeLiveOut; + } + if (other.Color != 0L) { + Color = other.Color; + } + assigned_.Add(other.assigned_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Index = input.ReadInt64(); + break; + } + case 16: { + Size = input.ReadInt64(); + break; + } + case 24: { + IsThreadLocal = input.ReadBool(); + break; + } + case 40: { + IsEntryComputationParameter = input.ReadBool(); + break; + } + case 48: { + ParameterNumber = input.ReadInt64(); + break; + } + case 56: { + MaybeLiveOut = input.ReadBool(); + break; + } + case 64: { + Color = input.ReadInt64(); + break; + } + case 74: { + assigned_.AddEntriesFrom(input, _repeated_assigned_codec); + break; + } + case 82: + case 80: { + parameterShapeIndex_.AddEntriesFrom(input, _repeated_parameterShapeIndex_codec); + break; + } + case 88: { + IsTuple = input.ReadBool(); + break; + } + case 96: { + IsConstant = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Index = input.ReadInt64(); + break; + } + case 16: { + Size = input.ReadInt64(); + break; + } + case 24: { + IsThreadLocal = input.ReadBool(); + break; + } + case 40: { + IsEntryComputationParameter = input.ReadBool(); + break; + } + case 48: { + ParameterNumber = input.ReadInt64(); + break; + } + case 56: { + MaybeLiveOut = input.ReadBool(); + break; + } + case 64: { + Color = input.ReadInt64(); + break; + } + case 74: { + assigned_.AddEntriesFrom(ref input, _repeated_assigned_codec); + break; + } + case 82: + case 80: { + parameterShapeIndex_.AddEntriesFrom(ref input, _repeated_parameterShapeIndex_codec); + break; + } + case 88: { + IsTuple = input.ReadBool(); + break; + } + case 96: { + IsConstant = input.ReadBool(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the BufferAllocationProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Assigned represents a single LogicalBuffer that is assigned to this + /// BufferAllocation. + /// + public sealed partial class Assigned : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Assigned()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.BufferAllocationProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Assigned() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Assigned(Assigned other) : this() { + logicalBufferId_ = other.logicalBufferId_; + offset_ = other.offset_; + size_ = other.size_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Assigned Clone() { + return new Assigned(this); + } + + /// Field number for the "logical_buffer_id" field. + public const int LogicalBufferIdFieldNumber = 1; + private long logicalBufferId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long LogicalBufferId { + get { return logicalBufferId_; } + set { + logicalBufferId_ = value; + } + } + + /// Field number for the "offset" field. + public const int OffsetFieldNumber = 2; + private long offset_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Offset { + get { return offset_; } + set { + offset_ = value; + } + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 3; + private long size_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Assigned); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Assigned other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LogicalBufferId != other.LogicalBufferId) return false; + if (Offset != other.Offset) return false; + if (Size != other.Size) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LogicalBufferId != 0L) hash ^= LogicalBufferId.GetHashCode(); + if (Offset != 0L) hash ^= Offset.GetHashCode(); + if (Size != 0L) hash ^= Size.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LogicalBufferId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(LogicalBufferId); + } + if (Offset != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Offset); + } + if (Size != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Size); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LogicalBufferId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(LogicalBufferId); + } + if (Offset != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Offset); + } + if (Size != 0L) { + output.WriteRawTag(24); + output.WriteInt64(Size); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LogicalBufferId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(LogicalBufferId); + } + if (Offset != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Offset); + } + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Assigned other) { + if (other == null) { + return; + } + if (other.LogicalBufferId != 0L) { + LogicalBufferId = other.LogicalBufferId; + } + if (other.Offset != 0L) { + Offset = other.Offset; + } + if (other.Size != 0L) { + Size = other.Size; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + LogicalBufferId = input.ReadInt64(); + break; + } + case 16: { + Offset = input.ReadInt64(); + break; + } + case 24: { + Size = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + LogicalBufferId = input.ReadInt64(); + break; + } + case 16: { + Offset = input.ReadInt64(); + break; + } + case 24: { + Size = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// A trace of a HeapSimulator run. + /// + public sealed partial class HeapSimulatorTrace : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeapSimulatorTrace()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeapSimulatorTrace() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeapSimulatorTrace(HeapSimulatorTrace other) : this() { + events_ = other.events_.Clone(); + wholeModuleSimulation_ = other.wholeModuleSimulation_; + bufferAllocationIndex_ = other.bufferAllocationIndex_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeapSimulatorTrace Clone() { + return new HeapSimulatorTrace(this); + } + + /// Field number for the "events" field. + public const int EventsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_events_codec + = pb::FieldCodec.ForMessage(10, global::Xla.HeapSimulatorTrace.Types.Event.Parser); + private readonly pbc::RepeatedField events_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Events { + get { return events_; } + } + + /// Field number for the "whole_module_simulation" field. + public const int WholeModuleSimulationFieldNumber = 2; + private bool wholeModuleSimulation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool WholeModuleSimulation { + get { return wholeModuleSimulation_; } + set { + wholeModuleSimulation_ = value; + } + } + + /// Field number for the "buffer_allocation_index" field. + public const int BufferAllocationIndexFieldNumber = 3; + private long bufferAllocationIndex_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BufferAllocationIndex { + get { return bufferAllocationIndex_; } + set { + bufferAllocationIndex_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HeapSimulatorTrace); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HeapSimulatorTrace other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!events_.Equals(other.events_)) return false; + if (WholeModuleSimulation != other.WholeModuleSimulation) return false; + if (BufferAllocationIndex != other.BufferAllocationIndex) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= events_.GetHashCode(); + if (WholeModuleSimulation != false) hash ^= WholeModuleSimulation.GetHashCode(); + if (BufferAllocationIndex != 0L) hash ^= BufferAllocationIndex.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + events_.WriteTo(output, _repeated_events_codec); + if (WholeModuleSimulation != false) { + output.WriteRawTag(16); + output.WriteBool(WholeModuleSimulation); + } + if (BufferAllocationIndex != 0L) { + output.WriteRawTag(24); + output.WriteInt64(BufferAllocationIndex); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + events_.WriteTo(ref output, _repeated_events_codec); + if (WholeModuleSimulation != false) { + output.WriteRawTag(16); + output.WriteBool(WholeModuleSimulation); + } + if (BufferAllocationIndex != 0L) { + output.WriteRawTag(24); + output.WriteInt64(BufferAllocationIndex); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += events_.CalculateSize(_repeated_events_codec); + if (WholeModuleSimulation != false) { + size += 1 + 1; + } + if (BufferAllocationIndex != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BufferAllocationIndex); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HeapSimulatorTrace other) { + if (other == null) { + return; + } + events_.Add(other.events_); + if (other.WholeModuleSimulation != false) { + WholeModuleSimulation = other.WholeModuleSimulation; + } + if (other.BufferAllocationIndex != 0L) { + BufferAllocationIndex = other.BufferAllocationIndex; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + events_.AddEntriesFrom(input, _repeated_events_codec); + break; + } + case 16: { + WholeModuleSimulation = input.ReadBool(); + break; + } + case 24: { + BufferAllocationIndex = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + events_.AddEntriesFrom(ref input, _repeated_events_codec); + break; + } + case 16: { + WholeModuleSimulation = input.ReadBool(); + break; + } + case 24: { + BufferAllocationIndex = input.ReadInt64(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the HeapSimulatorTrace message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// The trace includes a list of events, where each event describes one action + /// performed by the heap simulator. + /// + public sealed partial class Event : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Event()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HeapSimulatorTrace.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event(Event other) : this() { + kind_ = other.kind_; + bufferId_ = other.bufferId_; + computationName_ = other.computationName_; + instructionName_ = other.instructionName_; + shareWithCanonicalId_ = other.shareWithCanonicalId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Event Clone() { + return new Event(this); + } + + /// Field number for the "kind" field. + public const int KindFieldNumber = 1; + private global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind kind_ = global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind Kind { + get { return kind_; } + set { + kind_ = value; + } + } + + /// Field number for the "buffer_id" field. + public const int BufferIdFieldNumber = 2; + private long bufferId_; + /// + /// The id of the LogicalBuffer that the event applies to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BufferId { + get { return bufferId_; } + set { + bufferId_ = value; + } + } + + /// Field number for the "computation_name" field. + public const int ComputationNameFieldNumber = 3; + private string computationName_ = ""; + /// + /// The HloInstruction that the simulation was processing that caused this + /// event to occur, identified by its computation and instruction name. E.g. + /// buffers defined by instruction A are allocated when processing A. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ComputationName { + get { return computationName_; } + set { + computationName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "instruction_name" field. + public const int InstructionNameFieldNumber = 4; + private string instructionName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string InstructionName { + get { return instructionName_; } + set { + instructionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "share_with_canonical_id" field. + public const int ShareWithCanonicalIdFieldNumber = 5; + private long shareWithCanonicalId_; + /// + /// The id of the canonical LogicalBuffer that the buffer shares with. Only + /// set for SHARE_WITH events. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ShareWithCanonicalId { + get { return shareWithCanonicalId_; } + set { + shareWithCanonicalId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Event); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Event other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Kind != other.Kind) return false; + if (BufferId != other.BufferId) return false; + if (ComputationName != other.ComputationName) return false; + if (InstructionName != other.InstructionName) return false; + if (ShareWithCanonicalId != other.ShareWithCanonicalId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Kind != global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc) hash ^= Kind.GetHashCode(); + if (BufferId != 0L) hash ^= BufferId.GetHashCode(); + if (ComputationName.Length != 0) hash ^= ComputationName.GetHashCode(); + if (InstructionName.Length != 0) hash ^= InstructionName.GetHashCode(); + if (ShareWithCanonicalId != 0L) hash ^= ShareWithCanonicalId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Kind != global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc) { + output.WriteRawTag(8); + output.WriteEnum((int) Kind); + } + if (BufferId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(BufferId); + } + if (ComputationName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(ComputationName); + } + if (InstructionName.Length != 0) { + output.WriteRawTag(34); + output.WriteString(InstructionName); + } + if (ShareWithCanonicalId != 0L) { + output.WriteRawTag(40); + output.WriteInt64(ShareWithCanonicalId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Kind != global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc) { + output.WriteRawTag(8); + output.WriteEnum((int) Kind); + } + if (BufferId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(BufferId); + } + if (ComputationName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(ComputationName); + } + if (InstructionName.Length != 0) { + output.WriteRawTag(34); + output.WriteString(InstructionName); + } + if (ShareWithCanonicalId != 0L) { + output.WriteRawTag(40); + output.WriteInt64(ShareWithCanonicalId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Kind != global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Kind); + } + if (BufferId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BufferId); + } + if (ComputationName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ComputationName); + } + if (InstructionName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InstructionName); + } + if (ShareWithCanonicalId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShareWithCanonicalId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Event other) { + if (other == null) { + return; + } + if (other.Kind != global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind.Alloc) { + Kind = other.Kind; + } + if (other.BufferId != 0L) { + BufferId = other.BufferId; + } + if (other.ComputationName.Length != 0) { + ComputationName = other.ComputationName; + } + if (other.InstructionName.Length != 0) { + InstructionName = other.InstructionName; + } + if (other.ShareWithCanonicalId != 0L) { + ShareWithCanonicalId = other.ShareWithCanonicalId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Kind = (global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind) input.ReadEnum(); + break; + } + case 16: { + BufferId = input.ReadInt64(); + break; + } + case 26: { + ComputationName = input.ReadString(); + break; + } + case 34: { + InstructionName = input.ReadString(); + break; + } + case 40: { + ShareWithCanonicalId = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Kind = (global::Xla.HeapSimulatorTrace.Types.Event.Types.Kind) input.ReadEnum(); + break; + } + case 16: { + BufferId = input.ReadInt64(); + break; + } + case 26: { + ComputationName = input.ReadString(); + break; + } + case 34: { + InstructionName = input.ReadString(); + break; + } + case 40: { + ShareWithCanonicalId = input.ReadInt64(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Event message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Kind { + /// + /// A memory region was allocated for the buffer. + /// + [pbr::OriginalName("ALLOC")] Alloc = 0, + /// + /// A memory region was freed for the buffer. + /// + [pbr::OriginalName("FREE")] Free = 1, + /// + /// A buffer was shared with another (canonical) buffer. This is similar to + /// ALLOC, except that instead of allocating a new region of memory, the + /// memory region of the canonical buffer is directly re-used. Multiple + /// buffers may share with the same canonical buffer. The lifetime of the + /// canonical buffer is extended to the union of all lifetimes. + /// + [pbr::OriginalName("SHARE_WITH")] ShareWith = 2, + } + + } + #endregion + + } + + } + #endregion + + } + + /// + /// An abstraction representing a set of HLO module built to run concurrently + /// across different devices. + /// + public sealed partial class HloModuleGroupProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloModuleGroupProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleGroupProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleGroupProto(HloModuleGroupProto other) : this() { + name_ = other.name_; + hloModules_ = other.hloModules_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleGroupProto Clone() { + return new HloModuleGroupProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "hlo_modules" field. + public const int HloModulesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_hloModules_codec + = pb::FieldCodec.ForMessage(18, global::Xla.HloModuleProto.Parser); + private readonly pbc::RepeatedField hloModules_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField HloModules { + get { return hloModules_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloModuleGroupProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloModuleGroupProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!hloModules_.Equals(other.hloModules_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= hloModules_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + hloModules_.WriteTo(output, _repeated_hloModules_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + hloModules_.WriteTo(ref output, _repeated_hloModules_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += hloModules_.CalculateSize(_repeated_hloModules_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloModuleGroupProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + hloModules_.Add(other.hloModules_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + hloModules_.AddEntriesFrom(input, _repeated_hloModules_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + hloModules_.AddEntriesFrom(ref input, _repeated_hloModules_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Serialization of BufferAssignment. + /// + public sealed partial class BufferAssignmentProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BufferAssignmentProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAssignmentProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAssignmentProto(BufferAssignmentProto other) : this() { + logicalBuffers_ = other.logicalBuffers_.Clone(); + bufferAliases_ = other.bufferAliases_.Clone(); + bufferAllocations_ = other.bufferAllocations_.Clone(); + heapSimulatorTraces_ = other.heapSimulatorTraces_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAssignmentProto Clone() { + return new BufferAssignmentProto(this); + } + + /// Field number for the "logical_buffers" field. + public const int LogicalBuffersFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_logicalBuffers_codec + = pb::FieldCodec.ForMessage(10, global::Xla.LogicalBufferProto.Parser); + private readonly pbc::RepeatedField logicalBuffers_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LogicalBuffers { + get { return logicalBuffers_; } + } + + /// Field number for the "buffer_aliases" field. + public const int BufferAliasesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_bufferAliases_codec + = pb::FieldCodec.ForMessage(18, global::Xla.BufferAssignmentProto.Types.BufferAlias.Parser); + private readonly pbc::RepeatedField bufferAliases_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField BufferAliases { + get { return bufferAliases_; } + } + + /// Field number for the "buffer_allocations" field. + public const int BufferAllocationsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_bufferAllocations_codec + = pb::FieldCodec.ForMessage(26, global::Xla.BufferAllocationProto.Parser); + private readonly pbc::RepeatedField bufferAllocations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField BufferAllocations { + get { return bufferAllocations_; } + } + + /// Field number for the "heap_simulator_traces" field. + public const int HeapSimulatorTracesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_heapSimulatorTraces_codec + = pb::FieldCodec.ForMessage(34, global::Xla.HeapSimulatorTrace.Parser); + private readonly pbc::RepeatedField heapSimulatorTraces_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField HeapSimulatorTraces { + get { return heapSimulatorTraces_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BufferAssignmentProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BufferAssignmentProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!logicalBuffers_.Equals(other.logicalBuffers_)) return false; + if(!bufferAliases_.Equals(other.bufferAliases_)) return false; + if(!bufferAllocations_.Equals(other.bufferAllocations_)) return false; + if(!heapSimulatorTraces_.Equals(other.heapSimulatorTraces_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= logicalBuffers_.GetHashCode(); + hash ^= bufferAliases_.GetHashCode(); + hash ^= bufferAllocations_.GetHashCode(); + hash ^= heapSimulatorTraces_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + logicalBuffers_.WriteTo(output, _repeated_logicalBuffers_codec); + bufferAliases_.WriteTo(output, _repeated_bufferAliases_codec); + bufferAllocations_.WriteTo(output, _repeated_bufferAllocations_codec); + heapSimulatorTraces_.WriteTo(output, _repeated_heapSimulatorTraces_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + logicalBuffers_.WriteTo(ref output, _repeated_logicalBuffers_codec); + bufferAliases_.WriteTo(ref output, _repeated_bufferAliases_codec); + bufferAllocations_.WriteTo(ref output, _repeated_bufferAllocations_codec); + heapSimulatorTraces_.WriteTo(ref output, _repeated_heapSimulatorTraces_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += logicalBuffers_.CalculateSize(_repeated_logicalBuffers_codec); + size += bufferAliases_.CalculateSize(_repeated_bufferAliases_codec); + size += bufferAllocations_.CalculateSize(_repeated_bufferAllocations_codec); + size += heapSimulatorTraces_.CalculateSize(_repeated_heapSimulatorTraces_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BufferAssignmentProto other) { + if (other == null) { + return; + } + logicalBuffers_.Add(other.logicalBuffers_); + bufferAliases_.Add(other.bufferAliases_); + bufferAllocations_.Add(other.bufferAllocations_); + heapSimulatorTraces_.Add(other.heapSimulatorTraces_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + logicalBuffers_.AddEntriesFrom(input, _repeated_logicalBuffers_codec); + break; + } + case 18: { + bufferAliases_.AddEntriesFrom(input, _repeated_bufferAliases_codec); + break; + } + case 26: { + bufferAllocations_.AddEntriesFrom(input, _repeated_bufferAllocations_codec); + break; + } + case 34: { + heapSimulatorTraces_.AddEntriesFrom(input, _repeated_heapSimulatorTraces_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + logicalBuffers_.AddEntriesFrom(ref input, _repeated_logicalBuffers_codec); + break; + } + case 18: { + bufferAliases_.AddEntriesFrom(ref input, _repeated_bufferAliases_codec); + break; + } + case 26: { + bufferAllocations_.AddEntriesFrom(ref input, _repeated_bufferAllocations_codec); + break; + } + case 34: { + heapSimulatorTraces_.AddEntriesFrom(ref input, _repeated_heapSimulatorTraces_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the BufferAssignmentProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Alias represents a source LogicalBuffer, and the buffer location that + /// aliases it. + /// + public sealed partial class BufferAlias : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BufferAlias()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.BufferAssignmentProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAlias() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAlias(BufferAlias other) : this() { + sourceBufferId_ = other.sourceBufferId_; + location_ = other.location_ != null ? other.location_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferAlias Clone() { + return new BufferAlias(this); + } + + /// Field number for the "source_buffer_id" field. + public const int SourceBufferIdFieldNumber = 1; + private long sourceBufferId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long SourceBufferId { + get { return sourceBufferId_; } + set { + sourceBufferId_ = value; + } + } + + /// Field number for the "location" field. + public const int LocationFieldNumber = 2; + private global::Xla.LogicalBufferProto.Types.Location location_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LogicalBufferProto.Types.Location Location { + get { return location_; } + set { + location_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BufferAlias); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BufferAlias other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SourceBufferId != other.SourceBufferId) return false; + if (!object.Equals(Location, other.Location)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SourceBufferId != 0L) hash ^= SourceBufferId.GetHashCode(); + if (location_ != null) hash ^= Location.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SourceBufferId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(SourceBufferId); + } + if (location_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Location); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SourceBufferId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(SourceBufferId); + } + if (location_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Location); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SourceBufferId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(SourceBufferId); + } + if (location_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Location); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BufferAlias other) { + if (other == null) { + return; + } + if (other.SourceBufferId != 0L) { + SourceBufferId = other.SourceBufferId; + } + if (other.location_ != null) { + if (location_ == null) { + Location = new global::Xla.LogicalBufferProto.Types.Location(); + } + Location.MergeFrom(other.Location); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SourceBufferId = input.ReadInt64(); + break; + } + case 18: { + if (location_ == null) { + Location = new global::Xla.LogicalBufferProto.Types.Location(); + } + input.ReadMessage(Location); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SourceBufferId = input.ReadInt64(); + break; + } + case 18: { + if (location_ == null) { + Location = new global::Xla.LogicalBufferProto.Types.Location(); + } + input.ReadMessage(Location); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Grouping message that contains all of the information above. + /// + public sealed partial class HloProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloProto(HloProto other) : this() { + hloModule_ = other.hloModule_ != null ? other.hloModule_.Clone() : null; + bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloProto Clone() { + return new HloProto(this); + } + + /// Field number for the "hlo_module" field. + public const int HloModuleFieldNumber = 1; + private global::Xla.HloModuleProto hloModule_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto HloModule { + get { return hloModule_; } + set { + hloModule_ = value; + } + } + + /// Field number for the "buffer_assignment" field. + public const int BufferAssignmentFieldNumber = 3; + private global::Xla.BufferAssignmentProto bufferAssignment_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.BufferAssignmentProto BufferAssignment { + get { return bufferAssignment_; } + set { + bufferAssignment_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(HloModule, other.HloModule)) return false; + if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (hloModule_ != null) hash ^= HloModule.GetHashCode(); + if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (hloModule_ != null) { + output.WriteRawTag(10); + output.WriteMessage(HloModule); + } + if (bufferAssignment_ != null) { + output.WriteRawTag(26); + output.WriteMessage(BufferAssignment); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (hloModule_ != null) { + output.WriteRawTag(10); + output.WriteMessage(HloModule); + } + if (bufferAssignment_ != null) { + output.WriteRawTag(26); + output.WriteMessage(BufferAssignment); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (hloModule_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(HloModule); + } + if (bufferAssignment_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloProto other) { + if (other == null) { + return; + } + if (other.hloModule_ != null) { + if (hloModule_ == null) { + HloModule = new global::Xla.HloModuleProto(); + } + HloModule.MergeFrom(other.HloModule); + } + if (other.bufferAssignment_ != null) { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + BufferAssignment.MergeFrom(other.BufferAssignment); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (hloModule_ == null) { + HloModule = new global::Xla.HloModuleProto(); + } + input.ReadMessage(HloModule); + break; + } + case 26: { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + input.ReadMessage(BufferAssignment); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (hloModule_ == null) { + HloModule = new global::Xla.HloModuleProto(); + } + input.ReadMessage(HloModule); + break; + } + case 26: { + if (bufferAssignment_ == null) { + BufferAssignment = new global::Xla.BufferAssignmentProto(); + } + input.ReadMessage(BufferAssignment); + break; + } + } + } + } + #endif + + } + + /// + /// Encapsulates HloProto together with the arguments, result, and + /// execution_platform. This message is used for purposes such as + /// analysis/replay/file-storage. + /// + public sealed partial class HloSnapshot : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloSnapshot()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloSnapshot() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloSnapshot(HloSnapshot other) : this() { + hlo_ = other.hlo_ != null ? other.hlo_.Clone() : null; + arguments_ = other.arguments_.Clone(); + result_ = other.result_ != null ? other.result_.Clone() : null; + executionPlatform_ = other.executionPlatform_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloSnapshot Clone() { + return new HloSnapshot(this); + } + + /// Field number for the "hlo" field. + public const int HloFieldNumber = 1; + private global::Xla.HloProto hlo_; + /// + /// The hlo graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloProto Hlo { + get { return hlo_; } + set { + hlo_ = value; + } + } + + /// Field number for the "arguments" field. + public const int ArgumentsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_arguments_codec + = pb::FieldCodec.ForMessage(18, global::Xla.LiteralProto.Parser); + private readonly pbc::RepeatedField arguments_ = new pbc::RepeatedField(); + /// + /// The arguments passed to the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Arguments { + get { return arguments_; } + } + + /// Field number for the "result" field. + public const int ResultFieldNumber = 3; + private global::Xla.LiteralProto result_; + /// + /// The result of the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Result { + get { return result_; } + set { + result_ = value; + } + } + + /// Field number for the "execution_platform" field. + public const int ExecutionPlatformFieldNumber = 4; + private string executionPlatform_ = ""; + /// + /// The name of the platform used to run the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ExecutionPlatform { + get { return executionPlatform_; } + set { + executionPlatform_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloSnapshot); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloSnapshot other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Hlo, other.Hlo)) return false; + if(!arguments_.Equals(other.arguments_)) return false; + if (!object.Equals(Result, other.Result)) return false; + if (ExecutionPlatform != other.ExecutionPlatform) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (hlo_ != null) hash ^= Hlo.GetHashCode(); + hash ^= arguments_.GetHashCode(); + if (result_ != null) hash ^= Result.GetHashCode(); + if (ExecutionPlatform.Length != 0) hash ^= ExecutionPlatform.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (hlo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Hlo); + } + arguments_.WriteTo(output, _repeated_arguments_codec); + if (result_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Result); + } + if (ExecutionPlatform.Length != 0) { + output.WriteRawTag(34); + output.WriteString(ExecutionPlatform); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (hlo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Hlo); + } + arguments_.WriteTo(ref output, _repeated_arguments_codec); + if (result_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Result); + } + if (ExecutionPlatform.Length != 0) { + output.WriteRawTag(34); + output.WriteString(ExecutionPlatform); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (hlo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Hlo); + } + size += arguments_.CalculateSize(_repeated_arguments_codec); + if (result_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Result); + } + if (ExecutionPlatform.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ExecutionPlatform); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloSnapshot other) { + if (other == null) { + return; + } + if (other.hlo_ != null) { + if (hlo_ == null) { + Hlo = new global::Xla.HloProto(); + } + Hlo.MergeFrom(other.Hlo); + } + arguments_.Add(other.arguments_); + if (other.result_ != null) { + if (result_ == null) { + Result = new global::Xla.LiteralProto(); + } + Result.MergeFrom(other.Result); + } + if (other.ExecutionPlatform.Length != 0) { + ExecutionPlatform = other.ExecutionPlatform; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (hlo_ == null) { + Hlo = new global::Xla.HloProto(); + } + input.ReadMessage(Hlo); + break; + } + case 18: { + arguments_.AddEntriesFrom(input, _repeated_arguments_codec); + break; + } + case 26: { + if (result_ == null) { + Result = new global::Xla.LiteralProto(); + } + input.ReadMessage(Result); + break; + } + case 34: { + ExecutionPlatform = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (hlo_ == null) { + Hlo = new global::Xla.HloProto(); + } + input.ReadMessage(Hlo); + break; + } + case 18: { + arguments_.AddEntriesFrom(ref input, _repeated_arguments_codec); + break; + } + case 26: { + if (result_ == null) { + Result = new global::Xla.LiteralProto(); + } + input.ReadMessage(Result); + break; + } + case 34: { + ExecutionPlatform = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering + /// with filename module_####.metadata.textproto, where #### is + /// canonical_module_id. + /// + public sealed partial class HloModuleMetadataProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloModuleMetadataProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleMetadataProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleMetadataProto(HloModuleMetadataProto other) : this() { + canonicalModuleId_ = other.canonicalModuleId_; + moduleGroupName_ = other.moduleGroupName_; + originalModuleId_ = other.originalModuleId_; + partitionedModuleIds_ = other.partitionedModuleIds_.Clone(); + passMetadata_ = other.passMetadata_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloModuleMetadataProto Clone() { + return new HloModuleMetadataProto(this); + } + + /// Field number for the "canonical_module_id" field. + public const int CanonicalModuleIdFieldNumber = 1; + private long canonicalModuleId_; + /// + /// Uniquely identifies an HloModuleMetadata. Equal to the first unique_id + /// of the module (a module may go through multiple unique_ids). If a module + /// is partitioned into multiple modules, those modules will each have a new + /// HloModuleMetadata with a different canonical_module_id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long CanonicalModuleId { + get { return canonicalModuleId_; } + set { + canonicalModuleId_ = value; + } + } + + /// Field number for the "module_group_name" field. + public const int ModuleGroupNameFieldNumber = 2; + private string moduleGroupName_ = ""; + /// + /// Name of the module group that the module is part of. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ModuleGroupName { + get { return moduleGroupName_; } + set { + moduleGroupName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "original_module_id" field. + public const int OriginalModuleIdFieldNumber = 3; + private long originalModuleId_; + /// + /// The canonical module id of the module that this one is partitioned from, + /// if applicable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OriginalModuleId { + get { return originalModuleId_; } + set { + originalModuleId_ = value; + } + } + + /// Field number for the "partitioned_module_ids" field. + public const int PartitionedModuleIdsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_partitionedModuleIds_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField partitionedModuleIds_ = new pbc::RepeatedField(); + /// + /// The canonical module ids of the modules that this one is partitioned into, + /// if applicable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField PartitionedModuleIds { + get { return partitionedModuleIds_; } + } + + /// Field number for the "pass_metadata" field. + public const int PassMetadataFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_passMetadata_codec + = pb::FieldCodec.ForMessage(42, global::Xla.HloPassMetadata.Parser); + private readonly pbc::RepeatedField passMetadata_ = new pbc::RepeatedField(); + /// + /// Metadata for the HLO passes that are run on the module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField PassMetadata { + get { return passMetadata_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloModuleMetadataProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloModuleMetadataProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (CanonicalModuleId != other.CanonicalModuleId) return false; + if (ModuleGroupName != other.ModuleGroupName) return false; + if (OriginalModuleId != other.OriginalModuleId) return false; + if(!partitionedModuleIds_.Equals(other.partitionedModuleIds_)) return false; + if(!passMetadata_.Equals(other.passMetadata_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (CanonicalModuleId != 0L) hash ^= CanonicalModuleId.GetHashCode(); + if (ModuleGroupName.Length != 0) hash ^= ModuleGroupName.GetHashCode(); + if (OriginalModuleId != 0L) hash ^= OriginalModuleId.GetHashCode(); + hash ^= partitionedModuleIds_.GetHashCode(); + hash ^= passMetadata_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (CanonicalModuleId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(CanonicalModuleId); + } + if (ModuleGroupName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ModuleGroupName); + } + if (OriginalModuleId != 0L) { + output.WriteRawTag(24); + output.WriteInt64(OriginalModuleId); + } + partitionedModuleIds_.WriteTo(output, _repeated_partitionedModuleIds_codec); + passMetadata_.WriteTo(output, _repeated_passMetadata_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (CanonicalModuleId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(CanonicalModuleId); + } + if (ModuleGroupName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ModuleGroupName); + } + if (OriginalModuleId != 0L) { + output.WriteRawTag(24); + output.WriteInt64(OriginalModuleId); + } + partitionedModuleIds_.WriteTo(ref output, _repeated_partitionedModuleIds_codec); + passMetadata_.WriteTo(ref output, _repeated_passMetadata_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (CanonicalModuleId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(CanonicalModuleId); + } + if (ModuleGroupName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ModuleGroupName); + } + if (OriginalModuleId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OriginalModuleId); + } + size += partitionedModuleIds_.CalculateSize(_repeated_partitionedModuleIds_codec); + size += passMetadata_.CalculateSize(_repeated_passMetadata_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloModuleMetadataProto other) { + if (other == null) { + return; + } + if (other.CanonicalModuleId != 0L) { + CanonicalModuleId = other.CanonicalModuleId; + } + if (other.ModuleGroupName.Length != 0) { + ModuleGroupName = other.ModuleGroupName; + } + if (other.OriginalModuleId != 0L) { + OriginalModuleId = other.OriginalModuleId; + } + partitionedModuleIds_.Add(other.partitionedModuleIds_); + passMetadata_.Add(other.passMetadata_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + CanonicalModuleId = input.ReadInt64(); + break; + } + case 18: { + ModuleGroupName = input.ReadString(); + break; + } + case 24: { + OriginalModuleId = input.ReadInt64(); + break; + } + case 34: + case 32: { + partitionedModuleIds_.AddEntriesFrom(input, _repeated_partitionedModuleIds_codec); + break; + } + case 42: { + passMetadata_.AddEntriesFrom(input, _repeated_passMetadata_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + CanonicalModuleId = input.ReadInt64(); + break; + } + case 18: { + ModuleGroupName = input.ReadString(); + break; + } + case 24: { + OriginalModuleId = input.ReadInt64(); + break; + } + case 34: + case 32: { + partitionedModuleIds_.AddEntriesFrom(ref input, _repeated_partitionedModuleIds_codec); + break; + } + case 42: { + passMetadata_.AddEntriesFrom(ref input, _repeated_passMetadata_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Metadata for one run of an HLO pass on a module. Provides more information + /// when processing debug dumps of HloProtos about the order of HLO passes and + /// various other stats like duration. `pass_id` may also be used to identify a + /// particular run of a pass in debug info that propagates through stages of + /// compilation. + /// + public sealed partial class HloPassMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HloPassMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloPassMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloPassMetadata(HloPassMetadata other) : this() { + passId_ = other.passId_; + passName_ = other.passName_; + pipelineName_ = other.pipelineName_; + dumpFilenames_ = other.dumpFilenames_.Clone(); + moduleChanged_ = other.moduleChanged_; + moduleId_ = other.moduleId_; + moduleGroupModuleIds_ = other.moduleGroupModuleIds_.Clone(); + startTimestampUsec_ = other.startTimestampUsec_; + endTimestampUsec_ = other.endTimestampUsec_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HloPassMetadata Clone() { + return new HloPassMetadata(this); + } + + /// Field number for the "pass_id" field. + public const int PassIdFieldNumber = 1; + private long passId_; + /// + /// For a given module, pass_id uniquely identifies a run of an HLO pass on + /// that module. Note that a pass_id may not always refer to the same pass + /// because the order of passes during compilation may change. For finding + /// metadata for a particular pass, pass_name and pipeline_name would be more + /// reliable, although note that they may not be unique. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PassId { + get { return passId_; } + set { + passId_ = value; + } + } + + /// Field number for the "pass_name" field. + public const int PassNameFieldNumber = 2; + private string passName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PassName { + get { return passName_; } + set { + passName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pipeline_name" field. + public const int PipelineNameFieldNumber = 3; + private string pipelineName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PipelineName { + get { return pipelineName_; } + set { + pipelineName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "dump_filenames" field. + public const int DumpFilenamesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_dumpFilenames_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField dumpFilenames_ = new pbc::RepeatedField(); + /// + /// Filenames of the dumps of the module after this pass ran. Module may be + /// dumped in multiple formats, and the order of formats in this field will + /// stay consistent across passes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DumpFilenames { + get { return dumpFilenames_; } + } + + /// Field number for the "module_changed" field. + public const int ModuleChangedFieldNumber = 5; + private bool moduleChanged_; + /// + /// Return value of pass.Run(). True if this pass changed the module, or, in + /// the case where the module was run through this pass as part of a module + /// group, true if this pass changed any module in the same module group. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ModuleChanged { + get { return moduleChanged_; } + set { + moduleChanged_ = value; + } + } + + /// Field number for the "module_id" field. + public const int ModuleIdFieldNumber = 6; + private long moduleId_; + /// + /// The unique_id of the module that this pass is run on. May be different from + /// the canonical_module_id of the HloModuleMetadata that this HloPassMetadata + /// is inside. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ModuleId { + get { return moduleId_; } + set { + moduleId_ = value; + } + } + + /// Field number for the "module_group_module_ids" field. + public const int ModuleGroupModuleIdsFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_moduleGroupModuleIds_codec + = pb::FieldCodec.ForInt64(58); + private readonly pbc::RepeatedField moduleGroupModuleIds_ = new pbc::RepeatedField(); + /// + /// If the module went through this pass as part of a module group, this is + /// set as the ids of all the modules in the module group. Empty otherwise. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ModuleGroupModuleIds { + get { return moduleGroupModuleIds_; } + } + + /// Field number for the "start_timestamp_usec" field. + public const int StartTimestampUsecFieldNumber = 8; + private long startTimestampUsec_; + /// + /// Timestamp before and after the pass is run. Note they may be equal. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StartTimestampUsec { + get { return startTimestampUsec_; } + set { + startTimestampUsec_ = value; + } + } + + /// Field number for the "end_timestamp_usec" field. + public const int EndTimestampUsecFieldNumber = 9; + private long endTimestampUsec_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long EndTimestampUsec { + get { return endTimestampUsec_; } + set { + endTimestampUsec_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HloPassMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HloPassMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (PassId != other.PassId) return false; + if (PassName != other.PassName) return false; + if (PipelineName != other.PipelineName) return false; + if(!dumpFilenames_.Equals(other.dumpFilenames_)) return false; + if (ModuleChanged != other.ModuleChanged) return false; + if (ModuleId != other.ModuleId) return false; + if(!moduleGroupModuleIds_.Equals(other.moduleGroupModuleIds_)) return false; + if (StartTimestampUsec != other.StartTimestampUsec) return false; + if (EndTimestampUsec != other.EndTimestampUsec) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (PassId != 0L) hash ^= PassId.GetHashCode(); + if (PassName.Length != 0) hash ^= PassName.GetHashCode(); + if (PipelineName.Length != 0) hash ^= PipelineName.GetHashCode(); + hash ^= dumpFilenames_.GetHashCode(); + if (ModuleChanged != false) hash ^= ModuleChanged.GetHashCode(); + if (ModuleId != 0L) hash ^= ModuleId.GetHashCode(); + hash ^= moduleGroupModuleIds_.GetHashCode(); + if (StartTimestampUsec != 0L) hash ^= StartTimestampUsec.GetHashCode(); + if (EndTimestampUsec != 0L) hash ^= EndTimestampUsec.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (PassId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(PassId); + } + if (PassName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PassName); + } + if (PipelineName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PipelineName); + } + dumpFilenames_.WriteTo(output, _repeated_dumpFilenames_codec); + if (ModuleChanged != false) { + output.WriteRawTag(40); + output.WriteBool(ModuleChanged); + } + if (ModuleId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ModuleId); + } + moduleGroupModuleIds_.WriteTo(output, _repeated_moduleGroupModuleIds_codec); + if (StartTimestampUsec != 0L) { + output.WriteRawTag(64); + output.WriteInt64(StartTimestampUsec); + } + if (EndTimestampUsec != 0L) { + output.WriteRawTag(72); + output.WriteInt64(EndTimestampUsec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (PassId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(PassId); + } + if (PassName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PassName); + } + if (PipelineName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PipelineName); + } + dumpFilenames_.WriteTo(ref output, _repeated_dumpFilenames_codec); + if (ModuleChanged != false) { + output.WriteRawTag(40); + output.WriteBool(ModuleChanged); + } + if (ModuleId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ModuleId); + } + moduleGroupModuleIds_.WriteTo(ref output, _repeated_moduleGroupModuleIds_codec); + if (StartTimestampUsec != 0L) { + output.WriteRawTag(64); + output.WriteInt64(StartTimestampUsec); + } + if (EndTimestampUsec != 0L) { + output.WriteRawTag(72); + output.WriteInt64(EndTimestampUsec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (PassId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PassId); + } + if (PassName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PassName); + } + if (PipelineName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PipelineName); + } + size += dumpFilenames_.CalculateSize(_repeated_dumpFilenames_codec); + if (ModuleChanged != false) { + size += 1 + 1; + } + if (ModuleId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ModuleId); + } + size += moduleGroupModuleIds_.CalculateSize(_repeated_moduleGroupModuleIds_codec); + if (StartTimestampUsec != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StartTimestampUsec); + } + if (EndTimestampUsec != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(EndTimestampUsec); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HloPassMetadata other) { + if (other == null) { + return; + } + if (other.PassId != 0L) { + PassId = other.PassId; + } + if (other.PassName.Length != 0) { + PassName = other.PassName; + } + if (other.PipelineName.Length != 0) { + PipelineName = other.PipelineName; + } + dumpFilenames_.Add(other.dumpFilenames_); + if (other.ModuleChanged != false) { + ModuleChanged = other.ModuleChanged; + } + if (other.ModuleId != 0L) { + ModuleId = other.ModuleId; + } + moduleGroupModuleIds_.Add(other.moduleGroupModuleIds_); + if (other.StartTimestampUsec != 0L) { + StartTimestampUsec = other.StartTimestampUsec; + } + if (other.EndTimestampUsec != 0L) { + EndTimestampUsec = other.EndTimestampUsec; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + PassId = input.ReadInt64(); + break; + } + case 18: { + PassName = input.ReadString(); + break; + } + case 26: { + PipelineName = input.ReadString(); + break; + } + case 34: { + dumpFilenames_.AddEntriesFrom(input, _repeated_dumpFilenames_codec); + break; + } + case 40: { + ModuleChanged = input.ReadBool(); + break; + } + case 48: { + ModuleId = input.ReadInt64(); + break; + } + case 58: + case 56: { + moduleGroupModuleIds_.AddEntriesFrom(input, _repeated_moduleGroupModuleIds_codec); + break; + } + case 64: { + StartTimestampUsec = input.ReadInt64(); + break; + } + case 72: { + EndTimestampUsec = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + PassId = input.ReadInt64(); + break; + } + case 18: { + PassName = input.ReadString(); + break; + } + case 26: { + PipelineName = input.ReadString(); + break; + } + case 34: { + dumpFilenames_.AddEntriesFrom(ref input, _repeated_dumpFilenames_codec); + break; + } + case 40: { + ModuleChanged = input.ReadBool(); + break; + } + case 48: { + ModuleId = input.ReadInt64(); + break; + } + case 58: + case 56: { + moduleGroupModuleIds_.AddEntriesFrom(ref input, _repeated_moduleGroupModuleIds_codec); + break; + } + case 64: { + StartTimestampUsec = input.ReadInt64(); + break; + } + case 72: { + EndTimestampUsec = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Encodes attributes for an entry function. + /// + public sealed partial class EntryFunctionAttributes : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EntryFunctionAttributes()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EntryFunctionAttributes() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EntryFunctionAttributes(EntryFunctionAttributes other) : this() { + buffers_ = other.buffers_.Clone(); + resultXlaShape_ = other.resultXlaShape_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EntryFunctionAttributes Clone() { + return new EntryFunctionAttributes(this); + } + + /// Field number for the "buffers" field. + public const int BuffersFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_buffers_codec + = pb::FieldCodec.ForMessage(10, global::Xla.EntryFunctionAttributes.Types.BufferParameterAttributes.Parser); + private readonly pbc::RepeatedField buffers_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Buffers { + get { return buffers_; } + } + + /// Field number for the "result_xla_shape" field. + public const int ResultXlaShapeFieldNumber = 2; + private string resultXlaShape_ = ""; + /// + /// xla::Shape in string format. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ResultXlaShape { + get { return resultXlaShape_; } + set { + resultXlaShape_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as EntryFunctionAttributes); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(EntryFunctionAttributes other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!buffers_.Equals(other.buffers_)) return false; + if (ResultXlaShape != other.ResultXlaShape) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= buffers_.GetHashCode(); + if (ResultXlaShape.Length != 0) hash ^= ResultXlaShape.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + buffers_.WriteTo(output, _repeated_buffers_codec); + if (ResultXlaShape.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ResultXlaShape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + buffers_.WriteTo(ref output, _repeated_buffers_codec); + if (ResultXlaShape.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ResultXlaShape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += buffers_.CalculateSize(_repeated_buffers_codec); + if (ResultXlaShape.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ResultXlaShape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(EntryFunctionAttributes other) { + if (other == null) { + return; + } + buffers_.Add(other.buffers_); + if (other.ResultXlaShape.Length != 0) { + ResultXlaShape = other.ResultXlaShape; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + buffers_.AddEntriesFrom(input, _repeated_buffers_codec); + break; + } + case 18: { + ResultXlaShape = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + buffers_.AddEntriesFrom(ref input, _repeated_buffers_codec); + break; + } + case 18: { + ResultXlaShape = input.ReadString(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the EntryFunctionAttributes message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Acts as the underlying container for an xla::ShapeIndex. + /// + public sealed partial class ShapeIndex : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShapeIndex()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.EntryFunctionAttributes.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeIndex() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeIndex(ShapeIndex other) : this() { + indices_ = other.indices_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeIndex Clone() { + return new ShapeIndex(this); + } + + /// Field number for the "indices" field. + public const int IndicesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_indices_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField indices_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Indices { + get { return indices_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShapeIndex); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShapeIndex other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!indices_.Equals(other.indices_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= indices_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + indices_.WriteTo(output, _repeated_indices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + indices_.WriteTo(ref output, _repeated_indices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += indices_.CalculateSize(_repeated_indices_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShapeIndex other) { + if (other == null) { + return; + } + indices_.Add(other.indices_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + indices_.AddEntriesFrom(input, _repeated_indices_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + indices_.AddEntriesFrom(ref input, _repeated_indices_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Encodes attributes for a single buffer parameter. + /// + public sealed partial class BufferParameterAttributes : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BufferParameterAttributes()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.EntryFunctionAttributes.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferParameterAttributes() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferParameterAttributes(BufferParameterAttributes other) : this() { + lmhloParams_ = other.lmhloParams_; + lmhloParamsPresent_ = other.lmhloParamsPresent_; + lmhloParamShapeIndex_ = other.lmhloParamShapeIndex_ != null ? other.lmhloParamShapeIndex_.Clone() : null; + lmhloConstantName_ = other.lmhloConstantName_; + lmhloMustAlias_ = other.lmhloMustAlias_; + lmhloOutputIndex_ = other.lmhloOutputIndex_ != null ? other.lmhloOutputIndex_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BufferParameterAttributes Clone() { + return new BufferParameterAttributes(this); + } + + /// Field number for the "lmhlo_params" field. + public const int LmhloParamsFieldNumber = 1; + private long lmhloParams_; + /// + /// Represents an lmhlo.params function argument attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long LmhloParams { + get { return lmhloParams_; } + set { + lmhloParams_ = value; + } + } + + /// Field number for the "lmhlo_params_present" field. + public const int LmhloParamsPresentFieldNumber = 6; + private bool lmhloParamsPresent_; + /// + /// TODO(hanbinyoon): Deprecate when optional fields are available in proto3 + /// (Protocol Buffers v3.15.0). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool LmhloParamsPresent { + get { return lmhloParamsPresent_; } + set { + lmhloParamsPresent_ = value; + } + } + + /// Field number for the "lmhlo_param_shape_index" field. + public const int LmhloParamShapeIndexFieldNumber = 2; + private global::Xla.EntryFunctionAttributes.Types.ShapeIndex lmhloParamShapeIndex_; + /// + /// Represents an lmhlo.param_shape_index function argument attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.EntryFunctionAttributes.Types.ShapeIndex LmhloParamShapeIndex { + get { return lmhloParamShapeIndex_; } + set { + lmhloParamShapeIndex_ = value; + } + } + + /// Field number for the "lmhlo_constant_name" field. + public const int LmhloConstantNameFieldNumber = 3; + private string lmhloConstantName_ = ""; + /// + /// Represents an lmhlo.constant_name function argument attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string LmhloConstantName { + get { return lmhloConstantName_; } + set { + lmhloConstantName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "lmhlo_must_alias" field. + public const int LmhloMustAliasFieldNumber = 4; + private bool lmhloMustAlias_; + /// + /// Represents an lmhlo.must_alias function argument attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool LmhloMustAlias { + get { return lmhloMustAlias_; } + set { + lmhloMustAlias_ = value; + } + } + + /// Field number for the "lmhlo_output_index" field. + public const int LmhloOutputIndexFieldNumber = 5; + private global::Xla.EntryFunctionAttributes.Types.ShapeIndex lmhloOutputIndex_; + /// + /// Represents an lmhlo.params function argument attribute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.EntryFunctionAttributes.Types.ShapeIndex LmhloOutputIndex { + get { return lmhloOutputIndex_; } + set { + lmhloOutputIndex_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BufferParameterAttributes); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BufferParameterAttributes other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LmhloParams != other.LmhloParams) return false; + if (LmhloParamsPresent != other.LmhloParamsPresent) return false; + if (!object.Equals(LmhloParamShapeIndex, other.LmhloParamShapeIndex)) return false; + if (LmhloConstantName != other.LmhloConstantName) return false; + if (LmhloMustAlias != other.LmhloMustAlias) return false; + if (!object.Equals(LmhloOutputIndex, other.LmhloOutputIndex)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LmhloParams != 0L) hash ^= LmhloParams.GetHashCode(); + if (LmhloParamsPresent != false) hash ^= LmhloParamsPresent.GetHashCode(); + if (lmhloParamShapeIndex_ != null) hash ^= LmhloParamShapeIndex.GetHashCode(); + if (LmhloConstantName.Length != 0) hash ^= LmhloConstantName.GetHashCode(); + if (LmhloMustAlias != false) hash ^= LmhloMustAlias.GetHashCode(); + if (lmhloOutputIndex_ != null) hash ^= LmhloOutputIndex.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LmhloParams != 0L) { + output.WriteRawTag(8); + output.WriteInt64(LmhloParams); + } + if (lmhloParamShapeIndex_ != null) { + output.WriteRawTag(18); + output.WriteMessage(LmhloParamShapeIndex); + } + if (LmhloConstantName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(LmhloConstantName); + } + if (LmhloMustAlias != false) { + output.WriteRawTag(32); + output.WriteBool(LmhloMustAlias); + } + if (lmhloOutputIndex_ != null) { + output.WriteRawTag(42); + output.WriteMessage(LmhloOutputIndex); + } + if (LmhloParamsPresent != false) { + output.WriteRawTag(48); + output.WriteBool(LmhloParamsPresent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LmhloParams != 0L) { + output.WriteRawTag(8); + output.WriteInt64(LmhloParams); + } + if (lmhloParamShapeIndex_ != null) { + output.WriteRawTag(18); + output.WriteMessage(LmhloParamShapeIndex); + } + if (LmhloConstantName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(LmhloConstantName); + } + if (LmhloMustAlias != false) { + output.WriteRawTag(32); + output.WriteBool(LmhloMustAlias); + } + if (lmhloOutputIndex_ != null) { + output.WriteRawTag(42); + output.WriteMessage(LmhloOutputIndex); + } + if (LmhloParamsPresent != false) { + output.WriteRawTag(48); + output.WriteBool(LmhloParamsPresent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LmhloParams != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(LmhloParams); + } + if (LmhloParamsPresent != false) { + size += 1 + 1; + } + if (lmhloParamShapeIndex_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LmhloParamShapeIndex); + } + if (LmhloConstantName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(LmhloConstantName); + } + if (LmhloMustAlias != false) { + size += 1 + 1; + } + if (lmhloOutputIndex_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LmhloOutputIndex); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BufferParameterAttributes other) { + if (other == null) { + return; + } + if (other.LmhloParams != 0L) { + LmhloParams = other.LmhloParams; + } + if (other.LmhloParamsPresent != false) { + LmhloParamsPresent = other.LmhloParamsPresent; + } + if (other.lmhloParamShapeIndex_ != null) { + if (lmhloParamShapeIndex_ == null) { + LmhloParamShapeIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + LmhloParamShapeIndex.MergeFrom(other.LmhloParamShapeIndex); + } + if (other.LmhloConstantName.Length != 0) { + LmhloConstantName = other.LmhloConstantName; + } + if (other.LmhloMustAlias != false) { + LmhloMustAlias = other.LmhloMustAlias; + } + if (other.lmhloOutputIndex_ != null) { + if (lmhloOutputIndex_ == null) { + LmhloOutputIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + LmhloOutputIndex.MergeFrom(other.LmhloOutputIndex); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + LmhloParams = input.ReadInt64(); + break; + } + case 18: { + if (lmhloParamShapeIndex_ == null) { + LmhloParamShapeIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + input.ReadMessage(LmhloParamShapeIndex); + break; + } + case 26: { + LmhloConstantName = input.ReadString(); + break; + } + case 32: { + LmhloMustAlias = input.ReadBool(); + break; + } + case 42: { + if (lmhloOutputIndex_ == null) { + LmhloOutputIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + input.ReadMessage(LmhloOutputIndex); + break; + } + case 48: { + LmhloParamsPresent = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + LmhloParams = input.ReadInt64(); + break; + } + case 18: { + if (lmhloParamShapeIndex_ == null) { + LmhloParamShapeIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + input.ReadMessage(LmhloParamShapeIndex); + break; + } + case 26: { + LmhloConstantName = input.ReadString(); + break; + } + case 32: { + LmhloMustAlias = input.ReadBool(); + break; + } + case 42: { + if (lmhloOutputIndex_ == null) { + LmhloOutputIndex = new global::Xla.EntryFunctionAttributes.Types.ShapeIndex(); + } + input.ReadMessage(LmhloOutputIndex); + break; + } + case 48: { + LmhloParamsPresent = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Encodes the underlying Xla runtime executable compiled from the XLA module. + /// + public sealed partial class XlaRuntimeExecutableProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XlaRuntimeExecutableProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.HloReflection.Descriptor.MessageTypes[17]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeExecutableProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeExecutableProto(XlaRuntimeExecutableProto other) : this() { + hloModuleProto_ = other.hloModuleProto_ != null ? other.hloModuleProto_.Clone() : null; + entryFuncAttrs_ = other.entryFuncAttrs_ != null ? other.entryFuncAttrs_.Clone() : null; + objFile_ = other.objFile_; + mlirModule_ = other.mlirModule_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaRuntimeExecutableProto Clone() { + return new XlaRuntimeExecutableProto(this); + } + + /// Field number for the "hlo_module_proto" field. + public const int HloModuleProtoFieldNumber = 1; + private global::Xla.HloModuleProto hloModuleProto_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto HloModuleProto { + get { return hloModuleProto_; } + set { + hloModuleProto_ = value; + } + } + + /// Field number for the "entry_func_attrs" field. + public const int EntryFuncAttrsFieldNumber = 2; + private global::Xla.EntryFunctionAttributes entryFuncAttrs_; + /// + /// XLA-specific attributes of the executable's entry function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.EntryFunctionAttributes EntryFuncAttrs { + get { return entryFuncAttrs_; } + set { + entryFuncAttrs_ = value; + } + } + + /// Field number for the "obj_file" field. + public const int ObjFileFieldNumber = 3; + private pb::ByteString objFile_ = pb::ByteString.Empty; + /// + /// Serialized object file compiled from the XLA module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString ObjFile { + get { return objFile_; } + set { + objFile_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "mlir_module" field. + public const int MlirModuleFieldNumber = 4; + private string mlirModule_ = ""; + /// + /// Serialized MLIR module corresponding to compiled object file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string MlirModule { + get { return mlirModule_; } + set { + mlirModule_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as XlaRuntimeExecutableProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(XlaRuntimeExecutableProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(HloModuleProto, other.HloModuleProto)) return false; + if (!object.Equals(EntryFuncAttrs, other.EntryFuncAttrs)) return false; + if (ObjFile != other.ObjFile) return false; + if (MlirModule != other.MlirModule) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (hloModuleProto_ != null) hash ^= HloModuleProto.GetHashCode(); + if (entryFuncAttrs_ != null) hash ^= EntryFuncAttrs.GetHashCode(); + if (ObjFile.Length != 0) hash ^= ObjFile.GetHashCode(); + if (MlirModule.Length != 0) hash ^= MlirModule.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (hloModuleProto_ != null) { + output.WriteRawTag(10); + output.WriteMessage(HloModuleProto); + } + if (entryFuncAttrs_ != null) { + output.WriteRawTag(18); + output.WriteMessage(EntryFuncAttrs); + } + if (ObjFile.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(ObjFile); + } + if (MlirModule.Length != 0) { + output.WriteRawTag(34); + output.WriteString(MlirModule); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (hloModuleProto_ != null) { + output.WriteRawTag(10); + output.WriteMessage(HloModuleProto); + } + if (entryFuncAttrs_ != null) { + output.WriteRawTag(18); + output.WriteMessage(EntryFuncAttrs); + } + if (ObjFile.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(ObjFile); + } + if (MlirModule.Length != 0) { + output.WriteRawTag(34); + output.WriteString(MlirModule); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (hloModuleProto_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(HloModuleProto); + } + if (entryFuncAttrs_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(EntryFuncAttrs); + } + if (ObjFile.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(ObjFile); + } + if (MlirModule.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MlirModule); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(XlaRuntimeExecutableProto other) { + if (other == null) { + return; + } + if (other.hloModuleProto_ != null) { + if (hloModuleProto_ == null) { + HloModuleProto = new global::Xla.HloModuleProto(); + } + HloModuleProto.MergeFrom(other.HloModuleProto); + } + if (other.entryFuncAttrs_ != null) { + if (entryFuncAttrs_ == null) { + EntryFuncAttrs = new global::Xla.EntryFunctionAttributes(); + } + EntryFuncAttrs.MergeFrom(other.EntryFuncAttrs); + } + if (other.ObjFile.Length != 0) { + ObjFile = other.ObjFile; + } + if (other.MlirModule.Length != 0) { + MlirModule = other.MlirModule; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (hloModuleProto_ == null) { + HloModuleProto = new global::Xla.HloModuleProto(); + } + input.ReadMessage(HloModuleProto); + break; + } + case 18: { + if (entryFuncAttrs_ == null) { + EntryFuncAttrs = new global::Xla.EntryFunctionAttributes(); + } + input.ReadMessage(EntryFuncAttrs); + break; + } + case 26: { + ObjFile = input.ReadBytes(); + break; + } + case 34: { + MlirModule = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (hloModuleProto_ == null) { + HloModuleProto = new global::Xla.HloModuleProto(); + } + input.ReadMessage(HloModuleProto); + break; + } + case 18: { + if (entryFuncAttrs_ == null) { + EntryFuncAttrs = new global::Xla.EntryFunctionAttributes(); + } + input.ReadMessage(EntryFuncAttrs); + break; + } + case 26: { + ObjFile = input.ReadBytes(); + break; + } + case 34: { + MlirModule = input.ReadString(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs new file mode 100644 index 000000000..aa91a6256 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs @@ -0,0 +1,26 @@ +namespace Tensorflow +{ + /// + /// In order for a object to be serialized to and from MetaGraphDef, + /// the class must implement to_proto() and from_proto() methods + /// + public interface IProtoBuf + { + string Name { get; } + + /// + /// Converts a `Variable` to a `VariableDef` protocol buffer. + /// + /// + /// + TProtoDef to_proto(string export_scope); + + /// + /// Returns a `Variable` object created from `variable_def`. + /// + /// + /// + /// + TDef from_proto(TProtoDef proto, string import_scope); + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/Iterator.cs b/src/TensorFlowNET.Core/Protobuf/Iterator.cs new file mode 100644 index 000000000..bc06fae1d --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Iterator.cs @@ -0,0 +1,205 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/iterator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/iterator.proto + public static partial class IteratorReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/iterator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static IteratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2l0ZXJhdG9yLnByb3RvEgp0", + "ZW5zb3JmbG93IjYKFUl0ZXJhdG9yU3RhdGVNZXRhZGF0YRIPCgd2ZXJzaW9u", + "GAEgASgJEgwKBGtleXMYAiADKAlCaQoTb3JnLnRlbnNvcmZsb3cudXRpbEIO", + "SXRlcmF0b3JQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.IteratorStateMetadata), global::Tensorflow.IteratorStateMetadata.Parser, new[]{ "Version", "Keys" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the metadata for an iterator's state stored + /// as a Variant tensor. + /// + public sealed partial class IteratorStateMetadata : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new IteratorStateMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.IteratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IteratorStateMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IteratorStateMetadata(IteratorStateMetadata other) : this() { + version_ = other.version_; + keys_ = other.keys_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IteratorStateMetadata Clone() { + return new IteratorStateMetadata(this); + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 1; + private string version_ = ""; + /// + /// A user-specified version string. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Version { + get { return version_; } + set { + version_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "keys" field. + public const int KeysFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_keys_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField keys_ = new pbc::RepeatedField(); + /// + /// Keys for tensors in the VariantTensorDataProto. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Keys { + get { return keys_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as IteratorStateMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(IteratorStateMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Version != other.Version) return false; + if(!keys_.Equals(other.keys_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Version.Length != 0) hash ^= Version.GetHashCode(); + hash ^= keys_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Version.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Version); + } + keys_.WriteTo(output, _repeated_keys_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Version.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Version); + } + size += keys_.CalculateSize(_repeated_keys_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(IteratorStateMetadata other) { + if (other == null) { + return; + } + if (other.Version.Length != 0) { + Version = other.Version; + } + keys_.Add(other.keys_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Version = input.ReadString(); + break; + } + case 18: { + keys_.AddEntriesFrom(input, _repeated_keys_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/KernelDef.cs b/src/TensorFlowNET.Core/Protobuf/KernelDef.cs new file mode 100644 index 000000000..06928ad44 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/KernelDef.cs @@ -0,0 +1,857 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/kernel_def.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/kernel_def.proto + public static partial class KernelDefReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/kernel_def.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static KernelDefReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2tlcm5lbF9kZWYucHJvdG8S", + "CnRlbnNvcmZsb3caKnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvYXR0cl92", + "YWx1ZS5wcm90byLvAQoJS2VybmVsRGVmEgoKAm9wGAEgASgJEhMKC2Rldmlj", + "ZV90eXBlGAIgASgJEjgKCmNvbnN0cmFpbnQYAyADKAsyJC50ZW5zb3JmbG93", + "Lktlcm5lbERlZi5BdHRyQ29uc3RyYWludBIXCg9ob3N0X21lbW9yeV9hcmcY", + "BCADKAkSDQoFbGFiZWwYBSABKAkSEAoIcHJpb3JpdHkYBiABKAUaTQoOQXR0", + "ckNvbnN0cmFpbnQSDAoEbmFtZRgBIAEoCRItCg5hbGxvd2VkX3ZhbHVlcxgC", + "IAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjMKCktlcm5lbExpc3QSJQoG", + "a2VybmVsGAEgAygLMhUudGVuc29yZmxvdy5LZXJuZWxEZWZCgwEKGG9yZy50", + "ZW5zb3JmbG93LmZyYW1ld29ya0IPS2VybmVsRGVmUHJvdG9zUAFaUWdpdGh1", + "Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29y", + "ZS9mcmFtZXdvcmsva2VybmVsX2RlZl9nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.KernelDef), global::Tensorflow.KernelDef.Parser, new[]{ "Op", "DeviceType", "Constraint", "HostMemoryArg", "Label", "Priority" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.KernelDef.Types.AttrConstraint), global::Tensorflow.KernelDef.Types.AttrConstraint.Parser, new[]{ "Name", "AllowedValues" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.KernelList), global::Tensorflow.KernelList.Parser, new[]{ "Kernel" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class KernelDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KernelDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.KernelDefReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelDef(KernelDef other) : this() { + op_ = other.op_; + deviceType_ = other.deviceType_; + constraint_ = other.constraint_.Clone(); + hostMemoryArg_ = other.hostMemoryArg_.Clone(); + label_ = other.label_; + priority_ = other.priority_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelDef Clone() { + return new KernelDef(this); + } + + /// Field number for the "op" field. + public const int OpFieldNumber = 1; + private string op_ = ""; + /// + /// Must match the name of an Op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Op { + get { return op_; } + set { + op_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "device_type" field. + public const int DeviceTypeFieldNumber = 2; + private string deviceType_ = ""; + /// + /// Type of device this kernel runs on. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DeviceType { + get { return deviceType_; } + set { + deviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "constraint" field. + public const int ConstraintFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_constraint_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.KernelDef.Types.AttrConstraint.Parser); + private readonly pbc::RepeatedField constraint_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Constraint { + get { return constraint_; } + } + + /// Field number for the "host_memory_arg" field. + public const int HostMemoryArgFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_hostMemoryArg_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField hostMemoryArg_ = new pbc::RepeatedField(); + /// + /// Names of the Op's input_/output_args that reside in host memory + /// instead of device memory. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField HostMemoryArg { + get { return hostMemoryArg_; } + } + + /// Field number for the "label" field. + public const int LabelFieldNumber = 5; + private string label_ = ""; + /// + /// This allows experimental kernels to be registered for an op that + /// won't be used unless the user specifies a "_kernel" attr with + /// value matching this. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Label { + get { return label_; } + set { + label_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "priority" field. + public const int PriorityFieldNumber = 6; + private int priority_; + /// + /// Prioritization of kernel amongst different devices. By default we assume + /// priority is 0. The higher the priority the better. By default (i.e. if + /// this is not set), we prefer GPU kernels over CPU. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Priority { + get { return priority_; } + set { + priority_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KernelDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KernelDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Op != other.Op) return false; + if (DeviceType != other.DeviceType) return false; + if(!constraint_.Equals(other.constraint_)) return false; + if(!hostMemoryArg_.Equals(other.hostMemoryArg_)) return false; + if (Label != other.Label) return false; + if (Priority != other.Priority) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Op.Length != 0) hash ^= Op.GetHashCode(); + if (DeviceType.Length != 0) hash ^= DeviceType.GetHashCode(); + hash ^= constraint_.GetHashCode(); + hash ^= hostMemoryArg_.GetHashCode(); + if (Label.Length != 0) hash ^= Label.GetHashCode(); + if (Priority != 0) hash ^= Priority.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Op.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Op); + } + if (DeviceType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DeviceType); + } + constraint_.WriteTo(output, _repeated_constraint_codec); + hostMemoryArg_.WriteTo(output, _repeated_hostMemoryArg_codec); + if (Label.Length != 0) { + output.WriteRawTag(42); + output.WriteString(Label); + } + if (Priority != 0) { + output.WriteRawTag(48); + output.WriteInt32(Priority); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Op.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Op); + } + if (DeviceType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DeviceType); + } + constraint_.WriteTo(ref output, _repeated_constraint_codec); + hostMemoryArg_.WriteTo(ref output, _repeated_hostMemoryArg_codec); + if (Label.Length != 0) { + output.WriteRawTag(42); + output.WriteString(Label); + } + if (Priority != 0) { + output.WriteRawTag(48); + output.WriteInt32(Priority); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Op.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Op); + } + if (DeviceType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DeviceType); + } + size += constraint_.CalculateSize(_repeated_constraint_codec); + size += hostMemoryArg_.CalculateSize(_repeated_hostMemoryArg_codec); + if (Label.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Label); + } + if (Priority != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Priority); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KernelDef other) { + if (other == null) { + return; + } + if (other.Op.Length != 0) { + Op = other.Op; + } + if (other.DeviceType.Length != 0) { + DeviceType = other.DeviceType; + } + constraint_.Add(other.constraint_); + hostMemoryArg_.Add(other.hostMemoryArg_); + if (other.Label.Length != 0) { + Label = other.Label; + } + if (other.Priority != 0) { + Priority = other.Priority; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Op = input.ReadString(); + break; + } + case 18: { + DeviceType = input.ReadString(); + break; + } + case 26: { + constraint_.AddEntriesFrom(input, _repeated_constraint_codec); + break; + } + case 34: { + hostMemoryArg_.AddEntriesFrom(input, _repeated_hostMemoryArg_codec); + break; + } + case 42: { + Label = input.ReadString(); + break; + } + case 48: { + Priority = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Op = input.ReadString(); + break; + } + case 18: { + DeviceType = input.ReadString(); + break; + } + case 26: { + constraint_.AddEntriesFrom(ref input, _repeated_constraint_codec); + break; + } + case 34: { + hostMemoryArg_.AddEntriesFrom(ref input, _repeated_hostMemoryArg_codec); + break; + } + case 42: { + Label = input.ReadString(); + break; + } + case 48: { + Priority = input.ReadInt32(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the KernelDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class AttrConstraint : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttrConstraint()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.KernelDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AttrConstraint() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AttrConstraint(AttrConstraint other) : this() { + name_ = other.name_; + allowedValues_ = other.allowedValues_ != null ? other.allowedValues_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AttrConstraint Clone() { + return new AttrConstraint(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Name of an attr from the Op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "allowed_values" field. + public const int AllowedValuesFieldNumber = 2; + private global::Tensorflow.AttrValue allowedValues_; + /// + /// A list of values that this kernel supports for this attr. + /// Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.AttrValue AllowedValues { + get { return allowedValues_; } + set { + allowedValues_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AttrConstraint); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AttrConstraint other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(AllowedValues, other.AllowedValues)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (allowedValues_ != null) hash ^= AllowedValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (allowedValues_ != null) { + output.WriteRawTag(18); + output.WriteMessage(AllowedValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (allowedValues_ != null) { + output.WriteRawTag(18); + output.WriteMessage(AllowedValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (allowedValues_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AllowedValues); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AttrConstraint other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.allowedValues_ != null) { + if (allowedValues_ == null) { + AllowedValues = new global::Tensorflow.AttrValue(); + } + AllowedValues.MergeFrom(other.AllowedValues); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (allowedValues_ == null) { + AllowedValues = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(AllowedValues); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (allowedValues_ == null) { + AllowedValues = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(AllowedValues); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// A collection of KernelDefs + /// + public sealed partial class KernelList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KernelList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.KernelDefReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelList(KernelList other) : this() { + kernel_ = other.kernel_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KernelList Clone() { + return new KernelList(this); + } + + /// Field number for the "kernel" field. + public const int KernelFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_kernel_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.KernelDef.Parser); + private readonly pbc::RepeatedField kernel_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Kernel { + get { return kernel_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KernelList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KernelList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!kernel_.Equals(other.kernel_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= kernel_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + kernel_.WriteTo(output, _repeated_kernel_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + kernel_.WriteTo(ref output, _repeated_kernel_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += kernel_.CalculateSize(_repeated_kernel_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KernelList other) { + if (other == null) { + return; + } + kernel_.Add(other.kernel_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + kernel_.AddEntriesFrom(input, _repeated_kernel_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + kernel_.AddEntriesFrom(ref input, _repeated_kernel_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/LogMemory.cs b/src/TensorFlowNET.Core/Protobuf/LogMemory.cs new file mode 100644 index 000000000..af16b3122 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/LogMemory.cs @@ -0,0 +1,1882 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/log_memory.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/log_memory.proto + public static partial class LogMemoryReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/log_memory.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static LogMemoryReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2xvZ19tZW1vcnkucHJvdG8S", + "CnRlbnNvcmZsb3caMnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdGVuc29y", + "X2Rlc2NyaXB0aW9uLnByb3RvIjAKDU1lbW9yeUxvZ1N0ZXASDwoHc3RlcF9p", + "ZBgBIAEoAxIOCgZoYW5kbGUYAiABKAkicAoZTWVtb3J5TG9nVGVuc29yQWxs", + "b2NhdGlvbhIPCgdzdGVwX2lkGAEgASgDEhMKC2tlcm5lbF9uYW1lGAIgASgJ", + "Ei0KBnRlbnNvchgDIAEoCzIdLnRlbnNvcmZsb3cuVGVuc29yRGVzY3JpcHRp", + "b24iTAobTWVtb3J5TG9nVGVuc29yRGVhbGxvY2F0aW9uEhUKDWFsbG9jYXRp", + "b25faWQYASABKAMSFgoOYWxsb2NhdG9yX25hbWUYAiABKAkiewoVTWVtb3J5", + "TG9nVGVuc29yT3V0cHV0Eg8KB3N0ZXBfaWQYASABKAMSEwoLa2VybmVsX25h", + "bWUYAiABKAkSDQoFaW5kZXgYAyABKAUSLQoGdGVuc29yGAQgASgLMh0udGVu", + "c29yZmxvdy5UZW5zb3JEZXNjcmlwdGlvbiKLAQoWTWVtb3J5TG9nUmF3QWxs", + "b2NhdGlvbhIPCgdzdGVwX2lkGAEgASgDEhEKCW9wZXJhdGlvbhgCIAEoCRIR", + "CgludW1fYnl0ZXMYAyABKAMSCwoDcHRyGAQgASgEEhUKDWFsbG9jYXRpb25f", + "aWQYBSABKAMSFgoOYWxsb2NhdG9yX25hbWUYBiABKAkifwoYTWVtb3J5TG9n", + "UmF3RGVhbGxvY2F0aW9uEg8KB3N0ZXBfaWQYASABKAMSEQoJb3BlcmF0aW9u", + "GAIgASgJEhUKDWFsbG9jYXRpb25faWQYAyABKAMSFgoOYWxsb2NhdG9yX25h", + "bWUYBCABKAkSEAoIZGVmZXJyZWQYBSABKAhCgwEKGG9yZy50ZW5zb3JmbG93", + "LmZyYW1ld29ya0IPTG9nTWVtb3J5UHJvdG9zUAFaUWdpdGh1Yi5jb20vdGVu", + "c29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdv", + "cmsvbG9nX21lbW9yeV9nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.TensorDescriptionReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogStep), global::Tensorflow.MemoryLogStep.Parser, new[]{ "StepId", "Handle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogTensorAllocation), global::Tensorflow.MemoryLogTensorAllocation.Parser, new[]{ "StepId", "KernelName", "Tensor" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogTensorDeallocation), global::Tensorflow.MemoryLogTensorDeallocation.Parser, new[]{ "AllocationId", "AllocatorName" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogTensorOutput), global::Tensorflow.MemoryLogTensorOutput.Parser, new[]{ "StepId", "KernelName", "Index", "Tensor" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogRawAllocation), global::Tensorflow.MemoryLogRawAllocation.Parser, new[]{ "StepId", "Operation", "NumBytes", "Ptr", "AllocationId", "AllocatorName" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryLogRawDeallocation), global::Tensorflow.MemoryLogRawDeallocation.Parser, new[]{ "StepId", "Operation", "AllocationId", "AllocatorName", "Deferred" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class MemoryLogStep : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogStep()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogStep() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogStep(MemoryLogStep other) : this() { + stepId_ = other.stepId_; + handle_ = other.handle_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogStep Clone() { + return new MemoryLogStep(this); + } + + /// Field number for the "step_id" field. + public const int StepIdFieldNumber = 1; + private long stepId_; + /// + /// Process-unique step id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StepId { + get { return stepId_; } + set { + stepId_ = value; + } + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 2; + private string handle_ = ""; + /// + /// Handle describing the feeds and fetches of the step. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Handle { + get { return handle_; } + set { + handle_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogStep); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogStep other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (StepId != other.StepId) return false; + if (Handle != other.Handle) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (StepId != 0L) hash ^= StepId.GetHashCode(); + if (Handle.Length != 0) hash ^= Handle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Handle.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Handle.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (StepId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StepId); + } + if (Handle.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Handle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogStep other) { + if (other == null) { + return; + } + if (other.StepId != 0L) { + StepId = other.StepId; + } + if (other.Handle.Length != 0) { + Handle = other.Handle; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Handle = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Handle = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class MemoryLogTensorAllocation : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogTensorAllocation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorAllocation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorAllocation(MemoryLogTensorAllocation other) : this() { + stepId_ = other.stepId_; + kernelName_ = other.kernelName_; + tensor_ = other.tensor_ != null ? other.tensor_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorAllocation Clone() { + return new MemoryLogTensorAllocation(this); + } + + /// Field number for the "step_id" field. + public const int StepIdFieldNumber = 1; + private long stepId_; + /// + /// Process-unique step id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StepId { + get { return stepId_; } + set { + stepId_ = value; + } + } + + /// Field number for the "kernel_name" field. + public const int KernelNameFieldNumber = 2; + private string kernelName_ = ""; + /// + /// Name of the kernel making the allocation as set in GraphDef, + /// e.g., "affine2/weights/Assign". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string KernelName { + get { return kernelName_; } + set { + kernelName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tensor" field. + public const int TensorFieldNumber = 3; + private global::Tensorflow.TensorDescription tensor_; + /// + /// Allocated tensor details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorDescription Tensor { + get { return tensor_; } + set { + tensor_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogTensorAllocation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogTensorAllocation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (StepId != other.StepId) return false; + if (KernelName != other.KernelName) return false; + if (!object.Equals(Tensor, other.Tensor)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (StepId != 0L) hash ^= StepId.GetHashCode(); + if (KernelName.Length != 0) hash ^= KernelName.GetHashCode(); + if (tensor_ != null) hash ^= Tensor.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (KernelName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(KernelName); + } + if (tensor_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Tensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (KernelName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(KernelName); + } + if (tensor_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Tensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (StepId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StepId); + } + if (KernelName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(KernelName); + } + if (tensor_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Tensor); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogTensorAllocation other) { + if (other == null) { + return; + } + if (other.StepId != 0L) { + StepId = other.StepId; + } + if (other.KernelName.Length != 0) { + KernelName = other.KernelName; + } + if (other.tensor_ != null) { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + Tensor.MergeFrom(other.Tensor); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + KernelName = input.ReadString(); + break; + } + case 26: { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(Tensor); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + KernelName = input.ReadString(); + break; + } + case 26: { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(Tensor); + break; + } + } + } + } + #endif + + } + + public sealed partial class MemoryLogTensorDeallocation : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogTensorDeallocation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorDeallocation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorDeallocation(MemoryLogTensorDeallocation other) : this() { + allocationId_ = other.allocationId_; + allocatorName_ = other.allocatorName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorDeallocation Clone() { + return new MemoryLogTensorDeallocation(this); + } + + /// Field number for the "allocation_id" field. + public const int AllocationIdFieldNumber = 1; + private long allocationId_; + /// + /// Id of the tensor buffer being deallocated, used to match to a + /// corresponding allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocationId { + get { return allocationId_; } + set { + allocationId_ = value; + } + } + + /// Field number for the "allocator_name" field. + public const int AllocatorNameFieldNumber = 2; + private string allocatorName_ = ""; + /// + /// Name of the allocator used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorName { + get { return allocatorName_; } + set { + allocatorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogTensorDeallocation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogTensorDeallocation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AllocationId != other.AllocationId) return false; + if (AllocatorName != other.AllocatorName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (AllocationId != 0L) hash ^= AllocationId.GetHashCode(); + if (AllocatorName.Length != 0) hash ^= AllocatorName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (AllocationId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(AllocatorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (AllocationId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(AllocatorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (AllocationId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocationId); + } + if (AllocatorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogTensorDeallocation other) { + if (other == null) { + return; + } + if (other.AllocationId != 0L) { + AllocationId = other.AllocationId; + } + if (other.AllocatorName.Length != 0) { + AllocatorName = other.AllocatorName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AllocationId = input.ReadInt64(); + break; + } + case 18: { + AllocatorName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + AllocationId = input.ReadInt64(); + break; + } + case 18: { + AllocatorName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class MemoryLogTensorOutput : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogTensorOutput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorOutput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorOutput(MemoryLogTensorOutput other) : this() { + stepId_ = other.stepId_; + kernelName_ = other.kernelName_; + index_ = other.index_; + tensor_ = other.tensor_ != null ? other.tensor_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogTensorOutput Clone() { + return new MemoryLogTensorOutput(this); + } + + /// Field number for the "step_id" field. + public const int StepIdFieldNumber = 1; + private long stepId_; + /// + /// Process-unique step id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StepId { + get { return stepId_; } + set { + stepId_ = value; + } + } + + /// Field number for the "kernel_name" field. + public const int KernelNameFieldNumber = 2; + private string kernelName_ = ""; + /// + /// Name of the kernel producing an output as set in GraphDef, e.g., + /// "affine2/weights/Assign". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string KernelName { + get { return kernelName_; } + set { + kernelName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "index" field. + public const int IndexFieldNumber = 3; + private int index_; + /// + /// Index of the output being set. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Index { + get { return index_; } + set { + index_ = value; + } + } + + /// Field number for the "tensor" field. + public const int TensorFieldNumber = 4; + private global::Tensorflow.TensorDescription tensor_; + /// + /// Output tensor details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorDescription Tensor { + get { return tensor_; } + set { + tensor_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogTensorOutput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogTensorOutput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (StepId != other.StepId) return false; + if (KernelName != other.KernelName) return false; + if (Index != other.Index) return false; + if (!object.Equals(Tensor, other.Tensor)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (StepId != 0L) hash ^= StepId.GetHashCode(); + if (KernelName.Length != 0) hash ^= KernelName.GetHashCode(); + if (Index != 0) hash ^= Index.GetHashCode(); + if (tensor_ != null) hash ^= Tensor.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (KernelName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(KernelName); + } + if (Index != 0) { + output.WriteRawTag(24); + output.WriteInt32(Index); + } + if (tensor_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Tensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (KernelName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(KernelName); + } + if (Index != 0) { + output.WriteRawTag(24); + output.WriteInt32(Index); + } + if (tensor_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Tensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (StepId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StepId); + } + if (KernelName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(KernelName); + } + if (Index != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Index); + } + if (tensor_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Tensor); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogTensorOutput other) { + if (other == null) { + return; + } + if (other.StepId != 0L) { + StepId = other.StepId; + } + if (other.KernelName.Length != 0) { + KernelName = other.KernelName; + } + if (other.Index != 0) { + Index = other.Index; + } + if (other.tensor_ != null) { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + Tensor.MergeFrom(other.Tensor); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + KernelName = input.ReadString(); + break; + } + case 24: { + Index = input.ReadInt32(); + break; + } + case 34: { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(Tensor); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + KernelName = input.ReadString(); + break; + } + case 24: { + Index = input.ReadInt32(); + break; + } + case 34: { + if (tensor_ == null) { + Tensor = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(Tensor); + break; + } + } + } + } + #endif + + } + + public sealed partial class MemoryLogRawAllocation : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogRawAllocation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawAllocation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawAllocation(MemoryLogRawAllocation other) : this() { + stepId_ = other.stepId_; + operation_ = other.operation_; + numBytes_ = other.numBytes_; + ptr_ = other.ptr_; + allocationId_ = other.allocationId_; + allocatorName_ = other.allocatorName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawAllocation Clone() { + return new MemoryLogRawAllocation(this); + } + + /// Field number for the "step_id" field. + public const int StepIdFieldNumber = 1; + private long stepId_; + /// + /// Process-unique step id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StepId { + get { return stepId_; } + set { + stepId_ = value; + } + } + + /// Field number for the "operation" field. + public const int OperationFieldNumber = 2; + private string operation_ = ""; + /// + /// Name of the operation making the allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Operation { + get { return operation_; } + set { + operation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_bytes" field. + public const int NumBytesFieldNumber = 3; + private long numBytes_; + /// + /// Number of bytes in the allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long NumBytes { + get { return numBytes_; } + set { + numBytes_ = value; + } + } + + /// Field number for the "ptr" field. + public const int PtrFieldNumber = 4; + private ulong ptr_; + /// + /// Address of the allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Ptr { + get { return ptr_; } + set { + ptr_ = value; + } + } + + /// Field number for the "allocation_id" field. + public const int AllocationIdFieldNumber = 5; + private long allocationId_; + /// + /// Id of the tensor buffer being allocated, used to match to a + /// corresponding deallocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocationId { + get { return allocationId_; } + set { + allocationId_ = value; + } + } + + /// Field number for the "allocator_name" field. + public const int AllocatorNameFieldNumber = 6; + private string allocatorName_ = ""; + /// + /// Name of the allocator used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorName { + get { return allocatorName_; } + set { + allocatorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogRawAllocation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogRawAllocation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (StepId != other.StepId) return false; + if (Operation != other.Operation) return false; + if (NumBytes != other.NumBytes) return false; + if (Ptr != other.Ptr) return false; + if (AllocationId != other.AllocationId) return false; + if (AllocatorName != other.AllocatorName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (StepId != 0L) hash ^= StepId.GetHashCode(); + if (Operation.Length != 0) hash ^= Operation.GetHashCode(); + if (NumBytes != 0L) hash ^= NumBytes.GetHashCode(); + if (Ptr != 0UL) hash ^= Ptr.GetHashCode(); + if (AllocationId != 0L) hash ^= AllocationId.GetHashCode(); + if (AllocatorName.Length != 0) hash ^= AllocatorName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Operation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Operation); + } + if (NumBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(NumBytes); + } + if (Ptr != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(Ptr); + } + if (AllocationId != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(AllocatorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Operation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Operation); + } + if (NumBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(NumBytes); + } + if (Ptr != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(Ptr); + } + if (AllocationId != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(AllocatorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (StepId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StepId); + } + if (Operation.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Operation); + } + if (NumBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(NumBytes); + } + if (Ptr != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Ptr); + } + if (AllocationId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocationId); + } + if (AllocatorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogRawAllocation other) { + if (other == null) { + return; + } + if (other.StepId != 0L) { + StepId = other.StepId; + } + if (other.Operation.Length != 0) { + Operation = other.Operation; + } + if (other.NumBytes != 0L) { + NumBytes = other.NumBytes; + } + if (other.Ptr != 0UL) { + Ptr = other.Ptr; + } + if (other.AllocationId != 0L) { + AllocationId = other.AllocationId; + } + if (other.AllocatorName.Length != 0) { + AllocatorName = other.AllocatorName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Operation = input.ReadString(); + break; + } + case 24: { + NumBytes = input.ReadInt64(); + break; + } + case 32: { + Ptr = input.ReadUInt64(); + break; + } + case 40: { + AllocationId = input.ReadInt64(); + break; + } + case 50: { + AllocatorName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Operation = input.ReadString(); + break; + } + case 24: { + NumBytes = input.ReadInt64(); + break; + } + case 32: { + Ptr = input.ReadUInt64(); + break; + } + case 40: { + AllocationId = input.ReadInt64(); + break; + } + case 50: { + AllocatorName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class MemoryLogRawDeallocation : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryLogRawDeallocation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.LogMemoryReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawDeallocation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawDeallocation(MemoryLogRawDeallocation other) : this() { + stepId_ = other.stepId_; + operation_ = other.operation_; + allocationId_ = other.allocationId_; + allocatorName_ = other.allocatorName_; + deferred_ = other.deferred_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryLogRawDeallocation Clone() { + return new MemoryLogRawDeallocation(this); + } + + /// Field number for the "step_id" field. + public const int StepIdFieldNumber = 1; + private long stepId_; + /// + /// Process-unique step id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long StepId { + get { return stepId_; } + set { + stepId_ = value; + } + } + + /// Field number for the "operation" field. + public const int OperationFieldNumber = 2; + private string operation_ = ""; + /// + /// Name of the operation making the deallocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Operation { + get { return operation_; } + set { + operation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "allocation_id" field. + public const int AllocationIdFieldNumber = 3; + private long allocationId_; + /// + /// Id of the tensor buffer being deallocated, used to match to a + /// corresponding allocation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocationId { + get { return allocationId_; } + set { + allocationId_ = value; + } + } + + /// Field number for the "allocator_name" field. + public const int AllocatorNameFieldNumber = 4; + private string allocatorName_ = ""; + /// + /// Name of the allocator used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorName { + get { return allocatorName_; } + set { + allocatorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "deferred" field. + public const int DeferredFieldNumber = 5; + private bool deferred_; + /// + /// True if the deallocation is queued and will be performed later, + /// e.g. for GPU lazy freeing of buffers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Deferred { + get { return deferred_; } + set { + deferred_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryLogRawDeallocation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryLogRawDeallocation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (StepId != other.StepId) return false; + if (Operation != other.Operation) return false; + if (AllocationId != other.AllocationId) return false; + if (AllocatorName != other.AllocatorName) return false; + if (Deferred != other.Deferred) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (StepId != 0L) hash ^= StepId.GetHashCode(); + if (Operation.Length != 0) hash ^= Operation.GetHashCode(); + if (AllocationId != 0L) hash ^= AllocationId.GetHashCode(); + if (AllocatorName.Length != 0) hash ^= AllocatorName.GetHashCode(); + if (Deferred != false) hash ^= Deferred.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Operation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Operation); + } + if (AllocationId != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(34); + output.WriteString(AllocatorName); + } + if (Deferred != false) { + output.WriteRawTag(40); + output.WriteBool(Deferred); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (StepId != 0L) { + output.WriteRawTag(8); + output.WriteInt64(StepId); + } + if (Operation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Operation); + } + if (AllocationId != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AllocationId); + } + if (AllocatorName.Length != 0) { + output.WriteRawTag(34); + output.WriteString(AllocatorName); + } + if (Deferred != false) { + output.WriteRawTag(40); + output.WriteBool(Deferred); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (StepId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(StepId); + } + if (Operation.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Operation); + } + if (AllocationId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocationId); + } + if (AllocatorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorName); + } + if (Deferred != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryLogRawDeallocation other) { + if (other == null) { + return; + } + if (other.StepId != 0L) { + StepId = other.StepId; + } + if (other.Operation.Length != 0) { + Operation = other.Operation; + } + if (other.AllocationId != 0L) { + AllocationId = other.AllocationId; + } + if (other.AllocatorName.Length != 0) { + AllocatorName = other.AllocatorName; + } + if (other.Deferred != false) { + Deferred = other.Deferred; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Operation = input.ReadString(); + break; + } + case 24: { + AllocationId = input.ReadInt64(); + break; + } + case 34: { + AllocatorName = input.ReadString(); + break; + } + case 40: { + Deferred = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + StepId = input.ReadInt64(); + break; + } + case 18: { + Operation = input.ReadString(); + break; + } + case 24: { + AllocationId = input.ReadInt64(); + break; + } + case 34: { + AllocatorName = input.ReadString(); + break; + } + case 40: { + Deferred = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs b/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs new file mode 100644 index 000000000..b47599ea9 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs @@ -0,0 +1,495 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/util/memmapped_file_system.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/util/memmapped_file_system.proto + public static partial class MemmappedFileSystemReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/util/memmapped_file_system.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MemmappedFileSystemReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjB0ZW5zb3JmbG93L2NvcmUvdXRpbC9tZW1tYXBwZWRfZmlsZV9zeXN0ZW0u", + "cHJvdG8SCnRlbnNvcmZsb3ciUwojTWVtbWFwcGVkRmlsZVN5c3RlbURpcmVj", + "dG9yeUVsZW1lbnQSDgoGb2Zmc2V0GAEgASgEEgwKBG5hbWUYAiABKAkSDgoG", + "bGVuZ3RoGAMgASgEImAKHE1lbW1hcHBlZEZpbGVTeXN0ZW1EaXJlY3RvcnkS", + "QAoHZWxlbWVudBgBIAMoCzIvLnRlbnNvcmZsb3cuTWVtbWFwcGVkRmlsZVN5", + "c3RlbURpcmVjdG9yeUVsZW1lbnRCA/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemmappedFileSystemDirectoryElement), global::Tensorflow.MemmappedFileSystemDirectoryElement.Parser, new[]{ "Offset", "Name", "Length" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemmappedFileSystemDirectory), global::Tensorflow.MemmappedFileSystemDirectory.Parser, new[]{ "Element" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// A message that describes one region of memmapped file. + /// + public sealed partial class MemmappedFileSystemDirectoryElement : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemmappedFileSystemDirectoryElement()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MemmappedFileSystemReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectoryElement() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectoryElement(MemmappedFileSystemDirectoryElement other) : this() { + offset_ = other.offset_; + name_ = other.name_; + length_ = other.length_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectoryElement Clone() { + return new MemmappedFileSystemDirectoryElement(this); + } + + /// Field number for the "offset" field. + public const int OffsetFieldNumber = 1; + private ulong offset_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Offset { + get { return offset_; } + set { + offset_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 2; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "length" field. + public const int LengthFieldNumber = 3; + private ulong length_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Length { + get { return length_; } + set { + length_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemmappedFileSystemDirectoryElement); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemmappedFileSystemDirectoryElement other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Offset != other.Offset) return false; + if (Name != other.Name) return false; + if (Length != other.Length) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Offset != 0UL) hash ^= Offset.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Length != 0UL) hash ^= Length.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Offset != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(Offset); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (Length != 0UL) { + output.WriteRawTag(24); + output.WriteUInt64(Length); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Offset != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(Offset); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (Length != 0UL) { + output.WriteRawTag(24); + output.WriteUInt64(Length); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Offset != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Offset); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Length != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Length); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemmappedFileSystemDirectoryElement other) { + if (other == null) { + return; + } + if (other.Offset != 0UL) { + Offset = other.Offset; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Length != 0UL) { + Length = other.Length; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Offset = input.ReadUInt64(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + case 24: { + Length = input.ReadUInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Offset = input.ReadUInt64(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + case 24: { + Length = input.ReadUInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// A directory of regions in a memmapped file. + /// + public sealed partial class MemmappedFileSystemDirectory : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemmappedFileSystemDirectory()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MemmappedFileSystemReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectory() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectory(MemmappedFileSystemDirectory other) : this() { + element_ = other.element_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemmappedFileSystemDirectory Clone() { + return new MemmappedFileSystemDirectory(this); + } + + /// Field number for the "element" field. + public const int ElementFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_element_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.MemmappedFileSystemDirectoryElement.Parser); + private readonly pbc::RepeatedField element_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Element { + get { return element_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemmappedFileSystemDirectory); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemmappedFileSystemDirectory other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!element_.Equals(other.element_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= element_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + element_.WriteTo(output, _repeated_element_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + element_.WriteTo(ref output, _repeated_element_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += element_.CalculateSize(_repeated_element_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemmappedFileSystemDirectory other) { + if (other == null) { + return; + } + element_.Add(other.element_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + element_.AddEntriesFrom(input, _repeated_element_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + element_.AddEntriesFrom(ref input, _repeated_element_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs b/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs new file mode 100644 index 000000000..4cd62e025 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/MetaGraph.cs @@ -0,0 +1,4009 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/meta_graph.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/meta_graph.proto + public static partial class MetaGraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/meta_graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MetaGraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cil0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvbWV0YV9ncmFwaC5wcm90bxIK", + "dGVuc29yZmxvdxoZZ29vZ2xlL3Byb3RvYnVmL2FueS5wcm90bxoldGVuc29y", + "Zmxvdy9jb3JlL2ZyYW1ld29yay9ncmFwaC5wcm90bxomdGVuc29yZmxvdy9j", + "b3JlL2ZyYW1ld29yay9vcF9kZWYucHJvdG8aLHRlbnNvcmZsb3cvY29yZS9m", + "cmFtZXdvcmsvdGVuc29yX3NoYXBlLnByb3RvGiV0ZW5zb3JmbG93L2NvcmUv", + "ZnJhbWV3b3JrL3R5cGVzLnByb3RvGjF0ZW5zb3JmbG93L2NvcmUvcHJvdG9i", + "dWYvc2F2ZWRfb2JqZWN0X2dyYXBoLnByb3RvGiR0ZW5zb3JmbG93L2NvcmUv", + "cHJvdG9idWYvc2F2ZXIucHJvdG8aJXRlbnNvcmZsb3cvY29yZS9wcm90b2J1", + "Zi9zdHJ1Y3QucHJvdG8iqAcKDE1ldGFHcmFwaERlZhI7Cg1tZXRhX2luZm9f", + "ZGVmGAEgASgLMiQudGVuc29yZmxvdy5NZXRhR3JhcGhEZWYuTWV0YUluZm9E", + "ZWYSJwoJZ3JhcGhfZGVmGAIgASgLMhQudGVuc29yZmxvdy5HcmFwaERlZhIn", + "CglzYXZlcl9kZWYYAyABKAsyFC50ZW5zb3JmbG93LlNhdmVyRGVmEkMKDmNv", + "bGxlY3Rpb25fZGVmGAQgAygLMisudGVuc29yZmxvdy5NZXRhR3JhcGhEZWYu", + "Q29sbGVjdGlvbkRlZkVudHJ5EkEKDXNpZ25hdHVyZV9kZWYYBSADKAsyKi50", + "ZW5zb3JmbG93Lk1ldGFHcmFwaERlZi5TaWduYXR1cmVEZWZFbnRyeRIwCg5h", + "c3NldF9maWxlX2RlZhgGIAMoCzIYLnRlbnNvcmZsb3cuQXNzZXRGaWxlRGVm", + "EjYKEG9iamVjdF9ncmFwaF9kZWYYByABKAsyHC50ZW5zb3JmbG93LlNhdmVk", + "T2JqZWN0R3JhcGga9gIKC01ldGFJbmZvRGVmEhoKEm1ldGFfZ3JhcGhfdmVy", + "c2lvbhgBIAEoCRIsChBzdHJpcHBlZF9vcF9saXN0GAIgASgLMhIudGVuc29y", + "Zmxvdy5PcExpc3QSJgoIYW55X2luZm8YAyABKAsyFC5nb29nbGUucHJvdG9i", + "dWYuQW55EgwKBHRhZ3MYBCADKAkSGgoSdGVuc29yZmxvd192ZXJzaW9uGAUg", + "ASgJEh4KFnRlbnNvcmZsb3dfZ2l0X3ZlcnNpb24YBiABKAkSHgoWc3RyaXBw", + "ZWRfZGVmYXVsdF9hdHRycxgHIAEoCBJTChBmdW5jdGlvbl9hbGlhc2VzGAgg", + "AygLMjkudGVuc29yZmxvdy5NZXRhR3JhcGhEZWYuTWV0YUluZm9EZWYuRnVu", + "Y3Rpb25BbGlhc2VzRW50cnkaNgoURnVuY3Rpb25BbGlhc2VzRW50cnkSCwoD", + "a2V5GAEgASgJEg0KBXZhbHVlGAIgASgJOgI4ARpPChJDb2xsZWN0aW9uRGVm", + "RW50cnkSCwoDa2V5GAEgASgJEigKBXZhbHVlGAIgASgLMhkudGVuc29yZmxv", + "dy5Db2xsZWN0aW9uRGVmOgI4ARpNChFTaWduYXR1cmVEZWZFbnRyeRILCgNr", + "ZXkYASABKAkSJwoFdmFsdWUYAiABKAsyGC50ZW5zb3JmbG93LlNpZ25hdHVy", + "ZURlZjoCOAEi3wMKDUNvbGxlY3Rpb25EZWYSNwoJbm9kZV9saXN0GAEgASgL", + "MiIudGVuc29yZmxvdy5Db2xsZWN0aW9uRGVmLk5vZGVMaXN0SAASOQoKYnl0", + "ZXNfbGlzdBgCIAEoCzIjLnRlbnNvcmZsb3cuQ29sbGVjdGlvbkRlZi5CeXRl", + "c0xpc3RIABI5CgppbnQ2NF9saXN0GAMgASgLMiMudGVuc29yZmxvdy5Db2xs", + "ZWN0aW9uRGVmLkludDY0TGlzdEgAEjkKCmZsb2F0X2xpc3QYBCABKAsyIy50", + "ZW5zb3JmbG93LkNvbGxlY3Rpb25EZWYuRmxvYXRMaXN0SAASNQoIYW55X2xp", + "c3QYBSABKAsyIS50ZW5zb3JmbG93LkNvbGxlY3Rpb25EZWYuQW55TGlzdEgA", + "GhkKCE5vZGVMaXN0Eg0KBXZhbHVlGAEgAygJGhoKCUJ5dGVzTGlzdBINCgV2", + "YWx1ZRgBIAMoDBoeCglJbnQ2NExpc3QSEQoFdmFsdWUYASADKANCAhABGh4K", + "CUZsb2F0TGlzdBIRCgV2YWx1ZRgBIAMoAkICEAEaLgoHQW55TGlzdBIjCgV2", + "YWx1ZRgBIAMoCzIULmdvb2dsZS5wcm90b2J1Zi5BbnlCBgoEa2luZCLRAwoK", + "VGVuc29ySW5mbxIOCgRuYW1lGAEgASgJSAASNgoKY29vX3NwYXJzZRgEIAEo", + "CzIgLnRlbnNvcmZsb3cuVGVuc29ySW5mby5Db29TcGFyc2VIABJCChBjb21w", + "b3NpdGVfdGVuc29yGAUgASgLMiYudGVuc29yZmxvdy5UZW5zb3JJbmZvLkNv", + "bXBvc2l0ZVRlbnNvckgAEiMKBWR0eXBlGAIgASgOMhQudGVuc29yZmxvdy5E", + "YXRhVHlwZRIyCgx0ZW5zb3Jfc2hhcGUYAyABKAsyHC50ZW5zb3JmbG93LlRl", + "bnNvclNoYXBlUHJvdG8aZQoJQ29vU3BhcnNlEhoKEnZhbHVlc190ZW5zb3Jf", + "bmFtZRgBIAEoCRIbChNpbmRpY2VzX3RlbnNvcl9uYW1lGAIgASgJEh8KF2Rl", + "bnNlX3NoYXBlX3RlbnNvcl9uYW1lGAMgASgJGmsKD0NvbXBvc2l0ZVRlbnNv", + "chIsCgl0eXBlX3NwZWMYASABKAsyGS50ZW5zb3JmbG93LlR5cGVTcGVjUHJv", + "dG8SKgoKY29tcG9uZW50cxgCIAMoCzIWLnRlbnNvcmZsb3cuVGVuc29ySW5m", + "b0IKCghlbmNvZGluZyKgAgoMU2lnbmF0dXJlRGVmEjQKBmlucHV0cxgBIAMo", + "CzIkLnRlbnNvcmZsb3cuU2lnbmF0dXJlRGVmLklucHV0c0VudHJ5EjYKB291", + "dHB1dHMYAiADKAsyJS50ZW5zb3JmbG93LlNpZ25hdHVyZURlZi5PdXRwdXRz", + "RW50cnkSEwoLbWV0aG9kX25hbWUYAyABKAkaRQoLSW5wdXRzRW50cnkSCwoD", + "a2V5GAEgASgJEiUKBXZhbHVlGAIgASgLMhYudGVuc29yZmxvdy5UZW5zb3JJ", + "bmZvOgI4ARpGCgxPdXRwdXRzRW50cnkSCwoDa2V5GAEgASgJEiUKBXZhbHVl", + "GAIgASgLMhYudGVuc29yZmxvdy5UZW5zb3JJbmZvOgI4ASJNCgxBc3NldEZp", + "bGVEZWYSKwoLdGVuc29yX2luZm8YASABKAsyFi50ZW5zb3JmbG93LlRlbnNv", + "ckluZm8SEAoIZmlsZW5hbWUYAiABKAlChwEKGG9yZy50ZW5zb3JmbG93LmZy", + "YW1ld29ya0IPTWV0YUdyYXBoUHJvdG9zUAFaVWdpdGh1Yi5jb20vdGVuc29y", + "Zmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1Zi9m", + "b3JfY29yZV9wcm90b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Google.Protobuf.WellKnownTypes.AnyReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.SavedObjectGraphReflection.Descriptor, global::Tensorflow.SaverReflection.Descriptor, global::Tensorflow.StructReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MetaGraphDef), global::Tensorflow.MetaGraphDef.Parser, new[]{ "MetaInfoDef", "GraphDef", "SaverDef", "CollectionDef", "SignatureDef", "AssetFileDef", "ObjectGraphDef" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MetaGraphDef.Types.MetaInfoDef), global::Tensorflow.MetaGraphDef.Types.MetaInfoDef.Parser, new[]{ "MetaGraphVersion", "StrippedOpList", "AnyInfo", "Tags", "TensorflowVersion", "TensorflowGitVersion", "StrippedDefaultAttrs", "FunctionAliases" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef), global::Tensorflow.CollectionDef.Parser, new[]{ "NodeList", "BytesList", "Int64List", "FloatList", "AnyList" }, new[]{ "Kind" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.NodeList), global::Tensorflow.CollectionDef.Types.NodeList.Parser, new[]{ "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.BytesList), global::Tensorflow.CollectionDef.Types.BytesList.Parser, new[]{ "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.Int64List), global::Tensorflow.CollectionDef.Types.Int64List.Parser, new[]{ "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.FloatList), global::Tensorflow.CollectionDef.Types.FloatList.Parser, new[]{ "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CollectionDef.Types.AnyList), global::Tensorflow.CollectionDef.Types.AnyList.Parser, new[]{ "Value" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorInfo), global::Tensorflow.TensorInfo.Parser, new[]{ "Name", "CooSparse", "CompositeTensor", "Dtype", "TensorShape" }, new[]{ "Encoding" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorInfo.Types.CooSparse), global::Tensorflow.TensorInfo.Types.CooSparse.Parser, new[]{ "ValuesTensorName", "IndicesTensorName", "DenseShapeTensorName" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorInfo.Types.CompositeTensor), global::Tensorflow.TensorInfo.Types.CompositeTensor.Parser, new[]{ "TypeSpec", "Components" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SignatureDef), global::Tensorflow.SignatureDef.Parser, new[]{ "Inputs", "Outputs", "MethodName" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AssetFileDef), global::Tensorflow.AssetFileDef.Parser, new[]{ "TensorInfo", "Filename" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer containing the following which are necessary to restart + /// training, run inference. It can be used to serialize/de-serialize memory + /// objects necessary for running computation in a graph when crossing the + /// process boundary. It can be used for long term storage of graphs, + /// cross-language execution of graphs, etc. + /// MetaInfoDef + /// GraphDef + /// SaverDef + /// CollectionDef + /// TensorInfo + /// SignatureDef + /// + public sealed partial class MetaGraphDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MetaGraphDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaGraphDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaGraphDef(MetaGraphDef other) : this() { + metaInfoDef_ = other.metaInfoDef_ != null ? other.metaInfoDef_.Clone() : null; + graphDef_ = other.graphDef_ != null ? other.graphDef_.Clone() : null; + saverDef_ = other.saverDef_ != null ? other.saverDef_.Clone() : null; + collectionDef_ = other.collectionDef_.Clone(); + signatureDef_ = other.signatureDef_.Clone(); + assetFileDef_ = other.assetFileDef_.Clone(); + objectGraphDef_ = other.objectGraphDef_ != null ? other.objectGraphDef_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaGraphDef Clone() { + return new MetaGraphDef(this); + } + + /// Field number for the "meta_info_def" field. + public const int MetaInfoDefFieldNumber = 1; + private global::Tensorflow.MetaGraphDef.Types.MetaInfoDef metaInfoDef_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.MetaGraphDef.Types.MetaInfoDef MetaInfoDef { + get { return metaInfoDef_; } + set { + metaInfoDef_ = value; + } + } + + /// Field number for the "graph_def" field. + public const int GraphDefFieldNumber = 2; + private global::Tensorflow.GraphDef graphDef_; + /// + /// GraphDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.GraphDef GraphDef { + get { return graphDef_; } + set { + graphDef_ = value; + } + } + + /// Field number for the "saver_def" field. + public const int SaverDefFieldNumber = 3; + private global::Tensorflow.SaverDef saverDef_; + /// + /// SaverDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SaverDef SaverDef { + get { return saverDef_; } + set { + saverDef_ = value; + } + } + + /// Field number for the "collection_def" field. + public const int CollectionDefFieldNumber = 4; + private static readonly pbc::MapField.Codec _map_collectionDef_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.CollectionDef.Parser), 34); + private readonly pbc::MapField collectionDef_ = new pbc::MapField(); + /// + /// collection_def: Map from collection name to collections. + /// See CollectionDef section for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField CollectionDef { + get { return collectionDef_; } + } + + /// Field number for the "signature_def" field. + public const int SignatureDefFieldNumber = 5; + private static readonly pbc::MapField.Codec _map_signatureDef_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.SignatureDef.Parser), 42); + private readonly pbc::MapField signatureDef_ = new pbc::MapField(); + /// + /// signature_def: Map from user supplied key for a signature to a single + /// SignatureDef. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField SignatureDef { + get { return signatureDef_; } + } + + /// Field number for the "asset_file_def" field. + public const int AssetFileDefFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_assetFileDef_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.AssetFileDef.Parser); + private readonly pbc::RepeatedField assetFileDef_ = new pbc::RepeatedField(); + /// + /// Asset file def to be used with the defined graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AssetFileDef { + get { return assetFileDef_; } + } + + /// Field number for the "object_graph_def" field. + public const int ObjectGraphDefFieldNumber = 7; + private global::Tensorflow.SavedObjectGraph objectGraphDef_; + /// + /// Extra information about the structure of functions and stateful objects. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedObjectGraph ObjectGraphDef { + get { return objectGraphDef_; } + set { + objectGraphDef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MetaGraphDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MetaGraphDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(MetaInfoDef, other.MetaInfoDef)) return false; + if (!object.Equals(GraphDef, other.GraphDef)) return false; + if (!object.Equals(SaverDef, other.SaverDef)) return false; + if (!CollectionDef.Equals(other.CollectionDef)) return false; + if (!SignatureDef.Equals(other.SignatureDef)) return false; + if(!assetFileDef_.Equals(other.assetFileDef_)) return false; + if (!object.Equals(ObjectGraphDef, other.ObjectGraphDef)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (metaInfoDef_ != null) hash ^= MetaInfoDef.GetHashCode(); + if (graphDef_ != null) hash ^= GraphDef.GetHashCode(); + if (saverDef_ != null) hash ^= SaverDef.GetHashCode(); + hash ^= CollectionDef.GetHashCode(); + hash ^= SignatureDef.GetHashCode(); + hash ^= assetFileDef_.GetHashCode(); + if (objectGraphDef_ != null) hash ^= ObjectGraphDef.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (metaInfoDef_ != null) { + output.WriteRawTag(10); + output.WriteMessage(MetaInfoDef); + } + if (graphDef_ != null) { + output.WriteRawTag(18); + output.WriteMessage(GraphDef); + } + if (saverDef_ != null) { + output.WriteRawTag(26); + output.WriteMessage(SaverDef); + } + collectionDef_.WriteTo(output, _map_collectionDef_codec); + signatureDef_.WriteTo(output, _map_signatureDef_codec); + assetFileDef_.WriteTo(output, _repeated_assetFileDef_codec); + if (objectGraphDef_ != null) { + output.WriteRawTag(58); + output.WriteMessage(ObjectGraphDef); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (metaInfoDef_ != null) { + output.WriteRawTag(10); + output.WriteMessage(MetaInfoDef); + } + if (graphDef_ != null) { + output.WriteRawTag(18); + output.WriteMessage(GraphDef); + } + if (saverDef_ != null) { + output.WriteRawTag(26); + output.WriteMessage(SaverDef); + } + collectionDef_.WriteTo(ref output, _map_collectionDef_codec); + signatureDef_.WriteTo(ref output, _map_signatureDef_codec); + assetFileDef_.WriteTo(ref output, _repeated_assetFileDef_codec); + if (objectGraphDef_ != null) { + output.WriteRawTag(58); + output.WriteMessage(ObjectGraphDef); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (metaInfoDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MetaInfoDef); + } + if (graphDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GraphDef); + } + if (saverDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaverDef); + } + size += collectionDef_.CalculateSize(_map_collectionDef_codec); + size += signatureDef_.CalculateSize(_map_signatureDef_codec); + size += assetFileDef_.CalculateSize(_repeated_assetFileDef_codec); + if (objectGraphDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ObjectGraphDef); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MetaGraphDef other) { + if (other == null) { + return; + } + if (other.metaInfoDef_ != null) { + if (metaInfoDef_ == null) { + MetaInfoDef = new global::Tensorflow.MetaGraphDef.Types.MetaInfoDef(); + } + MetaInfoDef.MergeFrom(other.MetaInfoDef); + } + if (other.graphDef_ != null) { + if (graphDef_ == null) { + GraphDef = new global::Tensorflow.GraphDef(); + } + GraphDef.MergeFrom(other.GraphDef); + } + if (other.saverDef_ != null) { + if (saverDef_ == null) { + SaverDef = new global::Tensorflow.SaverDef(); + } + SaverDef.MergeFrom(other.SaverDef); + } + collectionDef_.Add(other.collectionDef_); + signatureDef_.Add(other.signatureDef_); + assetFileDef_.Add(other.assetFileDef_); + if (other.objectGraphDef_ != null) { + if (objectGraphDef_ == null) { + ObjectGraphDef = new global::Tensorflow.SavedObjectGraph(); + } + ObjectGraphDef.MergeFrom(other.ObjectGraphDef); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (metaInfoDef_ == null) { + MetaInfoDef = new global::Tensorflow.MetaGraphDef.Types.MetaInfoDef(); + } + input.ReadMessage(MetaInfoDef); + break; + } + case 18: { + if (graphDef_ == null) { + GraphDef = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(GraphDef); + break; + } + case 26: { + if (saverDef_ == null) { + SaverDef = new global::Tensorflow.SaverDef(); + } + input.ReadMessage(SaverDef); + break; + } + case 34: { + collectionDef_.AddEntriesFrom(input, _map_collectionDef_codec); + break; + } + case 42: { + signatureDef_.AddEntriesFrom(input, _map_signatureDef_codec); + break; + } + case 50: { + assetFileDef_.AddEntriesFrom(input, _repeated_assetFileDef_codec); + break; + } + case 58: { + if (objectGraphDef_ == null) { + ObjectGraphDef = new global::Tensorflow.SavedObjectGraph(); + } + input.ReadMessage(ObjectGraphDef); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (metaInfoDef_ == null) { + MetaInfoDef = new global::Tensorflow.MetaGraphDef.Types.MetaInfoDef(); + } + input.ReadMessage(MetaInfoDef); + break; + } + case 18: { + if (graphDef_ == null) { + GraphDef = new global::Tensorflow.GraphDef(); + } + input.ReadMessage(GraphDef); + break; + } + case 26: { + if (saverDef_ == null) { + SaverDef = new global::Tensorflow.SaverDef(); + } + input.ReadMessage(SaverDef); + break; + } + case 34: { + collectionDef_.AddEntriesFrom(ref input, _map_collectionDef_codec); + break; + } + case 42: { + signatureDef_.AddEntriesFrom(ref input, _map_signatureDef_codec); + break; + } + case 50: { + assetFileDef_.AddEntriesFrom(ref input, _repeated_assetFileDef_codec); + break; + } + case 58: { + if (objectGraphDef_ == null) { + ObjectGraphDef = new global::Tensorflow.SavedObjectGraph(); + } + input.ReadMessage(ObjectGraphDef); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the MetaGraphDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Meta information regarding the graph to be exported. To be used by users + /// of this protocol buffer to encode information regarding their meta graph. + /// + public sealed partial class MetaInfoDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MetaInfoDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaInfoDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaInfoDef(MetaInfoDef other) : this() { + metaGraphVersion_ = other.metaGraphVersion_; + strippedOpList_ = other.strippedOpList_ != null ? other.strippedOpList_.Clone() : null; + anyInfo_ = other.anyInfo_ != null ? other.anyInfo_.Clone() : null; + tags_ = other.tags_.Clone(); + tensorflowVersion_ = other.tensorflowVersion_; + tensorflowGitVersion_ = other.tensorflowGitVersion_; + strippedDefaultAttrs_ = other.strippedDefaultAttrs_; + functionAliases_ = other.functionAliases_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MetaInfoDef Clone() { + return new MetaInfoDef(this); + } + + /// Field number for the "meta_graph_version" field. + public const int MetaGraphVersionFieldNumber = 1; + private string metaGraphVersion_ = ""; + /// + /// User specified Version string. Can be the name of the model and revision, + /// steps this model has been trained to, etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string MetaGraphVersion { + get { return metaGraphVersion_; } + set { + metaGraphVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "stripped_op_list" field. + public const int StrippedOpListFieldNumber = 2; + private global::Tensorflow.OpList strippedOpList_; + /// + /// A copy of the OpDefs used by the producer of this graph_def. + /// Descriptions and Ops not used in graph_def are stripped out. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.OpList StrippedOpList { + get { return strippedOpList_; } + set { + strippedOpList_ = value; + } + } + + /// Field number for the "any_info" field. + public const int AnyInfoFieldNumber = 3; + private global::Google.Protobuf.WellKnownTypes.Any anyInfo_; + /// + /// A serialized protobuf. Can be the time this meta graph is created, or + /// modified, or name of the model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Google.Protobuf.WellKnownTypes.Any AnyInfo { + get { return anyInfo_; } + set { + anyInfo_ = value; + } + } + + /// Field number for the "tags" field. + public const int TagsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_tags_codec + = pb::FieldCodec.ForString(34); + private readonly pbc::RepeatedField tags_ = new pbc::RepeatedField(); + /// + /// User supplied tag(s) on the meta_graph and included graph_def. + /// + /// MetaGraphDefs should be tagged with their capabilities or use-cases. + /// Examples: "train", "serve", "gpu", "tpu", etc. + /// These tags enable loaders to access the MetaGraph(s) appropriate for a + /// specific use-case or runtime environment. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Tags { + get { return tags_; } + } + + /// Field number for the "tensorflow_version" field. + public const int TensorflowVersionFieldNumber = 5; + private string tensorflowVersion_ = ""; + /// + /// The __version__ string of the tensorflow build used to write this graph. + /// This will be populated by the framework, which will overwrite any user + /// supplied value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TensorflowVersion { + get { return tensorflowVersion_; } + set { + tensorflowVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tensorflow_git_version" field. + public const int TensorflowGitVersionFieldNumber = 6; + private string tensorflowGitVersion_ = ""; + /// + /// The __git_version__ string of the tensorflow build used to write this + /// graph. This will be populated by the framework, which will overwrite any + /// user supplied value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TensorflowGitVersion { + get { return tensorflowGitVersion_; } + set { + tensorflowGitVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "stripped_default_attrs" field. + public const int StrippedDefaultAttrsFieldNumber = 7; + private bool strippedDefaultAttrs_; + /// + /// A flag to denote whether default-valued attrs have been stripped from + /// the nodes in this graph_def. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool StrippedDefaultAttrs { + get { return strippedDefaultAttrs_; } + set { + strippedDefaultAttrs_ = value; + } + } + + /// Field number for the "function_aliases" field. + public const int FunctionAliasesFieldNumber = 8; + private static readonly pbc::MapField.Codec _map_functionAliases_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 66); + private readonly pbc::MapField functionAliases_ = new pbc::MapField(); + /// + /// FunctionDef name to aliases mapping. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField FunctionAliases { + get { return functionAliases_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MetaInfoDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MetaInfoDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MetaGraphVersion != other.MetaGraphVersion) return false; + if (!object.Equals(StrippedOpList, other.StrippedOpList)) return false; + if (!object.Equals(AnyInfo, other.AnyInfo)) return false; + if(!tags_.Equals(other.tags_)) return false; + if (TensorflowVersion != other.TensorflowVersion) return false; + if (TensorflowGitVersion != other.TensorflowGitVersion) return false; + if (StrippedDefaultAttrs != other.StrippedDefaultAttrs) return false; + if (!FunctionAliases.Equals(other.FunctionAliases)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (MetaGraphVersion.Length != 0) hash ^= MetaGraphVersion.GetHashCode(); + if (strippedOpList_ != null) hash ^= StrippedOpList.GetHashCode(); + if (anyInfo_ != null) hash ^= AnyInfo.GetHashCode(); + hash ^= tags_.GetHashCode(); + if (TensorflowVersion.Length != 0) hash ^= TensorflowVersion.GetHashCode(); + if (TensorflowGitVersion.Length != 0) hash ^= TensorflowGitVersion.GetHashCode(); + if (StrippedDefaultAttrs != false) hash ^= StrippedDefaultAttrs.GetHashCode(); + hash ^= FunctionAliases.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (MetaGraphVersion.Length != 0) { + output.WriteRawTag(10); + output.WriteString(MetaGraphVersion); + } + if (strippedOpList_ != null) { + output.WriteRawTag(18); + output.WriteMessage(StrippedOpList); + } + if (anyInfo_ != null) { + output.WriteRawTag(26); + output.WriteMessage(AnyInfo); + } + tags_.WriteTo(output, _repeated_tags_codec); + if (TensorflowVersion.Length != 0) { + output.WriteRawTag(42); + output.WriteString(TensorflowVersion); + } + if (TensorflowGitVersion.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TensorflowGitVersion); + } + if (StrippedDefaultAttrs != false) { + output.WriteRawTag(56); + output.WriteBool(StrippedDefaultAttrs); + } + functionAliases_.WriteTo(output, _map_functionAliases_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (MetaGraphVersion.Length != 0) { + output.WriteRawTag(10); + output.WriteString(MetaGraphVersion); + } + if (strippedOpList_ != null) { + output.WriteRawTag(18); + output.WriteMessage(StrippedOpList); + } + if (anyInfo_ != null) { + output.WriteRawTag(26); + output.WriteMessage(AnyInfo); + } + tags_.WriteTo(ref output, _repeated_tags_codec); + if (TensorflowVersion.Length != 0) { + output.WriteRawTag(42); + output.WriteString(TensorflowVersion); + } + if (TensorflowGitVersion.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TensorflowGitVersion); + } + if (StrippedDefaultAttrs != false) { + output.WriteRawTag(56); + output.WriteBool(StrippedDefaultAttrs); + } + functionAliases_.WriteTo(ref output, _map_functionAliases_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (MetaGraphVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MetaGraphVersion); + } + if (strippedOpList_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(StrippedOpList); + } + if (anyInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AnyInfo); + } + size += tags_.CalculateSize(_repeated_tags_codec); + if (TensorflowVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorflowVersion); + } + if (TensorflowGitVersion.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorflowGitVersion); + } + if (StrippedDefaultAttrs != false) { + size += 1 + 1; + } + size += functionAliases_.CalculateSize(_map_functionAliases_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MetaInfoDef other) { + if (other == null) { + return; + } + if (other.MetaGraphVersion.Length != 0) { + MetaGraphVersion = other.MetaGraphVersion; + } + if (other.strippedOpList_ != null) { + if (strippedOpList_ == null) { + StrippedOpList = new global::Tensorflow.OpList(); + } + StrippedOpList.MergeFrom(other.StrippedOpList); + } + if (other.anyInfo_ != null) { + if (anyInfo_ == null) { + AnyInfo = new global::Google.Protobuf.WellKnownTypes.Any(); + } + AnyInfo.MergeFrom(other.AnyInfo); + } + tags_.Add(other.tags_); + if (other.TensorflowVersion.Length != 0) { + TensorflowVersion = other.TensorflowVersion; + } + if (other.TensorflowGitVersion.Length != 0) { + TensorflowGitVersion = other.TensorflowGitVersion; + } + if (other.StrippedDefaultAttrs != false) { + StrippedDefaultAttrs = other.StrippedDefaultAttrs; + } + functionAliases_.Add(other.functionAliases_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + MetaGraphVersion = input.ReadString(); + break; + } + case 18: { + if (strippedOpList_ == null) { + StrippedOpList = new global::Tensorflow.OpList(); + } + input.ReadMessage(StrippedOpList); + break; + } + case 26: { + if (anyInfo_ == null) { + AnyInfo = new global::Google.Protobuf.WellKnownTypes.Any(); + } + input.ReadMessage(AnyInfo); + break; + } + case 34: { + tags_.AddEntriesFrom(input, _repeated_tags_codec); + break; + } + case 42: { + TensorflowVersion = input.ReadString(); + break; + } + case 50: { + TensorflowGitVersion = input.ReadString(); + break; + } + case 56: { + StrippedDefaultAttrs = input.ReadBool(); + break; + } + case 66: { + functionAliases_.AddEntriesFrom(input, _map_functionAliases_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + MetaGraphVersion = input.ReadString(); + break; + } + case 18: { + if (strippedOpList_ == null) { + StrippedOpList = new global::Tensorflow.OpList(); + } + input.ReadMessage(StrippedOpList); + break; + } + case 26: { + if (anyInfo_ == null) { + AnyInfo = new global::Google.Protobuf.WellKnownTypes.Any(); + } + input.ReadMessage(AnyInfo); + break; + } + case 34: { + tags_.AddEntriesFrom(ref input, _repeated_tags_codec); + break; + } + case 42: { + TensorflowVersion = input.ReadString(); + break; + } + case 50: { + TensorflowGitVersion = input.ReadString(); + break; + } + case 56: { + StrippedDefaultAttrs = input.ReadBool(); + break; + } + case 66: { + functionAliases_.AddEntriesFrom(ref input, _map_functionAliases_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// CollectionDef should cover most collections. + /// To add a user-defined collection, do one of the following: + /// 1. For simple data types, such as string, int, float: + /// tf.add_to_collection("your_collection_name", your_simple_value) + /// strings will be stored as bytes_list. + /// + /// 2. For Protobuf types, there are three ways to add them: + /// 1) tf.add_to_collection("your_collection_name", + /// your_proto.SerializeToString()) + /// + /// collection_def { + /// key: "user_defined_bytes_collection" + /// value { + /// bytes_list { + /// value: "queue_name: \"test_queue\"\n" + /// } + /// } + /// } + /// + /// or + /// + /// 2) tf.add_to_collection("your_collection_name", str(your_proto)) + /// + /// collection_def { + /// key: "user_defined_string_collection" + /// value { + /// bytes_list { + /// value: "\n\ntest_queue" + /// } + /// } + /// } + /// + /// or + /// + /// 3) any_buf = any_pb2.Any() + /// tf.add_to_collection("your_collection_name", + /// any_buf.Pack(your_proto)) + /// + /// collection_def { + /// key: "user_defined_any_collection" + /// value { + /// any_list { + /// value { + /// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" + /// value: "\n\ntest_queue" + /// } + /// } + /// } + /// } + /// + /// 3. For Python objects, implement to_proto() and from_proto(), and register + /// them in the following manner: + /// ops.register_proto_function("your_collection_name", + /// proto_type, + /// to_proto=YourPythonObject.to_proto, + /// from_proto=YourPythonObject.from_proto) + /// These functions will be invoked to serialize and de-serialize the + /// collection. For example, + /// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, + /// proto_type=variable_pb2.VariableDef, + /// to_proto=Variable.to_proto, + /// from_proto=Variable.from_proto) + /// + public sealed partial class CollectionDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CollectionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CollectionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CollectionDef(CollectionDef other) : this() { + switch (other.KindCase) { + case KindOneofCase.NodeList: + NodeList = other.NodeList.Clone(); + break; + case KindOneofCase.BytesList: + BytesList = other.BytesList.Clone(); + break; + case KindOneofCase.Int64List: + Int64List = other.Int64List.Clone(); + break; + case KindOneofCase.FloatList: + FloatList = other.FloatList.Clone(); + break; + case KindOneofCase.AnyList: + AnyList = other.AnyList.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CollectionDef Clone() { + return new CollectionDef(this); + } + + /// Field number for the "node_list" field. + public const int NodeListFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CollectionDef.Types.NodeList NodeList { + get { return kindCase_ == KindOneofCase.NodeList ? (global::Tensorflow.CollectionDef.Types.NodeList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.NodeList; + } + } + + /// Field number for the "bytes_list" field. + public const int BytesListFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CollectionDef.Types.BytesList BytesList { + get { return kindCase_ == KindOneofCase.BytesList ? (global::Tensorflow.CollectionDef.Types.BytesList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.BytesList; + } + } + + /// Field number for the "int64_list" field. + public const int Int64ListFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CollectionDef.Types.Int64List Int64List { + get { return kindCase_ == KindOneofCase.Int64List ? (global::Tensorflow.CollectionDef.Types.Int64List) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Int64List; + } + } + + /// Field number for the "float_list" field. + public const int FloatListFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CollectionDef.Types.FloatList FloatList { + get { return kindCase_ == KindOneofCase.FloatList ? (global::Tensorflow.CollectionDef.Types.FloatList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.FloatList; + } + } + + /// Field number for the "any_list" field. + public const int AnyListFieldNumber = 5; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CollectionDef.Types.AnyList AnyList { + get { return kindCase_ == KindOneofCase.AnyList ? (global::Tensorflow.CollectionDef.Types.AnyList) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.AnyList; + } + } + + private object kind_; + /// Enum of possible cases for the "kind" oneof. + public enum KindOneofCase { + None = 0, + NodeList = 1, + BytesList = 2, + Int64List = 3, + FloatList = 4, + AnyList = 5, + } + private KindOneofCase kindCase_ = KindOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KindOneofCase KindCase { + get { return kindCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearKind() { + kindCase_ = KindOneofCase.None; + kind_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CollectionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CollectionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(NodeList, other.NodeList)) return false; + if (!object.Equals(BytesList, other.BytesList)) return false; + if (!object.Equals(Int64List, other.Int64List)) return false; + if (!object.Equals(FloatList, other.FloatList)) return false; + if (!object.Equals(AnyList, other.AnyList)) return false; + if (KindCase != other.KindCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (kindCase_ == KindOneofCase.NodeList) hash ^= NodeList.GetHashCode(); + if (kindCase_ == KindOneofCase.BytesList) hash ^= BytesList.GetHashCode(); + if (kindCase_ == KindOneofCase.Int64List) hash ^= Int64List.GetHashCode(); + if (kindCase_ == KindOneofCase.FloatList) hash ^= FloatList.GetHashCode(); + if (kindCase_ == KindOneofCase.AnyList) hash ^= AnyList.GetHashCode(); + hash ^= (int) kindCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (kindCase_ == KindOneofCase.NodeList) { + output.WriteRawTag(10); + output.WriteMessage(NodeList); + } + if (kindCase_ == KindOneofCase.BytesList) { + output.WriteRawTag(18); + output.WriteMessage(BytesList); + } + if (kindCase_ == KindOneofCase.Int64List) { + output.WriteRawTag(26); + output.WriteMessage(Int64List); + } + if (kindCase_ == KindOneofCase.FloatList) { + output.WriteRawTag(34); + output.WriteMessage(FloatList); + } + if (kindCase_ == KindOneofCase.AnyList) { + output.WriteRawTag(42); + output.WriteMessage(AnyList); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (kindCase_ == KindOneofCase.NodeList) { + output.WriteRawTag(10); + output.WriteMessage(NodeList); + } + if (kindCase_ == KindOneofCase.BytesList) { + output.WriteRawTag(18); + output.WriteMessage(BytesList); + } + if (kindCase_ == KindOneofCase.Int64List) { + output.WriteRawTag(26); + output.WriteMessage(Int64List); + } + if (kindCase_ == KindOneofCase.FloatList) { + output.WriteRawTag(34); + output.WriteMessage(FloatList); + } + if (kindCase_ == KindOneofCase.AnyList) { + output.WriteRawTag(42); + output.WriteMessage(AnyList); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (kindCase_ == KindOneofCase.NodeList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(NodeList); + } + if (kindCase_ == KindOneofCase.BytesList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BytesList); + } + if (kindCase_ == KindOneofCase.Int64List) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Int64List); + } + if (kindCase_ == KindOneofCase.FloatList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatList); + } + if (kindCase_ == KindOneofCase.AnyList) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AnyList); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CollectionDef other) { + if (other == null) { + return; + } + switch (other.KindCase) { + case KindOneofCase.NodeList: + if (NodeList == null) { + NodeList = new global::Tensorflow.CollectionDef.Types.NodeList(); + } + NodeList.MergeFrom(other.NodeList); + break; + case KindOneofCase.BytesList: + if (BytesList == null) { + BytesList = new global::Tensorflow.CollectionDef.Types.BytesList(); + } + BytesList.MergeFrom(other.BytesList); + break; + case KindOneofCase.Int64List: + if (Int64List == null) { + Int64List = new global::Tensorflow.CollectionDef.Types.Int64List(); + } + Int64List.MergeFrom(other.Int64List); + break; + case KindOneofCase.FloatList: + if (FloatList == null) { + FloatList = new global::Tensorflow.CollectionDef.Types.FloatList(); + } + FloatList.MergeFrom(other.FloatList); + break; + case KindOneofCase.AnyList: + if (AnyList == null) { + AnyList = new global::Tensorflow.CollectionDef.Types.AnyList(); + } + AnyList.MergeFrom(other.AnyList); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.CollectionDef.Types.NodeList subBuilder = new global::Tensorflow.CollectionDef.Types.NodeList(); + if (kindCase_ == KindOneofCase.NodeList) { + subBuilder.MergeFrom(NodeList); + } + input.ReadMessage(subBuilder); + NodeList = subBuilder; + break; + } + case 18: { + global::Tensorflow.CollectionDef.Types.BytesList subBuilder = new global::Tensorflow.CollectionDef.Types.BytesList(); + if (kindCase_ == KindOneofCase.BytesList) { + subBuilder.MergeFrom(BytesList); + } + input.ReadMessage(subBuilder); + BytesList = subBuilder; + break; + } + case 26: { + global::Tensorflow.CollectionDef.Types.Int64List subBuilder = new global::Tensorflow.CollectionDef.Types.Int64List(); + if (kindCase_ == KindOneofCase.Int64List) { + subBuilder.MergeFrom(Int64List); + } + input.ReadMessage(subBuilder); + Int64List = subBuilder; + break; + } + case 34: { + global::Tensorflow.CollectionDef.Types.FloatList subBuilder = new global::Tensorflow.CollectionDef.Types.FloatList(); + if (kindCase_ == KindOneofCase.FloatList) { + subBuilder.MergeFrom(FloatList); + } + input.ReadMessage(subBuilder); + FloatList = subBuilder; + break; + } + case 42: { + global::Tensorflow.CollectionDef.Types.AnyList subBuilder = new global::Tensorflow.CollectionDef.Types.AnyList(); + if (kindCase_ == KindOneofCase.AnyList) { + subBuilder.MergeFrom(AnyList); + } + input.ReadMessage(subBuilder); + AnyList = subBuilder; + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + global::Tensorflow.CollectionDef.Types.NodeList subBuilder = new global::Tensorflow.CollectionDef.Types.NodeList(); + if (kindCase_ == KindOneofCase.NodeList) { + subBuilder.MergeFrom(NodeList); + } + input.ReadMessage(subBuilder); + NodeList = subBuilder; + break; + } + case 18: { + global::Tensorflow.CollectionDef.Types.BytesList subBuilder = new global::Tensorflow.CollectionDef.Types.BytesList(); + if (kindCase_ == KindOneofCase.BytesList) { + subBuilder.MergeFrom(BytesList); + } + input.ReadMessage(subBuilder); + BytesList = subBuilder; + break; + } + case 26: { + global::Tensorflow.CollectionDef.Types.Int64List subBuilder = new global::Tensorflow.CollectionDef.Types.Int64List(); + if (kindCase_ == KindOneofCase.Int64List) { + subBuilder.MergeFrom(Int64List); + } + input.ReadMessage(subBuilder); + Int64List = subBuilder; + break; + } + case 34: { + global::Tensorflow.CollectionDef.Types.FloatList subBuilder = new global::Tensorflow.CollectionDef.Types.FloatList(); + if (kindCase_ == KindOneofCase.FloatList) { + subBuilder.MergeFrom(FloatList); + } + input.ReadMessage(subBuilder); + FloatList = subBuilder; + break; + } + case 42: { + global::Tensorflow.CollectionDef.Types.AnyList subBuilder = new global::Tensorflow.CollectionDef.Types.AnyList(); + if (kindCase_ == KindOneofCase.AnyList) { + subBuilder.MergeFrom(AnyList); + } + input.ReadMessage(subBuilder); + AnyList = subBuilder; + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the CollectionDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// NodeList is used for collecting nodes in graph. For example + /// collection_def { + /// key: "summaries" + /// value { + /// node_list { + /// value: "input_producer/ScalarSummary:0" + /// value: "shuffle_batch/ScalarSummary:0" + /// value: "ImageSummary:0" + /// } + /// } + /// + public sealed partial class NodeList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeList(NodeList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeList Clone() { + return new NodeList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as NodeList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(NodeList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(NodeList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + } + + /// + /// BytesList is used for collecting strings and serialized protobufs. For + /// example: + /// collection_def { + /// key: "trainable_variables" + /// value { + /// bytes_list { + /// value: "\n\017conv1/weights:0\022\024conv1/weights/Assign + /// \032\024conv1/weights/read:0" + /// value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 + /// \023conv1/biases/read:0" + /// } + /// } + /// } + /// + public sealed partial class BytesList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BytesList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BytesList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BytesList(BytesList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BytesList Clone() { + return new BytesList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForBytes(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BytesList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BytesList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BytesList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Int64List is used for collecting int, int64 and long values. + /// + public sealed partial class Int64List : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Int64List()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Int64List() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Int64List(Int64List other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Int64List Clone() { + return new Int64List(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Int64List); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Int64List other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Int64List other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + } + + /// + /// FloatList is used for collecting float values. + /// + public sealed partial class FloatList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FloatList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FloatList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FloatList(FloatList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FloatList Clone() { + return new FloatList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FloatList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FloatList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FloatList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 13: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + } + + /// + /// AnyList is used for collecting Any protos. + /// + public sealed partial class AnyList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AnyList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.CollectionDef.Descriptor.NestedTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AnyList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AnyList(AnyList other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AnyList Clone() { + return new AnyList(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::Google.Protobuf.WellKnownTypes.Any.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AnyList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AnyList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AnyList other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Information about a Tensor necessary for feeding or retrieval. + /// + public sealed partial class TensorInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorInfo(TensorInfo other) : this() { + dtype_ = other.dtype_; + tensorShape_ = other.tensorShape_ != null ? other.tensorShape_.Clone() : null; + switch (other.EncodingCase) { + case EncodingOneofCase.Name: + Name = other.Name; + break; + case EncodingOneofCase.CooSparse: + CooSparse = other.CooSparse.Clone(); + break; + case EncodingOneofCase.CompositeTensor: + CompositeTensor = other.CompositeTensor.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorInfo Clone() { + return new TensorInfo(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + /// + /// For dense `Tensor`s, the name of the tensor in the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return encodingCase_ == EncodingOneofCase.Name ? (string) encoding_ : ""; } + set { + encoding_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + encodingCase_ = EncodingOneofCase.Name; + } + } + + /// Field number for the "coo_sparse" field. + public const int CooSparseFieldNumber = 4; + /// + /// There are many possible encodings of sparse matrices + /// (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow + /// uses only the COO encoding. This is supported and documented in the + /// SparseTensor Python class. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorInfo.Types.CooSparse CooSparse { + get { return encodingCase_ == EncodingOneofCase.CooSparse ? (global::Tensorflow.TensorInfo.Types.CooSparse) encoding_ : null; } + set { + encoding_ = value; + encodingCase_ = value == null ? EncodingOneofCase.None : EncodingOneofCase.CooSparse; + } + } + + /// Field number for the "composite_tensor" field. + public const int CompositeTensorFieldNumber = 5; + /// + /// Generic encoding for CompositeTensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorInfo.Types.CompositeTensor CompositeTensor { + get { return encodingCase_ == EncodingOneofCase.CompositeTensor ? (global::Tensorflow.TensorInfo.Types.CompositeTensor) encoding_ : null; } + set { + encoding_ = value; + encodingCase_ = value == null ? EncodingOneofCase.None : EncodingOneofCase.CompositeTensor; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 2; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "tensor_shape" field. + public const int TensorShapeFieldNumber = 3; + private global::Tensorflow.TensorShapeProto tensorShape_; + /// + /// The static shape should be recorded here, to the extent that it can + /// be known in advance. In the case of a SparseTensor, this field describes + /// the logical shape of the represented tensor (aka dense_shape). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto TensorShape { + get { return tensorShape_; } + set { + tensorShape_ = value; + } + } + + private object encoding_; + /// Enum of possible cases for the "encoding" oneof. + public enum EncodingOneofCase { + None = 0, + Name = 1, + CooSparse = 4, + CompositeTensor = 5, + } + private EncodingOneofCase encodingCase_ = EncodingOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EncodingOneofCase EncodingCase { + get { return encodingCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearEncoding() { + encodingCase_ = EncodingOneofCase.None; + encoding_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TensorInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TensorInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(CooSparse, other.CooSparse)) return false; + if (!object.Equals(CompositeTensor, other.CompositeTensor)) return false; + if (Dtype != other.Dtype) return false; + if (!object.Equals(TensorShape, other.TensorShape)) return false; + if (EncodingCase != other.EncodingCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (encodingCase_ == EncodingOneofCase.Name) hash ^= Name.GetHashCode(); + if (encodingCase_ == EncodingOneofCase.CooSparse) hash ^= CooSparse.GetHashCode(); + if (encodingCase_ == EncodingOneofCase.CompositeTensor) hash ^= CompositeTensor.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (tensorShape_ != null) hash ^= TensorShape.GetHashCode(); + hash ^= (int) encodingCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (encodingCase_ == EncodingOneofCase.Name) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Dtype); + } + if (tensorShape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorShape); + } + if (encodingCase_ == EncodingOneofCase.CooSparse) { + output.WriteRawTag(34); + output.WriteMessage(CooSparse); + } + if (encodingCase_ == EncodingOneofCase.CompositeTensor) { + output.WriteRawTag(42); + output.WriteMessage(CompositeTensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (encodingCase_ == EncodingOneofCase.Name) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Dtype); + } + if (tensorShape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorShape); + } + if (encodingCase_ == EncodingOneofCase.CooSparse) { + output.WriteRawTag(34); + output.WriteMessage(CooSparse); + } + if (encodingCase_ == EncodingOneofCase.CompositeTensor) { + output.WriteRawTag(42); + output.WriteMessage(CompositeTensor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (encodingCase_ == EncodingOneofCase.Name) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (encodingCase_ == EncodingOneofCase.CooSparse) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CooSparse); + } + if (encodingCase_ == EncodingOneofCase.CompositeTensor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CompositeTensor); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (tensorShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorShape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TensorInfo other) { + if (other == null) { + return; + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.tensorShape_ != null) { + if (tensorShape_ == null) { + TensorShape = new global::Tensorflow.TensorShapeProto(); + } + TensorShape.MergeFrom(other.TensorShape); + } + switch (other.EncodingCase) { + case EncodingOneofCase.Name: + Name = other.Name; + break; + case EncodingOneofCase.CooSparse: + if (CooSparse == null) { + CooSparse = new global::Tensorflow.TensorInfo.Types.CooSparse(); + } + CooSparse.MergeFrom(other.CooSparse); + break; + case EncodingOneofCase.CompositeTensor: + if (CompositeTensor == null) { + CompositeTensor = new global::Tensorflow.TensorInfo.Types.CompositeTensor(); + } + CompositeTensor.MergeFrom(other.CompositeTensor); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 26: { + if (tensorShape_ == null) { + TensorShape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(TensorShape); + break; + } + case 34: { + global::Tensorflow.TensorInfo.Types.CooSparse subBuilder = new global::Tensorflow.TensorInfo.Types.CooSparse(); + if (encodingCase_ == EncodingOneofCase.CooSparse) { + subBuilder.MergeFrom(CooSparse); + } + input.ReadMessage(subBuilder); + CooSparse = subBuilder; + break; + } + case 42: { + global::Tensorflow.TensorInfo.Types.CompositeTensor subBuilder = new global::Tensorflow.TensorInfo.Types.CompositeTensor(); + if (encodingCase_ == EncodingOneofCase.CompositeTensor) { + subBuilder.MergeFrom(CompositeTensor); + } + input.ReadMessage(subBuilder); + CompositeTensor = subBuilder; + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 26: { + if (tensorShape_ == null) { + TensorShape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(TensorShape); + break; + } + case 34: { + global::Tensorflow.TensorInfo.Types.CooSparse subBuilder = new global::Tensorflow.TensorInfo.Types.CooSparse(); + if (encodingCase_ == EncodingOneofCase.CooSparse) { + subBuilder.MergeFrom(CooSparse); + } + input.ReadMessage(subBuilder); + CooSparse = subBuilder; + break; + } + case 42: { + global::Tensorflow.TensorInfo.Types.CompositeTensor subBuilder = new global::Tensorflow.TensorInfo.Types.CompositeTensor(); + if (encodingCase_ == EncodingOneofCase.CompositeTensor) { + subBuilder.MergeFrom(CompositeTensor); + } + input.ReadMessage(subBuilder); + CompositeTensor = subBuilder; + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TensorInfo message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// For sparse tensors, The COO encoding stores a triple of values, indices, + /// and shape. + /// + public sealed partial class CooSparse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CooSparse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorInfo.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CooSparse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CooSparse(CooSparse other) : this() { + valuesTensorName_ = other.valuesTensorName_; + indicesTensorName_ = other.indicesTensorName_; + denseShapeTensorName_ = other.denseShapeTensorName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CooSparse Clone() { + return new CooSparse(this); + } + + /// Field number for the "values_tensor_name" field. + public const int ValuesTensorNameFieldNumber = 1; + private string valuesTensorName_ = ""; + /// + /// The shape of the values Tensor is [?]. Its dtype must be the dtype of + /// the SparseTensor as a whole, given in the enclosing TensorInfo. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ValuesTensorName { + get { return valuesTensorName_; } + set { + valuesTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "indices_tensor_name" field. + public const int IndicesTensorNameFieldNumber = 2; + private string indicesTensorName_ = ""; + /// + /// The indices Tensor must have dtype int64 and shape [?, ?]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string IndicesTensorName { + get { return indicesTensorName_; } + set { + indicesTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "dense_shape_tensor_name" field. + public const int DenseShapeTensorNameFieldNumber = 3; + private string denseShapeTensorName_ = ""; + /// + /// The dynamic logical shape represented by the SparseTensor is recorded in + /// the Tensor referenced here. It must have dtype int64 and shape [?]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DenseShapeTensorName { + get { return denseShapeTensorName_; } + set { + denseShapeTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CooSparse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CooSparse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ValuesTensorName != other.ValuesTensorName) return false; + if (IndicesTensorName != other.IndicesTensorName) return false; + if (DenseShapeTensorName != other.DenseShapeTensorName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ValuesTensorName.Length != 0) hash ^= ValuesTensorName.GetHashCode(); + if (IndicesTensorName.Length != 0) hash ^= IndicesTensorName.GetHashCode(); + if (DenseShapeTensorName.Length != 0) hash ^= DenseShapeTensorName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ValuesTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ValuesTensorName); + } + if (IndicesTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(IndicesTensorName); + } + if (DenseShapeTensorName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DenseShapeTensorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ValuesTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ValuesTensorName); + } + if (IndicesTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(IndicesTensorName); + } + if (DenseShapeTensorName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DenseShapeTensorName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ValuesTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ValuesTensorName); + } + if (IndicesTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(IndicesTensorName); + } + if (DenseShapeTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DenseShapeTensorName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CooSparse other) { + if (other == null) { + return; + } + if (other.ValuesTensorName.Length != 0) { + ValuesTensorName = other.ValuesTensorName; + } + if (other.IndicesTensorName.Length != 0) { + IndicesTensorName = other.IndicesTensorName; + } + if (other.DenseShapeTensorName.Length != 0) { + DenseShapeTensorName = other.DenseShapeTensorName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ValuesTensorName = input.ReadString(); + break; + } + case 18: { + IndicesTensorName = input.ReadString(); + break; + } + case 26: { + DenseShapeTensorName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ValuesTensorName = input.ReadString(); + break; + } + case 18: { + IndicesTensorName = input.ReadString(); + break; + } + case 26: { + DenseShapeTensorName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Generic encoding for composite tensors. + /// + public sealed partial class CompositeTensor : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CompositeTensor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorInfo.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompositeTensor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompositeTensor(CompositeTensor other) : this() { + typeSpec_ = other.typeSpec_ != null ? other.typeSpec_.Clone() : null; + components_ = other.components_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompositeTensor Clone() { + return new CompositeTensor(this); + } + + /// Field number for the "type_spec" field. + public const int TypeSpecFieldNumber = 1; + private global::Tensorflow.TypeSpecProto typeSpec_; + /// + /// The serialized TypeSpec for the composite tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TypeSpecProto TypeSpec { + get { return typeSpec_; } + set { + typeSpec_ = value; + } + } + + /// Field number for the "components" field. + public const int ComponentsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_components_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorInfo.Parser); + private readonly pbc::RepeatedField components_ = new pbc::RepeatedField(); + /// + /// A TensorInfo for each flattened component tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Components { + get { return components_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CompositeTensor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CompositeTensor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(TypeSpec, other.TypeSpec)) return false; + if(!components_.Equals(other.components_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (typeSpec_ != null) hash ^= TypeSpec.GetHashCode(); + hash ^= components_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (typeSpec_ != null) { + output.WriteRawTag(10); + output.WriteMessage(TypeSpec); + } + components_.WriteTo(output, _repeated_components_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (typeSpec_ != null) { + output.WriteRawTag(10); + output.WriteMessage(TypeSpec); + } + components_.WriteTo(ref output, _repeated_components_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (typeSpec_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TypeSpec); + } + size += components_.CalculateSize(_repeated_components_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CompositeTensor other) { + if (other == null) { + return; + } + if (other.typeSpec_ != null) { + if (typeSpec_ == null) { + TypeSpec = new global::Tensorflow.TypeSpecProto(); + } + TypeSpec.MergeFrom(other.TypeSpec); + } + components_.Add(other.components_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (typeSpec_ == null) { + TypeSpec = new global::Tensorflow.TypeSpecProto(); + } + input.ReadMessage(TypeSpec); + break; + } + case 18: { + components_.AddEntriesFrom(input, _repeated_components_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (typeSpec_ == null) { + TypeSpec = new global::Tensorflow.TypeSpecProto(); + } + input.ReadMessage(TypeSpec); + break; + } + case 18: { + components_.AddEntriesFrom(ref input, _repeated_components_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// SignatureDef defines the signature of a computation supported by a TensorFlow + /// graph. + /// + /// For example, a model with two loss computations, sharing a single input, + /// might have the following signature_def map, in a MetaGraphDef message. + /// + /// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, + /// output key, and method_name are identical, and will be used by system(s) that + /// implement or rely upon this particular loss method. The output tensor names + /// differ, demonstrating how different outputs can exist for the same method. + /// + /// signature_def { + /// key: "loss_A" + /// value { + /// inputs { + /// key: "input" + /// value { + /// name: "input:0" + /// dtype: DT_STRING + /// tensor_shape: ... + /// } + /// } + /// outputs { + /// key: "loss_output" + /// value { + /// name: "loss_output_A:0" + /// dtype: DT_FLOAT + /// tensor_shape: ... + /// } + /// } + /// method_name: "some/package/compute_loss" + /// } + /// ... + /// } + /// signature_def { + /// key: "loss_B" + /// value { + /// inputs { + /// key: "input" + /// value { + /// name: "input:0" + /// dtype: DT_STRING + /// tensor_shape: ... + /// } + /// } + /// outputs { + /// key: "loss_output" + /// value { + /// name: "loss_output_B:0" + /// dtype: DT_FLOAT + /// tensor_shape: ... + /// } + /// } + /// method_name: "some/package/compute_loss" + /// } + /// ... + /// } + /// + public sealed partial class SignatureDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SignatureDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SignatureDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SignatureDef(SignatureDef other) : this() { + inputs_ = other.inputs_.Clone(); + outputs_ = other.outputs_.Clone(); + methodName_ = other.methodName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SignatureDef Clone() { + return new SignatureDef(this); + } + + /// Field number for the "inputs" field. + public const int InputsFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_inputs_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorInfo.Parser), 10); + private readonly pbc::MapField inputs_ = new pbc::MapField(); + /// + /// Named input parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Inputs { + get { return inputs_; } + } + + /// Field number for the "outputs" field. + public const int OutputsFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_outputs_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorInfo.Parser), 18); + private readonly pbc::MapField outputs_ = new pbc::MapField(); + /// + /// Named output parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Outputs { + get { return outputs_; } + } + + /// Field number for the "method_name" field. + public const int MethodNameFieldNumber = 3; + private string methodName_ = ""; + /// + /// Extensible method_name information enabling third-party users to mark a + /// SignatureDef as supporting a particular method. This enables producers and + /// consumers of SignatureDefs, e.g. a model definition library and a serving + /// library to have a clear hand-off regarding the semantics of a computation. + /// + /// Note that multiple SignatureDefs in a single MetaGraphDef may have the same + /// method_name. This is commonly used to support multi-headed computation, + /// where a single graph computation may return multiple results. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string MethodName { + get { return methodName_; } + set { + methodName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SignatureDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SignatureDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Inputs.Equals(other.Inputs)) return false; + if (!Outputs.Equals(other.Outputs)) return false; + if (MethodName != other.MethodName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= Inputs.GetHashCode(); + hash ^= Outputs.GetHashCode(); + if (MethodName.Length != 0) hash ^= MethodName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + inputs_.WriteTo(output, _map_inputs_codec); + outputs_.WriteTo(output, _map_outputs_codec); + if (MethodName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(MethodName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + inputs_.WriteTo(ref output, _map_inputs_codec); + outputs_.WriteTo(ref output, _map_outputs_codec); + if (MethodName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(MethodName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += inputs_.CalculateSize(_map_inputs_codec); + size += outputs_.CalculateSize(_map_outputs_codec); + if (MethodName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MethodName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SignatureDef other) { + if (other == null) { + return; + } + inputs_.Add(other.inputs_); + outputs_.Add(other.outputs_); + if (other.MethodName.Length != 0) { + MethodName = other.MethodName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + inputs_.AddEntriesFrom(input, _map_inputs_codec); + break; + } + case 18: { + outputs_.AddEntriesFrom(input, _map_outputs_codec); + break; + } + case 26: { + MethodName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + inputs_.AddEntriesFrom(ref input, _map_inputs_codec); + break; + } + case 18: { + outputs_.AddEntriesFrom(ref input, _map_outputs_codec); + break; + } + case 26: { + MethodName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// An asset file def for a single file or a set of sharded files with the same + /// name. + /// + public sealed partial class AssetFileDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AssetFileDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MetaGraphReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AssetFileDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AssetFileDef(AssetFileDef other) : this() { + tensorInfo_ = other.tensorInfo_ != null ? other.tensorInfo_.Clone() : null; + filename_ = other.filename_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AssetFileDef Clone() { + return new AssetFileDef(this); + } + + /// Field number for the "tensor_info" field. + public const int TensorInfoFieldNumber = 1; + private global::Tensorflow.TensorInfo tensorInfo_; + /// + /// The tensor to bind the asset filename to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorInfo TensorInfo { + get { return tensorInfo_; } + set { + tensorInfo_ = value; + } + } + + /// Field number for the "filename" field. + public const int FilenameFieldNumber = 2; + private string filename_ = ""; + /// + /// The filename within an assets directory. Note: does not include the path + /// prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename + /// would be "vocab.txt". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Filename { + get { return filename_; } + set { + filename_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AssetFileDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AssetFileDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(TensorInfo, other.TensorInfo)) return false; + if (Filename != other.Filename) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (tensorInfo_ != null) hash ^= TensorInfo.GetHashCode(); + if (Filename.Length != 0) hash ^= Filename.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (tensorInfo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(TensorInfo); + } + if (Filename.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Filename); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (tensorInfo_ != null) { + output.WriteRawTag(10); + output.WriteMessage(TensorInfo); + } + if (Filename.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Filename); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (tensorInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorInfo); + } + if (Filename.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Filename); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AssetFileDef other) { + if (other == null) { + return; + } + if (other.tensorInfo_ != null) { + if (tensorInfo_ == null) { + TensorInfo = new global::Tensorflow.TensorInfo(); + } + TensorInfo.MergeFrom(other.TensorInfo); + } + if (other.Filename.Length != 0) { + Filename = other.Filename; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (tensorInfo_ == null) { + TensorInfo = new global::Tensorflow.TensorInfo(); + } + input.ReadMessage(TensorInfo); + break; + } + case 18: { + Filename = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (tensorInfo_ == null) { + TensorInfo = new global::Tensorflow.TensorInfo(); + } + input.ReadMessage(TensorInfo); + break; + } + case 18: { + Filename = input.ReadString(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/NodeDef.cs b/src/TensorFlowNET.Core/Protobuf/NodeDef.cs index af40fd62b..657ef46eb 100644 --- a/src/TensorFlowNET.Core/Protobuf/NodeDef.cs +++ b/src/TensorFlowNET.Core/Protobuf/NodeDef.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: node_def.proto +// source: tensorflow/core/framework/node_def.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from node_def.proto + /// Holder for reflection information generated from tensorflow/core/framework/node_def.proto public static partial class NodeDefReflection { #region Descriptor - /// File descriptor for node_def.proto + /// File descriptor for tensorflow/core/framework/node_def.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,41 +24,56 @@ public static partial class NodeDefReflection { static NodeDefReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "Cg5ub2RlX2RlZi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", - "byKzAQoHTm9kZURlZhIMCgRuYW1lGAEgASgJEgoKAm9wGAIgASgJEg0KBWlu", - "cHV0GAMgAygJEg4KBmRldmljZRgEIAEoCRIrCgRhdHRyGAUgAygLMh0udGVu", - "c29yZmxvdy5Ob2RlRGVmLkF0dHJFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5", - "GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6", - "AjgBQmkKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IJTm9kZVByb3RvUAFa", - "PWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cv", - "Z28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw==")); + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL25vZGVfZGVmLnByb3RvEgp0", + "ZW5zb3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2F0dHJfdmFs", + "dWUucHJvdG8aKXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvZnVsbF90eXBl", + "LnByb3RvIoYDCgdOb2RlRGVmEgwKBG5hbWUYASABKAkSCgoCb3AYAiABKAkS", + "DQoFaW5wdXQYAyADKAkSDgoGZGV2aWNlGAQgASgJEisKBGF0dHIYBSADKAsy", + "HS50ZW5zb3JmbG93Lk5vZGVEZWYuQXR0ckVudHJ5EkoKF2V4cGVyaW1lbnRh", + "bF9kZWJ1Z19pbmZvGAYgASgLMikudGVuc29yZmxvdy5Ob2RlRGVmLkV4cGVy", + "aW1lbnRhbERlYnVnSW5mbxIyChFleHBlcmltZW50YWxfdHlwZRgHIAEoCzIX", + "LnRlbnNvcmZsb3cuRnVsbFR5cGVEZWYaQgoJQXR0ckVudHJ5EgsKA2tleRgB", + "IAEoCRIkCgV2YWx1ZRgCIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlOgI4", + "ARpRChVFeHBlcmltZW50YWxEZWJ1Z0luZm8SGwoTb3JpZ2luYWxfbm9kZV9u", + "YW1lcxgBIAMoCRIbChNvcmlnaW5hbF9mdW5jX25hbWVzGAIgAygJQnsKGG9y", + "Zy50ZW5zb3JmbG93LmZyYW1ld29ya0IJTm9kZVByb3RvUAFaT2dpdGh1Yi5j", + "b20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9m", + "cmFtZXdvcmsvbm9kZV9kZWZfZ29fcHJvdG/4AQFiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeDef), global::Tensorflow.NodeDef.Parser, new[]{ "Name", "Op", "Input", "Device", "Attr" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.FullTypeReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeDef), global::Tensorflow.NodeDef.Parser, new[]{ "Name", "Op", "Input", "Device", "Attr", "ExperimentalDebugInfo", "ExperimentalType" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo), global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo.Parser, new[]{ "OriginalNodeNames", "OriginalFuncNames" }, null, null, null, null)}) })); } #endregion } #region Messages - public sealed partial class NodeDef : pb::IMessage { + public sealed partial class NodeDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeDef()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.NodeDefReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NodeDef() { OnConstruction(); } @@ -66,16 +81,20 @@ public NodeDef() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NodeDef(NodeDef other) : this() { name_ = other.name_; op_ = other.op_; input_ = other.input_.Clone(); device_ = other.device_; attr_ = other.attr_.Clone(); + experimentalDebugInfo_ = other.experimentalDebugInfo_ != null ? other.experimentalDebugInfo_.Clone() : null; + experimentalType_ = other.experimentalType_ != null ? other.experimentalType_.Clone() : null; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public NodeDef Clone() { return new NodeDef(this); } @@ -86,9 +105,10 @@ public NodeDef Clone() { /// /// The name given to this operator. Used for naming inputs, /// logging, visualization, etc. Unique within a single GraphDef. - /// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + /// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -104,6 +124,7 @@ public string Name { /// Op names starting with an underscore are reserved for internal use. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Op { get { return op_; } set { @@ -124,6 +145,7 @@ public string Op { /// "^node". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Input { get { return input_; } } @@ -154,6 +176,7 @@ public string Op { /// choose a device automatically. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Device { get { return device_; } set { @@ -164,7 +187,7 @@ public string Device { /// Field number for the "attr" field. public const int AttrFieldNumber = 5; private static readonly pbc::MapField.Codec _map_attr_codec - = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); private readonly pbc::MapField attr_ = new pbc::MapField(); /// /// Operation-specific graph-construction-time configuration. @@ -181,16 +204,52 @@ public string Device { /// TODO(josh11b): Add some examples here showing best practices. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::MapField Attr { get { return attr_; } } + /// Field number for the "experimental_debug_info" field. + public const int ExperimentalDebugInfoFieldNumber = 6; + private global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo experimentalDebugInfo_; + /// + /// This stores debug information associated with the node. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo ExperimentalDebugInfo { + get { return experimentalDebugInfo_; } + set { + experimentalDebugInfo_ = value; + } + } + + /// Field number for the "experimental_type" field. + public const int ExperimentalTypeFieldNumber = 7; + private global::Tensorflow.FullTypeDef experimentalType_; + /// + /// The complete type of this node. Experimental and subject to change. + /// Currently, the field only contains the return types of the node. That will + /// extend in the future to contain the entire signature of the node, as a + /// function type. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FullTypeDef ExperimentalType { + get { return experimentalType_; } + set { + experimentalType_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as NodeDef); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(NodeDef other) { if (ReferenceEquals(other, null)) { return false; @@ -203,10 +262,13 @@ public bool Equals(NodeDef other) { if(!input_.Equals(other.input_)) return false; if (Device != other.Device) return false; if (!Attr.Equals(other.Attr)) return false; + if (!object.Equals(ExperimentalDebugInfo, other.ExperimentalDebugInfo)) return false; + if (!object.Equals(ExperimentalType, other.ExperimentalType)) return false; return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Name.Length != 0) hash ^= Name.GetHashCode(); @@ -214,6 +276,8 @@ public override int GetHashCode() { hash ^= input_.GetHashCode(); if (Device.Length != 0) hash ^= Device.GetHashCode(); hash ^= Attr.GetHashCode(); + if (experimentalDebugInfo_ != null) hash ^= ExperimentalDebugInfo.GetHashCode(); + if (experimentalType_ != null) hash ^= ExperimentalType.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -221,12 +285,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Name.Length != 0) { output.WriteRawTag(10); output.WriteString(Name); @@ -241,12 +310,54 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteString(Device); } attr_.WriteTo(output, _map_attr_codec); + if (experimentalDebugInfo_ != null) { + output.WriteRawTag(50); + output.WriteMessage(ExperimentalDebugInfo); + } + if (experimentalType_ != null) { + output.WriteRawTag(58); + output.WriteMessage(ExperimentalType); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Op.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Op); + } + input_.WriteTo(ref output, _repeated_input_codec); + if (Device.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Device); + } + attr_.WriteTo(ref output, _map_attr_codec); + if (experimentalDebugInfo_ != null) { + output.WriteRawTag(50); + output.WriteMessage(ExperimentalDebugInfo); + } + if (experimentalType_ != null) { + output.WriteRawTag(58); + output.WriteMessage(ExperimentalType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Name.Length != 0) { @@ -260,6 +371,12 @@ public int CalculateSize() { size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); } size += attr_.CalculateSize(_map_attr_codec); + if (experimentalDebugInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExperimentalDebugInfo); + } + if (experimentalType_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExperimentalType); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -267,6 +384,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(NodeDef other) { if (other == null) { return; @@ -282,11 +400,27 @@ public void MergeFrom(NodeDef other) { Device = other.Device; } attr_.Add(other.attr_); + if (other.experimentalDebugInfo_ != null) { + if (experimentalDebugInfo_ == null) { + ExperimentalDebugInfo = new global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo(); + } + ExperimentalDebugInfo.MergeFrom(other.ExperimentalDebugInfo); + } + if (other.experimentalType_ != null) { + if (experimentalType_ == null) { + ExperimentalType = new global::Tensorflow.FullTypeDef(); + } + ExperimentalType.MergeFrom(other.ExperimentalType); + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -313,9 +447,303 @@ public void MergeFrom(pb::CodedInputStream input) { attr_.AddEntriesFrom(input, _map_attr_codec); break; } + case 50: { + if (experimentalDebugInfo_ == null) { + ExperimentalDebugInfo = new global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo(); + } + input.ReadMessage(ExperimentalDebugInfo); + break; + } + case 58: { + if (experimentalType_ == null) { + ExperimentalType = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(ExperimentalType); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Op = input.ReadString(); + break; + } + case 26: { + input_.AddEntriesFrom(ref input, _repeated_input_codec); + break; + } + case 34: { + Device = input.ReadString(); + break; + } + case 42: { + attr_.AddEntriesFrom(ref input, _map_attr_codec); + break; + } + case 50: { + if (experimentalDebugInfo_ == null) { + ExperimentalDebugInfo = new global::Tensorflow.NodeDef.Types.ExperimentalDebugInfo(); + } + input.ReadMessage(ExperimentalDebugInfo); + break; + } + case 58: { + if (experimentalType_ == null) { + ExperimentalType = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(ExperimentalType); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the NodeDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class ExperimentalDebugInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExperimentalDebugInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.NodeDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExperimentalDebugInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExperimentalDebugInfo(ExperimentalDebugInfo other) : this() { + originalNodeNames_ = other.originalNodeNames_.Clone(); + originalFuncNames_ = other.originalFuncNames_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExperimentalDebugInfo Clone() { + return new ExperimentalDebugInfo(this); + } + + /// Field number for the "original_node_names" field. + public const int OriginalNodeNamesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_originalNodeNames_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField originalNodeNames_ = new pbc::RepeatedField(); + /// + /// Opaque string inserted into error messages created by the runtime. + /// + /// This is intended to store the list of names of the nodes from the + /// original graph that this node was derived. For example if this node, say + /// C, was result of a fusion of 2 nodes A and B, then 'original_node' would + /// be {A, B}. This information can be used to map errors originating at the + /// current node to some top level source code. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OriginalNodeNames { + get { return originalNodeNames_; } + } + + /// Field number for the "original_func_names" field. + public const int OriginalFuncNamesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_originalFuncNames_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField originalFuncNames_ = new pbc::RepeatedField(); + /// + /// This is intended to store the list of names of the functions from the + /// original graph that this node was derived. For example if this node, say + /// C, was result of a fusion of node A in function FA and node B in function + /// FB, then `original_funcs` would be {FA, FB}. If the node is in the top + /// level graph, the `original_func` is empty. This information, with the + /// `original_node_names` can be used to map errors originating at the + /// current ndoe to some top level source code. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OriginalFuncNames { + get { return originalFuncNames_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExperimentalDebugInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExperimentalDebugInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!originalNodeNames_.Equals(other.originalNodeNames_)) return false; + if(!originalFuncNames_.Equals(other.originalFuncNames_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= originalNodeNames_.GetHashCode(); + hash ^= originalFuncNames_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + originalNodeNames_.WriteTo(output, _repeated_originalNodeNames_codec); + originalFuncNames_.WriteTo(output, _repeated_originalFuncNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + originalNodeNames_.WriteTo(ref output, _repeated_originalNodeNames_codec); + originalFuncNames_.WriteTo(ref output, _repeated_originalFuncNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += originalNodeNames_.CalculateSize(_repeated_originalNodeNames_codec); + size += originalFuncNames_.CalculateSize(_repeated_originalFuncNames_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExperimentalDebugInfo other) { + if (other == null) { + return; + } + originalNodeNames_.Add(other.originalNodeNames_); + originalFuncNames_.Add(other.originalFuncNames_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + originalNodeNames_.AddEntriesFrom(input, _repeated_originalNodeNames_codec); + break; + } + case 18: { + originalFuncNames_.AddEntriesFrom(input, _repeated_originalFuncNames_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + originalNodeNames_.AddEntriesFrom(ref input, _repeated_originalNodeNames_codec); + break; + } + case 18: { + originalFuncNames_.AddEntriesFrom(ref input, _repeated_originalFuncNames_codec); + break; + } + } + } } + #endif + } + } + #endregion } diff --git a/src/TensorFlowNET.Core/Protobuf/OpDef.cs b/src/TensorFlowNET.Core/Protobuf/OpDef.cs index 737e97e5f..dd6a26450 100644 --- a/src/TensorFlowNET.Core/Protobuf/OpDef.cs +++ b/src/TensorFlowNET.Core/Protobuf/OpDef.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: op_def.proto +// source: tensorflow/core/framework/op_def.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from op_def.proto + /// Holder for reflection information generated from tensorflow/core/framework/op_def.proto public static partial class OpDefReflection { #region Descriptor - /// File descriptor for op_def.proto + /// File descriptor for tensorflow/core/framework/op_def.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,35 +24,43 @@ public static partial class OpDefReflection { static OpDefReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "CgxvcF9kZWYucHJvdG8SCnRlbnNvcmZsb3caEGF0dHJfdmFsdWUucHJvdG8a", - "C3R5cGVzLnByb3RvIrgFCgVPcERlZhIMCgRuYW1lGAEgASgJEisKCWlucHV0", - "X2FyZxgCIAMoCzIYLnRlbnNvcmZsb3cuT3BEZWYuQXJnRGVmEiwKCm91dHB1", - "dF9hcmcYAyADKAsyGC50ZW5zb3JmbG93Lk9wRGVmLkFyZ0RlZhInCgRhdHRy", - "GAQgAygLMhkudGVuc29yZmxvdy5PcERlZi5BdHRyRGVmEi4KC2RlcHJlY2F0", - "aW9uGAggASgLMhkudGVuc29yZmxvdy5PcERlcHJlY2F0aW9uEg8KB3N1bW1h", - "cnkYBSABKAkSEwoLZGVzY3JpcHRpb24YBiABKAkSFgoOaXNfY29tbXV0YXRp", - "dmUYEiABKAgSFAoMaXNfYWdncmVnYXRlGBAgASgIEhMKC2lzX3N0YXRlZnVs", - "GBEgASgIEiIKGmFsbG93c191bmluaXRpYWxpemVkX2lucHV0GBMgASgIGp8B", - "CgZBcmdEZWYSDAoEbmFtZRgBIAEoCRITCgtkZXNjcmlwdGlvbhgCIAEoCRIi", - "CgR0eXBlGAMgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZRIRCgl0eXBlX2F0", - "dHIYBCABKAkSEwoLbnVtYmVyX2F0dHIYBSABKAkSFgoOdHlwZV9saXN0X2F0", - "dHIYBiABKAkSDgoGaXNfcmVmGBAgASgIGr0BCgdBdHRyRGVmEgwKBG5hbWUY", - "ASABKAkSDAoEdHlwZRgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgLMhUu", - "dGVuc29yZmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAkSEwoL", - "aGFzX21pbmltdW0YBSABKAgSDwoHbWluaW11bRgGIAEoAxItCg5hbGxvd2Vk", - "X3ZhbHVlcxgHIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjUKDU9wRGVw", - "cmVjYXRpb24SDwoHdmVyc2lvbhgBIAEoBRITCgtleHBsYW5hdGlvbhgCIAEo", - "CSInCgZPcExpc3QSHQoCb3AYASADKAsyES50ZW5zb3JmbG93Lk9wRGVmQmsK", - "GG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0ILT3BEZWZQcm90b3NQAVo9Z2l0", - "aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9j", - "b3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + "CiZ0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL29wX2RlZi5wcm90bxIKdGVu", + "c29yZmxvdxoqdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay9hdHRyX3ZhbHVl", + "LnByb3RvGil0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZS5w", + "cm90bxovdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay9yZXNvdXJjZV9oYW5k", + "bGUucHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdHlwZXMucHJv", + "dG8i8wYKBU9wRGVmEgwKBG5hbWUYASABKAkSKwoJaW5wdXRfYXJnGAIgAygL", + "MhgudGVuc29yZmxvdy5PcERlZi5BcmdEZWYSLAoKb3V0cHV0X2FyZxgDIAMo", + "CzIYLnRlbnNvcmZsb3cuT3BEZWYuQXJnRGVmEhYKDmNvbnRyb2xfb3V0cHV0", + "GBQgAygJEicKBGF0dHIYBCADKAsyGS50ZW5zb3JmbG93Lk9wRGVmLkF0dHJE", + "ZWYSLgoLZGVwcmVjYXRpb24YCCABKAsyGS50ZW5zb3JmbG93Lk9wRGVwcmVj", + "YXRpb24SDwoHc3VtbWFyeRgFIAEoCRITCgtkZXNjcmlwdGlvbhgGIAEoCRIW", + "Cg5pc19jb21tdXRhdGl2ZRgSIAEoCBIUCgxpc19hZ2dyZWdhdGUYECABKAgS", + "EwoLaXNfc3RhdGVmdWwYESABKAgSIgoaYWxsb3dzX3VuaW5pdGlhbGl6ZWRf", + "aW5wdXQYEyABKAgSJAocaXNfZGlzdHJpYnV0ZWRfY29tbXVuaWNhdGlvbhgV", + "IAEoCBqcAgoGQXJnRGVmEgwKBG5hbWUYASABKAkSEwoLZGVzY3JpcHRpb24Y", + "AiABKAkSIgoEdHlwZRgDIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGUSEQoJ", + "dHlwZV9hdHRyGAQgASgJEhMKC251bWJlcl9hdHRyGAUgASgJEhYKDnR5cGVf", + "bGlzdF9hdHRyGAYgASgJEkIKC2hhbmRsZV9kYXRhGAcgAygLMi0udGVuc29y", + "Zmxvdy5SZXNvdXJjZUhhbmRsZVByb3RvLkR0eXBlQW5kU2hhcGUSDgoGaXNf", + "cmVmGBAgASgIEjcKFmV4cGVyaW1lbnRhbF9mdWxsX3R5cGUYESABKAsyFy50", + "ZW5zb3JmbG93LkZ1bGxUeXBlRGVmGr0BCgdBdHRyRGVmEgwKBG5hbWUYASAB", + "KAkSDAoEdHlwZRgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgLMhUudGVu", + "c29yZmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAkSEwoLaGFz", + "X21pbmltdW0YBSABKAgSDwoHbWluaW11bRgGIAEoAxItCg5hbGxvd2VkX3Zh", + "bHVlcxgHIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjUKDU9wRGVwcmVj", + "YXRpb24SDwoHdmVyc2lvbhgBIAEoBRITCgtleHBsYW5hdGlvbhgCIAEoCSIn", + "CgZPcExpc3QSHQoCb3AYASADKAsyES50ZW5zb3JmbG93Lk9wRGVmQnsKGG9y", + "Zy50ZW5zb3JmbG93LmZyYW1ld29ya0ILT3BEZWZQcm90b3NQAVpNZ2l0aHVi", + "LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", + "L2ZyYW1ld29yay9vcF9kZWZfZ29fcHJvdG/4AQFiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef), global::Tensorflow.OpDef.Parser, new[]{ "Name", "InputArg", "OutputArg", "Attr", "Deprecation", "Summary", "Description", "IsCommutative", "IsAggregate", "IsStateful", "AllowsUninitializedInput" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.ArgDef), global::Tensorflow.OpDef.Types.ArgDef.Parser, new[]{ "Name", "Description", "Type", "TypeAttr", "NumberAttr", "TypeListAttr", "IsRef" }, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.AttrDef), global::Tensorflow.OpDef.Types.AttrDef.Parser, new[]{ "Name", "Type", "DefaultValue", "Description", "HasMinimum", "Minimum", "AllowedValues" }, null, null, null)}), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDeprecation), global::Tensorflow.OpDeprecation.Parser, new[]{ "Version", "Explanation" }, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpList), global::Tensorflow.OpList.Parser, new[]{ "Op" }, null, null, null) + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.FullTypeReflection.Descriptor, global::Tensorflow.ResourceHandleReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef), global::Tensorflow.OpDef.Parser, new[]{ "Name", "InputArg", "OutputArg", "ControlOutput", "Attr", "Deprecation", "Summary", "Description", "IsCommutative", "IsAggregate", "IsStateful", "AllowsUninitializedInput", "IsDistributedCommunication" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.ArgDef), global::Tensorflow.OpDef.Types.ArgDef.Parser, new[]{ "Name", "Description", "Type", "TypeAttr", "NumberAttr", "TypeListAttr", "HandleData", "IsRef", "ExperimentalFullType" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.AttrDef), global::Tensorflow.OpDef.Types.AttrDef.Parser, new[]{ "Name", "Type", "DefaultValue", "Description", "HasMinimum", "Minimum", "AllowedValues" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDeprecation), global::Tensorflow.OpDeprecation.Parser, new[]{ "Version", "Explanation" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpList), global::Tensorflow.OpList.Parser, new[]{ "Op" }, null, null, null, null) })); } #endregion @@ -64,23 +72,31 @@ static OpDefReflection() { /// using the "op" field which should match the name of a OpDef. /// LINT.IfChange /// - public sealed partial class OpDef : pb::IMessage { + public sealed partial class OpDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpDef()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDef() { OnConstruction(); } @@ -88,10 +104,12 @@ public OpDef() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDef(OpDef other) : this() { name_ = other.name_; inputArg_ = other.inputArg_.Clone(); outputArg_ = other.outputArg_.Clone(); + controlOutput_ = other.controlOutput_.Clone(); attr_ = other.attr_.Clone(); deprecation_ = other.deprecation_ != null ? other.deprecation_.Clone() : null; summary_ = other.summary_; @@ -100,10 +118,12 @@ public OpDef(OpDef other) : this() { isAggregate_ = other.isAggregate_; isStateful_ = other.isStateful_; allowsUninitializedInput_ = other.allowsUninitializedInput_; + isDistributedCommunication_ = other.isDistributedCommunication_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDef Clone() { return new OpDef(this); } @@ -113,9 +133,10 @@ public OpDef Clone() { private string name_ = ""; /// /// Op names starting with an underscore are reserved for internal use. - /// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + /// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9>_]*". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -132,6 +153,7 @@ public string Name { /// Description of the input(s). /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField InputArg { get { return inputArg_; } } @@ -145,16 +167,33 @@ public string Name { /// Description of the output(s). /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField OutputArg { get { return outputArg_; } } + /// Field number for the "control_output" field. + public const int ControlOutputFieldNumber = 20; + private static readonly pb::FieldCodec _repeated_controlOutput_codec + = pb::FieldCodec.ForString(162); + private readonly pbc::RepeatedField controlOutput_ = new pbc::RepeatedField(); + /// + /// Named control outputs for this operation. Useful only for composite + /// operations (i.e. functions) which want to name different control outputs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ControlOutput { + get { return controlOutput_; } + } + /// Field number for the "attr" field. public const int AttrFieldNumber = 4; private static readonly pb::FieldCodec _repeated_attr_codec = pb::FieldCodec.ForMessage(34, global::Tensorflow.OpDef.Types.AttrDef.Parser); private readonly pbc::RepeatedField attr_ = new pbc::RepeatedField(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Attr { get { return attr_; } } @@ -166,6 +205,7 @@ public string Name { /// Optional deprecation based on GraphDef versions. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.OpDeprecation Deprecation { get { return deprecation_; } set { @@ -180,6 +220,7 @@ public string Name { /// One-line human-readable description of what the Op does. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Summary { get { return summary_; } set { @@ -194,6 +235,7 @@ public string Summary { /// Additional, longer human-readable description of what the Op does. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Description { get { return description_; } set { @@ -208,6 +250,7 @@ public string Description { /// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool IsCommutative { get { return isCommutative_; } set { @@ -229,6 +272,7 @@ public bool IsCommutative { /// TODO(josh11b): Implement that optimization. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool IsAggregate { get { return isAggregate_; } set { @@ -253,6 +297,7 @@ public bool IsAggregate { /// Subexpression Elimination (CSE). /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool IsStateful { get { return isStateful_; } set { @@ -270,6 +315,7 @@ public bool IsStateful { /// input. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool AllowsUninitializedInput { get { return allowsUninitializedInput_; } set { @@ -277,12 +323,31 @@ public bool AllowsUninitializedInput { } } + /// Field number for the "is_distributed_communication" field. + public const int IsDistributedCommunicationFieldNumber = 21; + private bool isDistributedCommunication_; + /// + /// Indicates whether the op implementation uses distributed communication. + /// If True, the op is allowed to return errors for network disconnection and + /// trigger TF network failure handling logics. + /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsDistributedCommunication { + get { return isDistributedCommunication_; } + set { + isDistributedCommunication_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as OpDef); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(OpDef other) { if (ReferenceEquals(other, null)) { return false; @@ -293,6 +358,7 @@ public bool Equals(OpDef other) { if (Name != other.Name) return false; if(!inputArg_.Equals(other.inputArg_)) return false; if(!outputArg_.Equals(other.outputArg_)) return false; + if(!controlOutput_.Equals(other.controlOutput_)) return false; if(!attr_.Equals(other.attr_)) return false; if (!object.Equals(Deprecation, other.Deprecation)) return false; if (Summary != other.Summary) return false; @@ -301,15 +367,18 @@ public bool Equals(OpDef other) { if (IsAggregate != other.IsAggregate) return false; if (IsStateful != other.IsStateful) return false; if (AllowsUninitializedInput != other.AllowsUninitializedInput) return false; + if (IsDistributedCommunication != other.IsDistributedCommunication) return false; return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Name.Length != 0) hash ^= Name.GetHashCode(); hash ^= inputArg_.GetHashCode(); hash ^= outputArg_.GetHashCode(); + hash ^= controlOutput_.GetHashCode(); hash ^= attr_.GetHashCode(); if (deprecation_ != null) hash ^= Deprecation.GetHashCode(); if (Summary.Length != 0) hash ^= Summary.GetHashCode(); @@ -318,6 +387,7 @@ public override int GetHashCode() { if (IsAggregate != false) hash ^= IsAggregate.GetHashCode(); if (IsStateful != false) hash ^= IsStateful.GetHashCode(); if (AllowsUninitializedInput != false) hash ^= AllowsUninitializedInput.GetHashCode(); + if (IsDistributedCommunication != false) hash ^= IsDistributedCommunication.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -325,12 +395,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Name.Length != 0) { output.WriteRawTag(10); output.WriteString(Name); @@ -366,12 +441,69 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(152, 1); output.WriteBool(AllowsUninitializedInput); } + controlOutput_.WriteTo(output, _repeated_controlOutput_codec); + if (IsDistributedCommunication != false) { + output.WriteRawTag(168, 1); + output.WriteBool(IsDistributedCommunication); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + inputArg_.WriteTo(ref output, _repeated_inputArg_codec); + outputArg_.WriteTo(ref output, _repeated_outputArg_codec); + attr_.WriteTo(ref output, _repeated_attr_codec); + if (Summary.Length != 0) { + output.WriteRawTag(42); + output.WriteString(Summary); + } + if (Description.Length != 0) { + output.WriteRawTag(50); + output.WriteString(Description); + } + if (deprecation_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Deprecation); + } + if (IsAggregate != false) { + output.WriteRawTag(128, 1); + output.WriteBool(IsAggregate); + } + if (IsStateful != false) { + output.WriteRawTag(136, 1); + output.WriteBool(IsStateful); + } + if (IsCommutative != false) { + output.WriteRawTag(144, 1); + output.WriteBool(IsCommutative); + } + if (AllowsUninitializedInput != false) { + output.WriteRawTag(152, 1); + output.WriteBool(AllowsUninitializedInput); + } + controlOutput_.WriteTo(ref output, _repeated_controlOutput_codec); + if (IsDistributedCommunication != false) { + output.WriteRawTag(168, 1); + output.WriteBool(IsDistributedCommunication); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Name.Length != 0) { @@ -379,6 +511,7 @@ public int CalculateSize() { } size += inputArg_.CalculateSize(_repeated_inputArg_codec); size += outputArg_.CalculateSize(_repeated_outputArg_codec); + size += controlOutput_.CalculateSize(_repeated_controlOutput_codec); size += attr_.CalculateSize(_repeated_attr_codec); if (deprecation_ != null) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Deprecation); @@ -401,6 +534,9 @@ public int CalculateSize() { if (AllowsUninitializedInput != false) { size += 2 + 1; } + if (IsDistributedCommunication != false) { + size += 2 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -408,6 +544,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(OpDef other) { if (other == null) { return; @@ -417,10 +554,11 @@ public void MergeFrom(OpDef other) { } inputArg_.Add(other.inputArg_); outputArg_.Add(other.outputArg_); + controlOutput_.Add(other.controlOutput_); attr_.Add(other.attr_); if (other.deprecation_ != null) { if (deprecation_ == null) { - deprecation_ = new global::Tensorflow.OpDeprecation(); + Deprecation = new global::Tensorflow.OpDeprecation(); } Deprecation.MergeFrom(other.Deprecation); } @@ -442,11 +580,18 @@ public void MergeFrom(OpDef other) { if (other.AllowsUninitializedInput != false) { AllowsUninitializedInput = other.AllowsUninitializedInput; } + if (other.IsDistributedCommunication != false) { + IsDistributedCommunication = other.IsDistributedCommunication; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -479,9 +624,9 @@ public void MergeFrom(pb::CodedInputStream input) { } case 66: { if (deprecation_ == null) { - deprecation_ = new global::Tensorflow.OpDeprecation(); + Deprecation = new global::Tensorflow.OpDeprecation(); } - input.ReadMessage(deprecation_); + input.ReadMessage(Deprecation); break; } case 128: { @@ -500,34 +645,122 @@ public void MergeFrom(pb::CodedInputStream input) { AllowsUninitializedInput = input.ReadBool(); break; } + case 162: { + controlOutput_.AddEntriesFrom(input, _repeated_controlOutput_codec); + break; + } + case 168: { + IsDistributedCommunication = input.ReadBool(); + break; + } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + inputArg_.AddEntriesFrom(ref input, _repeated_inputArg_codec); + break; + } + case 26: { + outputArg_.AddEntriesFrom(ref input, _repeated_outputArg_codec); + break; + } + case 34: { + attr_.AddEntriesFrom(ref input, _repeated_attr_codec); + break; + } + case 42: { + Summary = input.ReadString(); + break; + } + case 50: { + Description = input.ReadString(); + break; + } + case 66: { + if (deprecation_ == null) { + Deprecation = new global::Tensorflow.OpDeprecation(); + } + input.ReadMessage(Deprecation); + break; + } + case 128: { + IsAggregate = input.ReadBool(); + break; + } + case 136: { + IsStateful = input.ReadBool(); + break; + } + case 144: { + IsCommutative = input.ReadBool(); + break; + } + case 152: { + AllowsUninitializedInput = input.ReadBool(); + break; + } + case 162: { + controlOutput_.AddEntriesFrom(ref input, _repeated_controlOutput_codec); + break; + } + case 168: { + IsDistributedCommunication = input.ReadBool(); + break; + } + } + } + } + #endif + #region Nested types /// Container for nested types declared in the OpDef message type. [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static partial class Types { /// /// For describing inputs and outputs. /// - public sealed partial class ArgDef : pb::IMessage { + public sealed partial class ArgDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ArgDef()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.OpDef.Descriptor.NestedTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ArgDef() { OnConstruction(); } @@ -535,6 +768,7 @@ public ArgDef() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ArgDef(ArgDef other) : this() { name_ = other.name_; description_ = other.description_; @@ -542,11 +776,14 @@ public ArgDef(ArgDef other) : this() { typeAttr_ = other.typeAttr_; numberAttr_ = other.numberAttr_; typeListAttr_ = other.typeListAttr_; + handleData_ = other.handleData_.Clone(); isRef_ = other.isRef_; + experimentalFullType_ = other.experimentalFullType_ != null ? other.experimentalFullType_.Clone() : null; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ArgDef Clone() { return new ArgDef(this); } @@ -558,6 +795,7 @@ public ArgDef Clone() { /// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -572,6 +810,7 @@ public string Name { /// Human readable description. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Description { get { return description_; } set { @@ -581,7 +820,7 @@ public string Description { /// Field number for the "type" field. public const int TypeFieldNumber = 3; - private global::Tensorflow.DataType type_ = 0; + private global::Tensorflow.DataType type_ = global::Tensorflow.DataType.DtInvalid; /// /// Describes the type of one or more tensors that are accepted/produced /// by this input/output arg. The only legal combinations are: @@ -595,6 +834,7 @@ public string Description { /// to the name of an attr with type "list(type)". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.DataType Type { get { return type_; } set { @@ -609,6 +849,7 @@ public string Description { /// if specified, attr must have type "type" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string TypeAttr { get { return typeAttr_; } set { @@ -623,6 +864,7 @@ public string TypeAttr { /// if specified, attr must have type "int" /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string NumberAttr { get { return numberAttr_; } set { @@ -638,6 +880,7 @@ public string NumberAttr { /// type, type_attr, and number_attr may be specified. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string TypeListAttr { get { return typeListAttr_; } set { @@ -645,6 +888,20 @@ public string TypeListAttr { } } + /// Field number for the "handle_data" field. + public const int HandleDataFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_handleData_codec + = pb::FieldCodec.ForMessage(58, global::Tensorflow.ResourceHandleProto.Types.DtypeAndShape.Parser); + private readonly pbc::RepeatedField handleData_ = new pbc::RepeatedField(); + /// + /// The handle data for resource inputs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField HandleData { + get { return handleData_; } + } + /// Field number for the "is_ref" field. public const int IsRefFieldNumber = 16; private bool isRef_; @@ -654,6 +911,7 @@ public string TypeListAttr { /// For outputs: if true, outputs are refs, otherwise they are not. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool IsRef { get { return isRef_; } set { @@ -661,12 +919,37 @@ public bool IsRef { } } + /// Field number for the "experimental_full_type" field. + public const int ExperimentalFullTypeFieldNumber = 17; + private global::Tensorflow.FullTypeDef experimentalFullType_; + /// + /// Experimental. Full type declaration for this argument. + /// The full type specification combines type, type_attr, type_list_attr, + /// etc. into a unified representation. + /// This declaration may contain non-concrete types (for example, + /// Tensor<TypeVar<'T'>> is a valid type declaration. + /// + /// Note: this is a transient field. The long-term aim is to represent the + /// entire OpDef as a single type: a callable. In that context, this field is + /// just the type of a single argument. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FullTypeDef ExperimentalFullType { + get { return experimentalFullType_; } + set { + experimentalFullType_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as ArgDef); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(ArgDef other) { if (ReferenceEquals(other, null)) { return false; @@ -680,20 +963,25 @@ public bool Equals(ArgDef other) { if (TypeAttr != other.TypeAttr) return false; if (NumberAttr != other.NumberAttr) return false; if (TypeListAttr != other.TypeListAttr) return false; + if(!handleData_.Equals(other.handleData_)) return false; if (IsRef != other.IsRef) return false; + if (!object.Equals(ExperimentalFullType, other.ExperimentalFullType)) return false; return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Name.Length != 0) hash ^= Name.GetHashCode(); if (Description.Length != 0) hash ^= Description.GetHashCode(); - if (Type != 0) hash ^= Type.GetHashCode(); + if (Type != global::Tensorflow.DataType.DtInvalid) hash ^= Type.GetHashCode(); if (TypeAttr.Length != 0) hash ^= TypeAttr.GetHashCode(); if (NumberAttr.Length != 0) hash ^= NumberAttr.GetHashCode(); if (TypeListAttr.Length != 0) hash ^= TypeListAttr.GetHashCode(); + hash ^= handleData_.GetHashCode(); if (IsRef != false) hash ^= IsRef.GetHashCode(); + if (experimentalFullType_ != null) hash ^= ExperimentalFullType.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -701,12 +989,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Name.Length != 0) { output.WriteRawTag(10); output.WriteString(Name); @@ -715,7 +1008,7 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(18); output.WriteString(Description); } - if (Type != 0) { + if (Type != global::Tensorflow.DataType.DtInvalid) { output.WriteRawTag(24); output.WriteEnum((int) Type); } @@ -731,16 +1024,66 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(50); output.WriteString(TypeListAttr); } + handleData_.WriteTo(output, _repeated_handleData_codec); if (IsRef != false) { output.WriteRawTag(128, 1); output.WriteBool(IsRef); } + if (experimentalFullType_ != null) { + output.WriteRawTag(138, 1); + output.WriteMessage(ExperimentalFullType); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Description.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Description); + } + if (Type != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Type); + } + if (TypeAttr.Length != 0) { + output.WriteRawTag(34); + output.WriteString(TypeAttr); + } + if (NumberAttr.Length != 0) { + output.WriteRawTag(42); + output.WriteString(NumberAttr); + } + if (TypeListAttr.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TypeListAttr); + } + handleData_.WriteTo(ref output, _repeated_handleData_codec); + if (IsRef != false) { + output.WriteRawTag(128, 1); + output.WriteBool(IsRef); + } + if (experimentalFullType_ != null) { + output.WriteRawTag(138, 1); + output.WriteMessage(ExperimentalFullType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Name.Length != 0) { @@ -749,7 +1092,7 @@ public int CalculateSize() { if (Description.Length != 0) { size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); } - if (Type != 0) { + if (Type != global::Tensorflow.DataType.DtInvalid) { size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Type); } if (TypeAttr.Length != 0) { @@ -761,9 +1104,13 @@ public int CalculateSize() { if (TypeListAttr.Length != 0) { size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeListAttr); } + size += handleData_.CalculateSize(_repeated_handleData_codec); if (IsRef != false) { size += 2 + 1; } + if (experimentalFullType_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ExperimentalFullType); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -771,6 +1118,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(ArgDef other) { if (other == null) { return; @@ -781,7 +1129,7 @@ public void MergeFrom(ArgDef other) { if (other.Description.Length != 0) { Description = other.Description; } - if (other.Type != 0) { + if (other.Type != global::Tensorflow.DataType.DtInvalid) { Type = other.Type; } if (other.TypeAttr.Length != 0) { @@ -793,14 +1141,25 @@ public void MergeFrom(ArgDef other) { if (other.TypeListAttr.Length != 0) { TypeListAttr = other.TypeListAttr; } + handleData_.Add(other.handleData_); if (other.IsRef != false) { IsRef = other.IsRef; } + if (other.experimentalFullType_ != null) { + if (experimentalFullType_ == null) { + ExperimentalFullType = new global::Tensorflow.FullTypeDef(); + } + ExperimentalFullType.MergeFrom(other.ExperimentalFullType); + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -816,7 +1175,61 @@ public void MergeFrom(pb::CodedInputStream input) { break; } case 24: { - type_ = (global::Tensorflow.DataType) input.ReadEnum(); + Type = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + TypeAttr = input.ReadString(); + break; + } + case 42: { + NumberAttr = input.ReadString(); + break; + } + case 50: { + TypeListAttr = input.ReadString(); + break; + } + case 58: { + handleData_.AddEntriesFrom(input, _repeated_handleData_codec); + break; + } + case 128: { + IsRef = input.ReadBool(); + break; + } + case 138: { + if (experimentalFullType_ == null) { + ExperimentalFullType = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(ExperimentalFullType); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Description = input.ReadString(); + break; + } + case 24: { + Type = (global::Tensorflow.DataType) input.ReadEnum(); break; } case 34: { @@ -831,13 +1244,25 @@ public void MergeFrom(pb::CodedInputStream input) { TypeListAttr = input.ReadString(); break; } + case 58: { + handleData_.AddEntriesFrom(ref input, _repeated_handleData_codec); + break; + } case 128: { IsRef = input.ReadBool(); break; } + case 138: { + if (experimentalFullType_ == null) { + ExperimentalFullType = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(ExperimentalFullType); + break; + } } } } + #endif } @@ -846,23 +1271,31 @@ public void MergeFrom(pb::CodedInputStream input) { /// Op. That is to say, this describes the attr fields that will /// be specified in the NodeDef. /// - public sealed partial class AttrDef : pb::IMessage { + public sealed partial class AttrDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttrDef()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.OpDef.Descriptor.NestedTypes[1]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrDef() { OnConstruction(); } @@ -870,6 +1303,7 @@ public AttrDef() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrDef(AttrDef other) : this() { name_ = other.name_; type_ = other.type_; @@ -882,6 +1316,7 @@ public AttrDef(AttrDef other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public AttrDef Clone() { return new AttrDef(this); } @@ -895,6 +1330,7 @@ public AttrDef Clone() { /// the regexp "[a-z][a-z0-9_]+". /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -910,6 +1346,7 @@ public string Name { /// "int", etc.). /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Type { get { return type_; } set { @@ -925,6 +1362,7 @@ public string Type { /// a value. If not specified, the user must supply a value. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.AttrValue DefaultValue { get { return defaultValue_; } set { @@ -939,6 +1377,7 @@ public string Type { /// Human-readable description. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Description { get { return description_; } set { @@ -954,6 +1393,7 @@ public string Description { /// types, this is the minimum length. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool HasMinimum { get { return hasMinimum_; } set { @@ -965,6 +1405,7 @@ public bool HasMinimum { public const int MinimumFieldNumber = 6; private long minimum_; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public long Minimum { get { return minimum_; } set { @@ -984,6 +1425,7 @@ public long Minimum { /// "allowed_values.list" has the set of allowed strings. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.AttrValue AllowedValues { get { return allowedValues_; } set { @@ -992,11 +1434,13 @@ public long Minimum { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as AttrDef); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(AttrDef other) { if (ReferenceEquals(other, null)) { return false; @@ -1015,6 +1459,7 @@ public bool Equals(AttrDef other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Name.Length != 0) hash ^= Name.GetHashCode(); @@ -1031,12 +1476,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Name.Length != 0) { output.WriteRawTag(10); output.WriteString(Name); @@ -1068,9 +1518,49 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Type.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Type); + } + if (defaultValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefaultValue); + } + if (Description.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Description); + } + if (HasMinimum != false) { + output.WriteRawTag(40); + output.WriteBool(HasMinimum); + } + if (Minimum != 0L) { + output.WriteRawTag(48); + output.WriteInt64(Minimum); + } + if (allowedValues_ != null) { + output.WriteRawTag(58); + output.WriteMessage(AllowedValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Name.Length != 0) { @@ -1101,6 +1591,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(AttrDef other) { if (other == null) { return; @@ -1113,7 +1604,7 @@ public void MergeFrom(AttrDef other) { } if (other.defaultValue_ != null) { if (defaultValue_ == null) { - defaultValue_ = new global::Tensorflow.AttrValue(); + DefaultValue = new global::Tensorflow.AttrValue(); } DefaultValue.MergeFrom(other.DefaultValue); } @@ -1128,7 +1619,7 @@ public void MergeFrom(AttrDef other) { } if (other.allowedValues_ != null) { if (allowedValues_ == null) { - allowedValues_ = new global::Tensorflow.AttrValue(); + AllowedValues = new global::Tensorflow.AttrValue(); } AllowedValues.MergeFrom(other.AllowedValues); } @@ -1136,7 +1627,11 @@ public void MergeFrom(AttrDef other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -1153,9 +1648,9 @@ public void MergeFrom(pb::CodedInputStream input) { } case 26: { if (defaultValue_ == null) { - defaultValue_ = new global::Tensorflow.AttrValue(); + DefaultValue = new global::Tensorflow.AttrValue(); } - input.ReadMessage(defaultValue_); + input.ReadMessage(DefaultValue); break; } case 34: { @@ -1172,15 +1667,65 @@ public void MergeFrom(pb::CodedInputStream input) { } case 58: { if (allowedValues_ == null) { - allowedValues_ = new global::Tensorflow.AttrValue(); + AllowedValues = new global::Tensorflow.AttrValue(); } - input.ReadMessage(allowedValues_); + input.ReadMessage(AllowedValues); break; } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Type = input.ReadString(); + break; + } + case 26: { + if (defaultValue_ == null) { + DefaultValue = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(DefaultValue); + break; + } + case 34: { + Description = input.ReadString(); + break; + } + case 40: { + HasMinimum = input.ReadBool(); + break; + } + case 48: { + Minimum = input.ReadInt64(); + break; + } + case 58: { + if (allowedValues_ == null) { + AllowedValues = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(AllowedValues); + break; + } + } + } + } + #endif + } } @@ -1191,23 +1736,31 @@ public void MergeFrom(pb::CodedInputStream input) { /// /// Information about version-dependent deprecation of an op /// - public sealed partial class OpDeprecation : pb::IMessage { + public sealed partial class OpDeprecation : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpDeprecation()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[1]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDeprecation() { OnConstruction(); } @@ -1215,6 +1768,7 @@ public OpDeprecation() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDeprecation(OpDeprecation other) : this() { version_ = other.version_; explanation_ = other.explanation_; @@ -1222,6 +1776,7 @@ public OpDeprecation(OpDeprecation other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpDeprecation Clone() { return new OpDeprecation(this); } @@ -1233,6 +1788,7 @@ public OpDeprecation Clone() { /// First GraphDef version at which the op is disallowed. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int Version { get { return version_; } set { @@ -1247,6 +1803,7 @@ public int Version { /// Explanation of why it was deprecated and what to use instead. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Explanation { get { return explanation_; } set { @@ -1255,11 +1812,13 @@ public string Explanation { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as OpDeprecation); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(OpDeprecation other) { if (ReferenceEquals(other, null)) { return false; @@ -1273,6 +1832,7 @@ public bool Equals(OpDeprecation other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Version != 0) hash ^= Version.GetHashCode(); @@ -1284,12 +1844,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Version != 0) { output.WriteRawTag(8); output.WriteInt32(Version); @@ -1301,9 +1866,29 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Version != 0) { + output.WriteRawTag(8); + output.WriteInt32(Version); + } + if (Explanation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Explanation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Version != 0) { @@ -1319,6 +1904,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(OpDeprecation other) { if (other == null) { return; @@ -1333,7 +1919,11 @@ public void MergeFrom(OpDeprecation other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -1350,30 +1940,62 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Version = input.ReadInt32(); + break; + } + case 18: { + Explanation = input.ReadString(); + break; + } + } + } + } + #endif + } /// /// A collection of OpDefs /// - public sealed partial class OpList : pb::IMessage { + public sealed partial class OpList : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpList()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[2]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpList() { OnConstruction(); } @@ -1381,12 +2003,14 @@ public OpList() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpList(OpList other) : this() { op_ = other.op_.Clone(); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public OpList Clone() { return new OpList(this); } @@ -1397,16 +2021,19 @@ public OpList Clone() { = pb::FieldCodec.ForMessage(10, global::Tensorflow.OpDef.Parser); private readonly pbc::RepeatedField op_ = new pbc::RepeatedField(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Op { get { return op_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as OpList); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(OpList other) { if (ReferenceEquals(other, null)) { return false; @@ -1419,6 +2046,7 @@ public bool Equals(OpList other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; hash ^= op_.GetHashCode(); @@ -1429,19 +2057,37 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else op_.WriteTo(output, _repeated_op_codec); if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + op_.WriteTo(ref output, _repeated_op_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; size += op_.CalculateSize(_repeated_op_codec); @@ -1452,6 +2098,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(OpList other) { if (other == null) { return; @@ -1461,7 +2108,11 @@ public void MergeFrom(OpList other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -1474,7 +2125,27 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + op_.AddEntriesFrom(ref input, _repeated_op_codec); + break; + } + } + } } + #endif } diff --git a/src/TensorFlowNET.Core/Protobuf/Protocol.cs b/src/TensorFlowNET.Core/Protobuf/Protocol.cs new file mode 100644 index 000000000..6463a9b54 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Protocol.cs @@ -0,0 +1,3840 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/pjrt/distributed/protocol.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla { + + /// Holder for reflection information generated from tensorflow/compiler/xla/pjrt/distributed/protocol.proto + public static partial class ProtocolReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/pjrt/distributed/protocol.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ProtocolReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9wanJ0L2Rpc3RyaWJ1dGVkL3By", + "b3RvY29sLnByb3RvEgN4bGEiYwoLRGV2aWNlUHJvdG8SHAoUbG9jYWxfZGV2", + "aWNlX29yZGluYWwYASABKAUSDAoEbmFtZRgCIAEoCRIOCgZ2ZW5kb3IYAyAB", + "KAkSGAoQZ2xvYmFsX2RldmljZV9pZBgEIAEoBSJIChJMb2NhbFRvcG9sb2d5", + "UHJvdG8SDwoHbm9kZV9pZBgBIAEoBRIhCgdkZXZpY2VzGAIgAygLMhAueGxh", + "LkRldmljZVByb3RvIj0KE0dsb2JhbFRvcG9sb2d5UHJvdG8SJgoFbm9kZXMY", + "ASADKAsyFy54bGEuTG9jYWxUb3BvbG9neVByb3RvImwKDkNvbm5lY3RSZXF1", + "ZXN0EhgKEHByb3RvY29sX3ZlcnNpb24YASABKAUSHAoUdGltZW91dF9taWxs", + "aXNlY29uZHMYAiABKAUSDwoHbm9kZV9pZBgDIAEoBRIRCgljbGllbnRfaWQY", + "BCABKAQiJQoPQ29ubmVjdFJlc3BvbnNlEhIKCnNlc3Npb25faWQYASABKAQi", + "XgoXRW51bWVyYXRlRGV2aWNlc1JlcXVlc3QSEgoKc2Vzc2lvbl9pZBgBIAEo", + "BBIvCg5sb2NhbF90b3BvbG9neRgDIAEoCzIXLnhsYS5Mb2NhbFRvcG9sb2d5", + "UHJvdG8iTQoYRW51bWVyYXRlRGV2aWNlc1Jlc3BvbnNlEjEKD2dsb2JhbF90", + "b3BvbG9neRgBIAEoCzIYLnhsYS5HbG9iYWxUb3BvbG9neVByb3RvIlMKEktl", + "eVZhbHVlR2V0UmVxdWVzdBISCgpzZXNzaW9uX2lkGAEgASgEEgsKA2tleRgC", + "IAEoDBIcChR0aW1lb3V0X21pbGxpc2Vjb25kcxgDIAEoBSIzChNLZXlWYWx1", + "ZUdldFJlc3BvbnNlEg0KBWZvdW5kGAEgASgIEg0KBXZhbHVlGAIgASgMIkQK", + "EktleVZhbHVlU2V0UmVxdWVzdBISCgpzZXNzaW9uX2lkGAEgASgEEgsKA2tl", + "eRgCIAEoDBINCgV2YWx1ZRgDIAEoDCIVChNLZXlWYWx1ZVNldFJlc3BvbnNl", + "Im0KFFdhaXRBdEJhcnJpZXJSZXF1ZXN0EhIKCnNlc3Npb25faWQYASABKAQS", + "EgoKYmFycmllcl9pZBgCIAEoDBIPCgdub2RlX2lkGAMgASgFEhwKFHRpbWVv", + "dXRfbWlsbGlzZWNvbmRzGAQgASgFIhcKFVdhaXRBdEJhcnJpZXJSZXNwb25z", + "ZSI3ChBIZWFydGJlYXRSZXF1ZXN0EhIKCnNlc3Npb25faWQYASABKAQSDwoH", + "bm9kZV9pZBgCIAEoBSITChFIZWFydGJlYXRSZXNwb25zZSI2Cg9TaHV0ZG93", + "blJlcXVlc3QSEgoKc2Vzc2lvbl9pZBgBIAEoBBIPCgdub2RlX2lkGAIgASgF", + "IhIKEFNodXRkb3duUmVzcG9uc2Uy8QMKGURpc3RyaWJ1dGVkUnVudGltZVNl", + "cnZpY2USNgoHQ29ubmVjdBITLnhsYS5Db25uZWN0UmVxdWVzdBoULnhsYS5D", + "b25uZWN0UmVzcG9uc2UiABJRChBFbnVtZXJhdGVEZXZpY2VzEhwueGxhLkVu", + "dW1lcmF0ZURldmljZXNSZXF1ZXN0Gh0ueGxhLkVudW1lcmF0ZURldmljZXNS", + "ZXNwb25zZSIAEjwKCUhlYXJ0YmVhdBIVLnhsYS5IZWFydGJlYXRSZXF1ZXN0", + "GhYueGxhLkhlYXJ0YmVhdFJlc3BvbnNlIgASOQoIU2h1dGRvd24SFC54bGEu", + "U2h1dGRvd25SZXF1ZXN0GhUueGxhLlNodXRkb3duUmVzcG9uc2UiABJCCgtL", + "ZXlWYWx1ZUdldBIXLnhsYS5LZXlWYWx1ZUdldFJlcXVlc3QaGC54bGEuS2V5", + "VmFsdWVHZXRSZXNwb25zZSIAEkIKC0tleVZhbHVlU2V0EhcueGxhLktleVZh", + "bHVlU2V0UmVxdWVzdBoYLnhsYS5LZXlWYWx1ZVNldFJlc3BvbnNlIgASSAoN", + "V2FpdEF0QmFycmllchIZLnhsYS5XYWl0QXRCYXJyaWVyUmVxdWVzdBoaLnhs", + "YS5XYWl0QXRCYXJyaWVyUmVzcG9uc2UiAEJgWl5naXRodWIuY29tL3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvbXBpbGVyL3hsYS9w", + "anJ0L2Rpc3RyaWJ1dGVkL3Byb3RvY29sX2dvX3Byb3RvYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeviceProto), global::Xla.DeviceProto.Parser, new[]{ "LocalDeviceOrdinal", "Name", "Vendor", "GlobalDeviceId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LocalTopologyProto), global::Xla.LocalTopologyProto.Parser, new[]{ "NodeId", "Devices" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GlobalTopologyProto), global::Xla.GlobalTopologyProto.Parser, new[]{ "Nodes" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ConnectRequest), global::Xla.ConnectRequest.Parser, new[]{ "ProtocolVersion", "TimeoutMilliseconds", "NodeId", "ClientId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ConnectResponse), global::Xla.ConnectResponse.Parser, new[]{ "SessionId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.EnumerateDevicesRequest), global::Xla.EnumerateDevicesRequest.Parser, new[]{ "SessionId", "LocalTopology" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.EnumerateDevicesResponse), global::Xla.EnumerateDevicesResponse.Parser, new[]{ "GlobalTopology" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.KeyValueGetRequest), global::Xla.KeyValueGetRequest.Parser, new[]{ "SessionId", "Key", "TimeoutMilliseconds" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.KeyValueGetResponse), global::Xla.KeyValueGetResponse.Parser, new[]{ "Found", "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.KeyValueSetRequest), global::Xla.KeyValueSetRequest.Parser, new[]{ "SessionId", "Key", "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.KeyValueSetResponse), global::Xla.KeyValueSetResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WaitAtBarrierRequest), global::Xla.WaitAtBarrierRequest.Parser, new[]{ "SessionId", "BarrierId", "NodeId", "TimeoutMilliseconds" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WaitAtBarrierResponse), global::Xla.WaitAtBarrierResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HeartbeatRequest), global::Xla.HeartbeatRequest.Parser, new[]{ "SessionId", "NodeId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.HeartbeatResponse), global::Xla.HeartbeatResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ShutdownRequest), global::Xla.ShutdownRequest.Parser, new[]{ "SessionId", "NodeId" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ShutdownResponse), global::Xla.ShutdownResponse.Parser, null, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Describes a device local to a host. + /// + public sealed partial class DeviceProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceProto(DeviceProto other) : this() { + localDeviceOrdinal_ = other.localDeviceOrdinal_; + name_ = other.name_; + vendor_ = other.vendor_; + globalDeviceId_ = other.globalDeviceId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceProto Clone() { + return new DeviceProto(this); + } + + /// Field number for the "local_device_ordinal" field. + public const int LocalDeviceOrdinalFieldNumber = 1; + private int localDeviceOrdinal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int LocalDeviceOrdinal { + get { return localDeviceOrdinal_; } + set { + localDeviceOrdinal_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 2; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "vendor" field. + public const int VendorFieldNumber = 3; + private string vendor_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Vendor { + get { return vendor_; } + set { + vendor_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "global_device_id" field. + public const int GlobalDeviceIdFieldNumber = 4; + private int globalDeviceId_; + /// + /// The following fields are present in the GlobalTopologyProto message + /// returned by EnumerateDevices() but not in the LocalTopologyProto messages + /// passed to EnumerateDevices(). In other words, the coordinator node + /// determines the global device IDs during EnumerateDevices(). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int GlobalDeviceId { + get { return globalDeviceId_; } + set { + globalDeviceId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LocalDeviceOrdinal != other.LocalDeviceOrdinal) return false; + if (Name != other.Name) return false; + if (Vendor != other.Vendor) return false; + if (GlobalDeviceId != other.GlobalDeviceId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LocalDeviceOrdinal != 0) hash ^= LocalDeviceOrdinal.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Vendor.Length != 0) hash ^= Vendor.GetHashCode(); + if (GlobalDeviceId != 0) hash ^= GlobalDeviceId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LocalDeviceOrdinal != 0) { + output.WriteRawTag(8); + output.WriteInt32(LocalDeviceOrdinal); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (Vendor.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Vendor); + } + if (GlobalDeviceId != 0) { + output.WriteRawTag(32); + output.WriteInt32(GlobalDeviceId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LocalDeviceOrdinal != 0) { + output.WriteRawTag(8); + output.WriteInt32(LocalDeviceOrdinal); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (Vendor.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Vendor); + } + if (GlobalDeviceId != 0) { + output.WriteRawTag(32); + output.WriteInt32(GlobalDeviceId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LocalDeviceOrdinal != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(LocalDeviceOrdinal); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Vendor.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Vendor); + } + if (GlobalDeviceId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(GlobalDeviceId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceProto other) { + if (other == null) { + return; + } + if (other.LocalDeviceOrdinal != 0) { + LocalDeviceOrdinal = other.LocalDeviceOrdinal; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Vendor.Length != 0) { + Vendor = other.Vendor; + } + if (other.GlobalDeviceId != 0) { + GlobalDeviceId = other.GlobalDeviceId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + LocalDeviceOrdinal = input.ReadInt32(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + case 26: { + Vendor = input.ReadString(); + break; + } + case 32: { + GlobalDeviceId = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + LocalDeviceOrdinal = input.ReadInt32(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + case 26: { + Vendor = input.ReadString(); + break; + } + case 32: { + GlobalDeviceId = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class LocalTopologyProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LocalTopologyProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalTopologyProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalTopologyProto(LocalTopologyProto other) : this() { + nodeId_ = other.nodeId_; + devices_ = other.devices_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LocalTopologyProto Clone() { + return new LocalTopologyProto(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 1; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "devices" field. + public const int DevicesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_devices_codec + = pb::FieldCodec.ForMessage(18, global::Xla.DeviceProto.Parser); + private readonly pbc::RepeatedField devices_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Devices { + get { return devices_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LocalTopologyProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LocalTopologyProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if(!devices_.Equals(other.devices_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + hash ^= devices_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + devices_.WriteTo(output, _repeated_devices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + devices_.WriteTo(ref output, _repeated_devices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + size += devices_.CalculateSize(_repeated_devices_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LocalTopologyProto other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + devices_.Add(other.devices_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + devices_.AddEntriesFrom(input, _repeated_devices_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + devices_.AddEntriesFrom(ref input, _repeated_devices_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class GlobalTopologyProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GlobalTopologyProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalTopologyProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalTopologyProto(GlobalTopologyProto other) : this() { + nodes_ = other.nodes_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalTopologyProto Clone() { + return new GlobalTopologyProto(this); + } + + /// Field number for the "nodes" field. + public const int NodesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_nodes_codec + = pb::FieldCodec.ForMessage(10, global::Xla.LocalTopologyProto.Parser); + private readonly pbc::RepeatedField nodes_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Nodes { + get { return nodes_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GlobalTopologyProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GlobalTopologyProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!nodes_.Equals(other.nodes_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= nodes_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + nodes_.WriteTo(output, _repeated_nodes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + nodes_.WriteTo(ref output, _repeated_nodes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += nodes_.CalculateSize(_repeated_nodes_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GlobalTopologyProto other) { + if (other == null) { + return; + } + nodes_.Add(other.nodes_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + nodes_.AddEntriesFrom(input, _repeated_nodes_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + nodes_.AddEntriesFrom(ref input, _repeated_nodes_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class ConnectRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConnectRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectRequest(ConnectRequest other) : this() { + protocolVersion_ = other.protocolVersion_; + timeoutMilliseconds_ = other.timeoutMilliseconds_; + nodeId_ = other.nodeId_; + clientId_ = other.clientId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectRequest Clone() { + return new ConnectRequest(this); + } + + /// Field number for the "protocol_version" field. + public const int ProtocolVersionFieldNumber = 1; + private int protocolVersion_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ProtocolVersion { + get { return protocolVersion_; } + set { + protocolVersion_ = value; + } + } + + /// Field number for the "timeout_milliseconds" field. + public const int TimeoutMillisecondsFieldNumber = 2; + private int timeoutMilliseconds_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int TimeoutMilliseconds { + get { return timeoutMilliseconds_; } + set { + timeoutMilliseconds_ = value; + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 3; + private int nodeId_; + /// + /// We assume that each node knows its globally-unique node ID, provided by + /// whatever mechanism launches the tasks. Node IDs should form a dense range + /// of integers [0, num_nodes). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "client_id" field. + public const int ClientIdFieldNumber = 4; + private ulong clientId_; + /// + /// A unique ID number for the client. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong ClientId { + get { return clientId_; } + set { + clientId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ConnectRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ConnectRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ProtocolVersion != other.ProtocolVersion) return false; + if (TimeoutMilliseconds != other.TimeoutMilliseconds) return false; + if (NodeId != other.NodeId) return false; + if (ClientId != other.ClientId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ProtocolVersion != 0) hash ^= ProtocolVersion.GetHashCode(); + if (TimeoutMilliseconds != 0) hash ^= TimeoutMilliseconds.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (ClientId != 0UL) hash ^= ClientId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ProtocolVersion != 0) { + output.WriteRawTag(8); + output.WriteInt32(ProtocolVersion); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(16); + output.WriteInt32(TimeoutMilliseconds); + } + if (NodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(NodeId); + } + if (ClientId != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(ClientId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ProtocolVersion != 0) { + output.WriteRawTag(8); + output.WriteInt32(ProtocolVersion); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(16); + output.WriteInt32(TimeoutMilliseconds); + } + if (NodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(NodeId); + } + if (ClientId != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(ClientId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ProtocolVersion != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ProtocolVersion); + } + if (TimeoutMilliseconds != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TimeoutMilliseconds); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (ClientId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(ClientId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ConnectRequest other) { + if (other == null) { + return; + } + if (other.ProtocolVersion != 0) { + ProtocolVersion = other.ProtocolVersion; + } + if (other.TimeoutMilliseconds != 0) { + TimeoutMilliseconds = other.TimeoutMilliseconds; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.ClientId != 0UL) { + ClientId = other.ClientId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ProtocolVersion = input.ReadInt32(); + break; + } + case 16: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + case 24: { + NodeId = input.ReadInt32(); + break; + } + case 32: { + ClientId = input.ReadUInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ProtocolVersion = input.ReadInt32(); + break; + } + case 16: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + case 24: { + NodeId = input.ReadInt32(); + break; + } + case 32: { + ClientId = input.ReadUInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class ConnectResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConnectResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectResponse(ConnectResponse other) : this() { + sessionId_ = other.sessionId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConnectResponse Clone() { + return new ConnectResponse(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ConnectResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ConnectResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ConnectResponse other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class EnumerateDevicesRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EnumerateDevicesRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesRequest(EnumerateDevicesRequest other) : this() { + sessionId_ = other.sessionId_; + localTopology_ = other.localTopology_ != null ? other.localTopology_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesRequest Clone() { + return new EnumerateDevicesRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "local_topology" field. + public const int LocalTopologyFieldNumber = 3; + private global::Xla.LocalTopologyProto localTopology_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LocalTopologyProto LocalTopology { + get { return localTopology_; } + set { + localTopology_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as EnumerateDevicesRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(EnumerateDevicesRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (!object.Equals(LocalTopology, other.LocalTopology)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (localTopology_ != null) hash ^= LocalTopology.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (localTopology_ != null) { + output.WriteRawTag(26); + output.WriteMessage(LocalTopology); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (localTopology_ != null) { + output.WriteRawTag(26); + output.WriteMessage(LocalTopology); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (localTopology_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LocalTopology); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(EnumerateDevicesRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.localTopology_ != null) { + if (localTopology_ == null) { + LocalTopology = new global::Xla.LocalTopologyProto(); + } + LocalTopology.MergeFrom(other.LocalTopology); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 26: { + if (localTopology_ == null) { + LocalTopology = new global::Xla.LocalTopologyProto(); + } + input.ReadMessage(LocalTopology); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 26: { + if (localTopology_ == null) { + LocalTopology = new global::Xla.LocalTopologyProto(); + } + input.ReadMessage(LocalTopology); + break; + } + } + } + } + #endif + + } + + public sealed partial class EnumerateDevicesResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EnumerateDevicesResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesResponse(EnumerateDevicesResponse other) : this() { + globalTopology_ = other.globalTopology_ != null ? other.globalTopology_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public EnumerateDevicesResponse Clone() { + return new EnumerateDevicesResponse(this); + } + + /// Field number for the "global_topology" field. + public const int GlobalTopologyFieldNumber = 1; + private global::Xla.GlobalTopologyProto globalTopology_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalTopologyProto GlobalTopology { + get { return globalTopology_; } + set { + globalTopology_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as EnumerateDevicesResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(EnumerateDevicesResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(GlobalTopology, other.GlobalTopology)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (globalTopology_ != null) hash ^= GlobalTopology.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (globalTopology_ != null) { + output.WriteRawTag(10); + output.WriteMessage(GlobalTopology); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (globalTopology_ != null) { + output.WriteRawTag(10); + output.WriteMessage(GlobalTopology); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (globalTopology_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GlobalTopology); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(EnumerateDevicesResponse other) { + if (other == null) { + return; + } + if (other.globalTopology_ != null) { + if (globalTopology_ == null) { + GlobalTopology = new global::Xla.GlobalTopologyProto(); + } + GlobalTopology.MergeFrom(other.GlobalTopology); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (globalTopology_ == null) { + GlobalTopology = new global::Xla.GlobalTopologyProto(); + } + input.ReadMessage(GlobalTopology); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (globalTopology_ == null) { + GlobalTopology = new global::Xla.GlobalTopologyProto(); + } + input.ReadMessage(GlobalTopology); + break; + } + } + } + } + #endif + + } + + public sealed partial class KeyValueGetRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeyValueGetRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetRequest(KeyValueGetRequest other) : this() { + sessionId_ = other.sessionId_; + key_ = other.key_; + timeoutMilliseconds_ = other.timeoutMilliseconds_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetRequest Clone() { + return new KeyValueGetRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 2; + private pb::ByteString key_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "timeout_milliseconds" field. + public const int TimeoutMillisecondsFieldNumber = 3; + private int timeoutMilliseconds_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int TimeoutMilliseconds { + get { return timeoutMilliseconds_; } + set { + timeoutMilliseconds_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KeyValueGetRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KeyValueGetRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (Key != other.Key) return false; + if (TimeoutMilliseconds != other.TimeoutMilliseconds) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (TimeoutMilliseconds != 0) hash ^= TimeoutMilliseconds.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (Key.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Key); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(24); + output.WriteInt32(TimeoutMilliseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (Key.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Key); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(24); + output.WriteInt32(TimeoutMilliseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Key); + } + if (TimeoutMilliseconds != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TimeoutMilliseconds); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KeyValueGetRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + if (other.TimeoutMilliseconds != 0) { + TimeoutMilliseconds = other.TimeoutMilliseconds; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + Key = input.ReadBytes(); + break; + } + case 24: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + Key = input.ReadBytes(); + break; + } + case 24: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class KeyValueGetResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeyValueGetResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetResponse(KeyValueGetResponse other) : this() { + found_ = other.found_; + value_ = other.value_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueGetResponse Clone() { + return new KeyValueGetResponse(this); + } + + /// Field number for the "found" field. + public const int FoundFieldNumber = 1; + private bool found_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Found { + get { return found_; } + set { + found_ = value; + } + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 2; + private pb::ByteString value_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Value { + get { return value_; } + set { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KeyValueGetResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KeyValueGetResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Found != other.Found) return false; + if (Value != other.Value) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Found != false) hash ^= Found.GetHashCode(); + if (Value.Length != 0) hash ^= Value.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Found != false) { + output.WriteRawTag(8); + output.WriteBool(Found); + } + if (Value.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Found != false) { + output.WriteRawTag(8); + output.WriteBool(Found); + } + if (Value.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Found != false) { + size += 1 + 1; + } + if (Value.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Value); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KeyValueGetResponse other) { + if (other == null) { + return; + } + if (other.Found != false) { + Found = other.Found; + } + if (other.Value.Length != 0) { + Value = other.Value; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Found = input.ReadBool(); + break; + } + case 18: { + Value = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Found = input.ReadBool(); + break; + } + case 18: { + Value = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + public sealed partial class KeyValueSetRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeyValueSetRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetRequest(KeyValueSetRequest other) : this() { + sessionId_ = other.sessionId_; + key_ = other.key_; + value_ = other.value_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetRequest Clone() { + return new KeyValueSetRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 2; + private pb::ByteString key_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 3; + private pb::ByteString value_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Value { + get { return value_; } + set { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KeyValueSetRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KeyValueSetRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (Key != other.Key) return false; + if (Value != other.Value) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (Value.Length != 0) hash ^= Value.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (Key.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Key); + } + if (Value.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (Key.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Key); + } + if (Value.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Key); + } + if (Value.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Value); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KeyValueSetRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + if (other.Value.Length != 0) { + Value = other.Value; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + Key = input.ReadBytes(); + break; + } + case 26: { + Value = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + Key = input.ReadBytes(); + break; + } + case 26: { + Value = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + public sealed partial class KeyValueSetResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeyValueSetResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetResponse(KeyValueSetResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KeyValueSetResponse Clone() { + return new KeyValueSetResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KeyValueSetResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KeyValueSetResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KeyValueSetResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class WaitAtBarrierRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitAtBarrierRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierRequest(WaitAtBarrierRequest other) : this() { + sessionId_ = other.sessionId_; + barrierId_ = other.barrierId_; + nodeId_ = other.nodeId_; + timeoutMilliseconds_ = other.timeoutMilliseconds_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierRequest Clone() { + return new WaitAtBarrierRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "barrier_id" field. + public const int BarrierIdFieldNumber = 2; + private pb::ByteString barrierId_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString BarrierId { + get { return barrierId_; } + set { + barrierId_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 3; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "timeout_milliseconds" field. + public const int TimeoutMillisecondsFieldNumber = 4; + private int timeoutMilliseconds_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int TimeoutMilliseconds { + get { return timeoutMilliseconds_; } + set { + timeoutMilliseconds_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitAtBarrierRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitAtBarrierRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (BarrierId != other.BarrierId) return false; + if (NodeId != other.NodeId) return false; + if (TimeoutMilliseconds != other.TimeoutMilliseconds) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (BarrierId.Length != 0) hash ^= BarrierId.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (TimeoutMilliseconds != 0) hash ^= TimeoutMilliseconds.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (BarrierId.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(BarrierId); + } + if (NodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(NodeId); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(32); + output.WriteInt32(TimeoutMilliseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (BarrierId.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(BarrierId); + } + if (NodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(NodeId); + } + if (TimeoutMilliseconds != 0) { + output.WriteRawTag(32); + output.WriteInt32(TimeoutMilliseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (BarrierId.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(BarrierId); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (TimeoutMilliseconds != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TimeoutMilliseconds); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitAtBarrierRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.BarrierId.Length != 0) { + BarrierId = other.BarrierId; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.TimeoutMilliseconds != 0) { + TimeoutMilliseconds = other.TimeoutMilliseconds; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + BarrierId = input.ReadBytes(); + break; + } + case 24: { + NodeId = input.ReadInt32(); + break; + } + case 32: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 18: { + BarrierId = input.ReadBytes(); + break; + } + case 24: { + NodeId = input.ReadInt32(); + break; + } + case 32: { + TimeoutMilliseconds = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class WaitAtBarrierResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitAtBarrierResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierResponse(WaitAtBarrierResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitAtBarrierResponse Clone() { + return new WaitAtBarrierResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitAtBarrierResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitAtBarrierResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitAtBarrierResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class HeartbeatRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeartbeatRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest(HeartbeatRequest other) : this() { + sessionId_ = other.sessionId_; + nodeId_ = other.nodeId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatRequest Clone() { + return new HeartbeatRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 2; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HeartbeatRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HeartbeatRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (NodeId != other.NodeId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HeartbeatRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class HeartbeatResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HeartbeatResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse(HeartbeatResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HeartbeatResponse Clone() { + return new HeartbeatResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as HeartbeatResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(HeartbeatResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(HeartbeatResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class ShutdownRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShutdownRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownRequest(ShutdownRequest other) : this() { + sessionId_ = other.sessionId_; + nodeId_ = other.nodeId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownRequest Clone() { + return new ShutdownRequest(this); + } + + /// Field number for the "session_id" field. + public const int SessionIdFieldNumber = 1; + private ulong sessionId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong SessionId { + get { return sessionId_; } + set { + sessionId_ = value; + } + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 2; + private int nodeId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShutdownRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShutdownRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SessionId != other.SessionId) return false; + if (NodeId != other.NodeId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SessionId != 0UL) hash ^= SessionId.GetHashCode(); + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SessionId != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(SessionId); + } + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SessionId != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(SessionId); + } + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShutdownRequest other) { + if (other == null) { + return; + } + if (other.SessionId != 0UL) { + SessionId = other.SessionId; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SessionId = input.ReadUInt64(); + break; + } + case 16: { + NodeId = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class ShutdownResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShutdownResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.ProtocolReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownResponse(ShutdownResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShutdownResponse Clone() { + return new ShutdownResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShutdownResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShutdownResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShutdownResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index 4b4cc3d33..c6ffa5e38 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -1,12 +1,12 @@ -### Download compiler from https://github.com/protocolbuffers/protobuf/releases -```shell -set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework -set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow - -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto -``` \ No newline at end of file +```console +PM> Install-Package Google.Protobuf +``` + +#### How to generate `proto` files. + +Download compiler from https://github.com/protocolbuffers/protobuf/releases. + +Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf, place it at `google/protobuf/any.proto`. + +Run `Gen.bat` under `src\TensorFlowNET.Core\Protobuf` folder. + diff --git a/src/TensorFlowNET.Core/Protobuf/ReaderBase.cs b/src/TensorFlowNET.Core/Protobuf/ReaderBase.cs new file mode 100644 index 000000000..fc27a3362 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ReaderBase.cs @@ -0,0 +1,265 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/reader_base.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/reader_base.proto + public static partial class ReaderBaseReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/reader_base.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ReaderBaseReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cit0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3JlYWRlcl9iYXNlLnByb3Rv", + "Egp0ZW5zb3JmbG93InIKD1JlYWRlckJhc2VTdGF0ZRIUCgx3b3JrX3N0YXJ0", + "ZWQYASABKAMSFQoNd29ya19maW5pc2hlZBgCIAEoAxIcChRudW1fcmVjb3Jk", + "c19wcm9kdWNlZBgDIAEoAxIUCgxjdXJyZW50X3dvcmsYBCABKAxCcAoYb3Jn", + "LnRlbnNvcmZsb3cuZnJhbWV3b3JrQhBSZWFkZXJCYXNlUHJvdG9zUAFaPWdp", + "dGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28v", + "Y29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ReaderBaseState), global::Tensorflow.ReaderBaseState.Parser, new[]{ "WorkStarted", "WorkFinished", "NumRecordsProduced", "CurrentWork" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// For serializing and restoring the state of ReaderBase, see + /// reader_base.h for details. + /// + public sealed partial class ReaderBaseState : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReaderBaseState()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ReaderBaseReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ReaderBaseState() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ReaderBaseState(ReaderBaseState other) : this() { + workStarted_ = other.workStarted_; + workFinished_ = other.workFinished_; + numRecordsProduced_ = other.numRecordsProduced_; + currentWork_ = other.currentWork_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ReaderBaseState Clone() { + return new ReaderBaseState(this); + } + + /// Field number for the "work_started" field. + public const int WorkStartedFieldNumber = 1; + private long workStarted_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long WorkStarted { + get { return workStarted_; } + set { + workStarted_ = value; + } + } + + /// Field number for the "work_finished" field. + public const int WorkFinishedFieldNumber = 2; + private long workFinished_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long WorkFinished { + get { return workFinished_; } + set { + workFinished_ = value; + } + } + + /// Field number for the "num_records_produced" field. + public const int NumRecordsProducedFieldNumber = 3; + private long numRecordsProduced_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long NumRecordsProduced { + get { return numRecordsProduced_; } + set { + numRecordsProduced_ = value; + } + } + + /// Field number for the "current_work" field. + public const int CurrentWorkFieldNumber = 4; + private pb::ByteString currentWork_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString CurrentWork { + get { return currentWork_; } + set { + currentWork_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ReaderBaseState); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ReaderBaseState other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (WorkStarted != other.WorkStarted) return false; + if (WorkFinished != other.WorkFinished) return false; + if (NumRecordsProduced != other.NumRecordsProduced) return false; + if (CurrentWork != other.CurrentWork) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (WorkStarted != 0L) hash ^= WorkStarted.GetHashCode(); + if (WorkFinished != 0L) hash ^= WorkFinished.GetHashCode(); + if (NumRecordsProduced != 0L) hash ^= NumRecordsProduced.GetHashCode(); + if (CurrentWork.Length != 0) hash ^= CurrentWork.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (WorkStarted != 0L) { + output.WriteRawTag(8); + output.WriteInt64(WorkStarted); + } + if (WorkFinished != 0L) { + output.WriteRawTag(16); + output.WriteInt64(WorkFinished); + } + if (NumRecordsProduced != 0L) { + output.WriteRawTag(24); + output.WriteInt64(NumRecordsProduced); + } + if (CurrentWork.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(CurrentWork); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (WorkStarted != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(WorkStarted); + } + if (WorkFinished != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(WorkFinished); + } + if (NumRecordsProduced != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(NumRecordsProduced); + } + if (CurrentWork.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(CurrentWork); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ReaderBaseState other) { + if (other == null) { + return; + } + if (other.WorkStarted != 0L) { + WorkStarted = other.WorkStarted; + } + if (other.WorkFinished != 0L) { + WorkFinished = other.WorkFinished; + } + if (other.NumRecordsProduced != 0L) { + NumRecordsProduced = other.NumRecordsProduced; + } + if (other.CurrentWork.Length != 0) { + CurrentWork = other.CurrentWork; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + WorkStarted = input.ReadInt64(); + break; + } + case 16: { + WorkFinished = input.ReadInt64(); + break; + } + case 24: { + NumRecordsProduced = input.ReadInt64(); + break; + } + case 34: { + CurrentWork = input.ReadBytes(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/ResourceHandle.cs b/src/TensorFlowNET.Core/Protobuf/ResourceHandle.cs index b9c3033cc..77e84cc53 100644 --- a/src/TensorFlowNET.Core/Protobuf/ResourceHandle.cs +++ b/src/TensorFlowNET.Core/Protobuf/ResourceHandle.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: resource_handle.proto +// source: tensorflow/core/framework/resource_handle.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from resource_handle.proto + /// Holder for reflection information generated from tensorflow/core/framework/resource_handle.proto public static partial class ResourceHandleReflection { #region Descriptor - /// File descriptor for resource_handle.proto + /// File descriptor for tensorflow/core/framework/resource_handle.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,17 +24,24 @@ public static partial class ResourceHandleReflection { static ResourceHandleReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "ChVyZXNvdXJjZV9oYW5kbGUucHJvdG8SCnRlbnNvcmZsb3cicgoTUmVzb3Vy", - "Y2VIYW5kbGVQcm90bxIOCgZkZXZpY2UYASABKAkSEQoJY29udGFpbmVyGAIg", - "ASgJEgwKBG5hbWUYAyABKAkSEQoJaGFzaF9jb2RlGAQgASgEEhcKD21heWJl", - "X3R5cGVfbmFtZRgFIAEoCUJuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", - "DlJlc291cmNlSGFuZGxlUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5z", - "b3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3Rv", - "Mw==")); + "Ci90ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3Jlc291cmNlX2hhbmRsZS5w", + "cm90bxIKdGVuc29yZmxvdxosdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90", + "ZW5zb3Jfc2hhcGUucHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsv", + "dHlwZXMucHJvdG8ipQIKE1Jlc291cmNlSGFuZGxlUHJvdG8SDgoGZGV2aWNl", + "GAEgASgJEhEKCWNvbnRhaW5lchgCIAEoCRIMCgRuYW1lGAMgASgJEhEKCWhh", + "c2hfY29kZRgEIAEoBBIXCg9tYXliZV90eXBlX25hbWUYBSABKAkSSAoRZHR5", + "cGVzX2FuZF9zaGFwZXMYBiADKAsyLS50ZW5zb3JmbG93LlJlc291cmNlSGFu", + "ZGxlUHJvdG8uRHR5cGVBbmRTaGFwZRphCg1EdHlwZUFuZFNoYXBlEiMKBWR0", + "eXBlGAEgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZRIrCgVzaGFwZRgCIAEo", + "CzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90b0oECAcQCEKHAQoYb3Jn", + "LnRlbnNvcmZsb3cuZnJhbWV3b3JrQg5SZXNvdXJjZUhhbmRsZVABWlZnaXRo", + "dWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2Nv", + "cmUvZnJhbWV3b3JrL3Jlc291cmNlX2hhbmRsZV9nb19wcm90b/gBAWIGcHJv", + "dG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResourceHandleProto), global::Tensorflow.ResourceHandleProto.Parser, new[]{ "Device", "Container", "Name", "HashCode", "MaybeTypeName" }, null, null, null) + new pbr::FileDescriptor[] { global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResourceHandleProto), global::Tensorflow.ResourceHandleProto.Parser, new[]{ "Device", "Container", "Name", "HashCode", "MaybeTypeName", "DtypesAndShapes" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResourceHandleProto.Types.DtypeAndShape), global::Tensorflow.ResourceHandleProto.Types.DtypeAndShape.Parser, new[]{ "Dtype", "Shape" }, null, null, null, null)}) })); } #endregion @@ -46,23 +53,31 @@ static ResourceHandleReflection() { /// not valid across executions, but can be serialized back and forth from within /// a single run. /// - public sealed partial class ResourceHandleProto : pb::IMessage { + public sealed partial class ResourceHandleProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResourceHandleProto()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.ResourceHandleReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ResourceHandleProto() { OnConstruction(); } @@ -70,16 +85,19 @@ public ResourceHandleProto() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ResourceHandleProto(ResourceHandleProto other) : this() { device_ = other.device_; container_ = other.container_; name_ = other.name_; hashCode_ = other.hashCode_; maybeTypeName_ = other.maybeTypeName_; + dtypesAndShapes_ = other.dtypesAndShapes_.Clone(); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ResourceHandleProto Clone() { return new ResourceHandleProto(this); } @@ -91,6 +109,7 @@ public ResourceHandleProto Clone() { /// Unique name for the device containing the resource. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Device { get { return device_; } set { @@ -105,6 +124,7 @@ public string Device { /// Container in which this resource is placed. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Container { get { return container_; } set { @@ -119,6 +139,7 @@ public string Container { /// Unique name of this resource. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -134,6 +155,7 @@ public string Name { /// and in the same execution. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public ulong HashCode { get { return hashCode_; } set { @@ -149,6 +171,7 @@ public ulong HashCode { /// available. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string MaybeTypeName { get { return maybeTypeName_; } set { @@ -156,12 +179,28 @@ public string MaybeTypeName { } } + /// Field number for the "dtypes_and_shapes" field. + public const int DtypesAndShapesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_dtypesAndShapes_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.ResourceHandleProto.Types.DtypeAndShape.Parser); + private readonly pbc::RepeatedField dtypesAndShapes_ = new pbc::RepeatedField(); + /// + /// Data types and shapes for the underlying resource. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DtypesAndShapes { + get { return dtypesAndShapes_; } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as ResourceHandleProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(ResourceHandleProto other) { if (ReferenceEquals(other, null)) { return false; @@ -174,10 +213,12 @@ public bool Equals(ResourceHandleProto other) { if (Name != other.Name) return false; if (HashCode != other.HashCode) return false; if (MaybeTypeName != other.MaybeTypeName) return false; + if(!dtypesAndShapes_.Equals(other.dtypesAndShapes_)) return false; return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Device.Length != 0) hash ^= Device.GetHashCode(); @@ -185,6 +226,7 @@ public override int GetHashCode() { if (Name.Length != 0) hash ^= Name.GetHashCode(); if (HashCode != 0UL) hash ^= HashCode.GetHashCode(); if (MaybeTypeName.Length != 0) hash ^= MaybeTypeName.GetHashCode(); + hash ^= dtypesAndShapes_.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -192,12 +234,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Device.Length != 0) { output.WriteRawTag(10); output.WriteString(Device); @@ -218,12 +265,46 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(42); output.WriteString(MaybeTypeName); } + dtypesAndShapes_.WriteTo(output, _repeated_dtypesAndShapes_codec); if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + if (Container.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Container); + } + if (Name.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Name); + } + if (HashCode != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(HashCode); + } + if (MaybeTypeName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(MaybeTypeName); + } + dtypesAndShapes_.WriteTo(ref output, _repeated_dtypesAndShapes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Device.Length != 0) { @@ -241,6 +322,7 @@ public int CalculateSize() { if (MaybeTypeName.Length != 0) { size += 1 + pb::CodedOutputStream.ComputeStringSize(MaybeTypeName); } + size += dtypesAndShapes_.CalculateSize(_repeated_dtypesAndShapes_codec); if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -248,6 +330,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(ResourceHandleProto other) { if (other == null) { return; @@ -267,11 +350,16 @@ public void MergeFrom(ResourceHandleProto other) { if (other.MaybeTypeName.Length != 0) { MaybeTypeName = other.MaybeTypeName; } + dtypesAndShapes_.Add(other.dtypesAndShapes_); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -298,10 +386,300 @@ public void MergeFrom(pb::CodedInputStream input) { MaybeTypeName = input.ReadString(); break; } + case 50: { + dtypesAndShapes_.AddEntriesFrom(input, _repeated_dtypesAndShapes_codec); + break; + } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Device = input.ReadString(); + break; + } + case 18: { + Container = input.ReadString(); + break; + } + case 26: { + Name = input.ReadString(); + break; + } + case 32: { + HashCode = input.ReadUInt64(); + break; + } + case 42: { + MaybeTypeName = input.ReadString(); + break; + } + case 50: { + dtypesAndShapes_.AddEntriesFrom(ref input, _repeated_dtypesAndShapes_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the ResourceHandleProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Protocol buffer representing a pair of (data type, tensor shape). + /// + public sealed partial class DtypeAndShape : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DtypeAndShape()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ResourceHandleProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DtypeAndShape() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DtypeAndShape(DtypeAndShape other) : this() { + dtype_ = other.dtype_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DtypeAndShape Clone() { + return new DtypeAndShape(this); + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 1; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DtypeAndShape); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DtypeAndShape other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Dtype != other.Dtype) return false; + if (!object.Equals(Shape, other.Shape)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DtypeAndShape other) { + if (other == null) { + return; + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + } + } + } + #endif + + } + + } + #endregion + } #endregion diff --git a/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs b/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs new file mode 100644 index 000000000..eae000206 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs @@ -0,0 +1,2491 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/rewriter_config.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/rewriter_config.proto + public static partial class RewriterConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/rewriter_config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static RewriterConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci50ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvcmV3cml0ZXJfY29uZmlnLnBy", + "b3RvEgp0ZW5zb3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2F0", + "dHJfdmFsdWUucHJvdG8aLnRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi92ZXJp", + "Zmllcl9jb25maWcucHJvdG8iOwoTQXV0b1BhcmFsbGVsT3B0aW9ucxIOCgZl", + "bmFibGUYASABKAgSFAoMbnVtX3JlcGxpY2FzGAIgASgFIisKFlNjb3BlZEFs", + "bG9jYXRvck9wdGlvbnMSEQoJZW5hYmxlX29wGAEgAygJIvYVCg5SZXdyaXRl", + "ckNvbmZpZxJDChVjcHVfbGF5b3V0X2NvbnZlcnNpb24YMiABKA4yJC50ZW5z", + "b3JmbG93LlJld3JpdGVyQ29uZmlnLkNwdUxheW91dBI7ChBsYXlvdXRfb3B0", + "aW1pemVyGAEgASgOMiEudGVuc29yZmxvdy5SZXdyaXRlckNvbmZpZy5Ub2dn", + "bGUSOwoQY29uc3RhbnRfZm9sZGluZxgDIAEoDjIhLnRlbnNvcmZsb3cuUmV3", + "cml0ZXJDb25maWcuVG9nZ2xlEj0KEnNoYXBlX29wdGltaXphdGlvbhgNIAEo", + "DjIhLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xlEjQKCXJlbWFw", + "cGluZxgOIAEoDjIhLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xl", + "EkYKG2NvbW1vbl9zdWJncmFwaF9lbGltaW5hdGlvbhgYIAEoDjIhLnRlbnNv", + "cmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xlEkIKF2FyaXRobWV0aWNfb3B0", + "aW1pemF0aW9uGAcgASgOMiEudGVuc29yZmxvdy5SZXdyaXRlckNvbmZpZy5U", + "b2dnbGUSQgoXZGVwZW5kZW5jeV9vcHRpbWl6YXRpb24YCCABKA4yIS50ZW5z", + "b3JmbG93LlJld3JpdGVyQ29uZmlnLlRvZ2dsZRI8ChFsb29wX29wdGltaXph", + "dGlvbhgJIAEoDjIhLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xl", + "EkAKFWZ1bmN0aW9uX29wdGltaXphdGlvbhgKIAEoDjIhLnRlbnNvcmZsb3cu", + "UmV3cml0ZXJDb25maWcuVG9nZ2xlEjkKDmRlYnVnX3N0cmlwcGVyGAsgASgO", + "MiEudGVuc29yZmxvdy5SZXdyaXRlckNvbmZpZy5Ub2dnbGUSHQoVZGlzYWJs", + "ZV9tb2RlbF9wcnVuaW5nGAIgASgIEkgKHXNjb3BlZF9hbGxvY2F0b3Jfb3B0", + "aW1pemF0aW9uGA8gASgOMiEudGVuc29yZmxvdy5SZXdyaXRlckNvbmZpZy5U", + "b2dnbGUSQwoYcGluX3RvX2hvc3Rfb3B0aW1pemF0aW9uGBIgASgOMiEudGVu", + "c29yZmxvdy5SZXdyaXRlckNvbmZpZy5Ub2dnbGUSQgoXaW1wbGVtZW50YXRp", + "b25fc2VsZWN0b3IYFiABKA4yIS50ZW5zb3JmbG93LlJld3JpdGVyQ29uZmln", + "LlRvZ2dsZRI/ChRhdXRvX21peGVkX3ByZWNpc2lvbhgXIAEoDjIhLnRlbnNv", + "cmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xlEkMKGGF1dG9fbWl4ZWRfcHJl", + "Y2lzaW9uX21rbBgZIAEoDjIhLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcu", + "VG9nZ2xlEk8KJGF1dG9fbWl4ZWRfcHJlY2lzaW9uX29uZWRubl9iZmxvYXQx", + "NhgfIAEoDjIhLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuVG9nZ2xlEkMK", + "GGF1dG9fbWl4ZWRfcHJlY2lzaW9uX2NwdRgdIAEoDjIhLnRlbnNvcmZsb3cu", + "UmV3cml0ZXJDb25maWcuVG9nZ2xlEh4KFmRpc2FibGVfbWV0YV9vcHRpbWl6", + "ZXIYEyABKAgSQAoVdXNlX3BsdWdpbl9vcHRpbWl6ZXJzGBwgASgOMiEudGVu", + "c29yZmxvdy5SZXdyaXRlckNvbmZpZy5Ub2dnbGUSTwokZXhwZXJpbWVudGFs", + "X2NvbmRpdGlvbmFsX2NvZGVfbW90aW9uGB4gASgOMiEudGVuc29yZmxvdy5S", + "ZXdyaXRlckNvbmZpZy5Ub2dnbGUSTwoZbWV0YV9vcHRpbWl6ZXJfaXRlcmF0", + "aW9ucxgMIAEoDjIsLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuTnVtSXRl", + "cmF0aW9uc1R5cGUSFwoPbWluX2dyYXBoX25vZGVzGBEgASgFEjsKM2V4cGVy", + "aW1lbnRhbF9kaXNhYmxlX2NvbXByZXNzZWRfdGVuc29yX29wdGltaXphdGlv", + "bhgaIAEoCBI7CjNleHBlcmltZW50YWxfZGlzYWJsZV9mb2xkaW5nX3F1YW50", + "aXphdGlvbl9lbXVsYXRpb24YGyABKAgSQgoTbWVtb3J5X29wdGltaXphdGlv", + "bhgEIAEoDjIlLnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWcuTWVtT3B0VHlw", + "ZRIvCidtZW1vcnlfb3B0aW1pemVyX3RhcmdldF9ub2RlX25hbWVfc2NvcGUY", + "BiABKAkSIQoZbWV0YV9vcHRpbWl6ZXJfdGltZW91dF9tcxgUIAEoAxI2Cg1h", + "dXRvX3BhcmFsbGVsGAUgASgLMh8udGVuc29yZmxvdy5BdXRvUGFyYWxsZWxP", + "cHRpb25zEiAKGGZhaWxfb25fb3B0aW1pemVyX2Vycm9ycxgVIAEoCBJBChVz", + "Y29wZWRfYWxsb2NhdG9yX29wdHMYECABKAsyIi50ZW5zb3JmbG93LlNjb3Bl", + "ZEFsbG9jYXRvck9wdGlvbnMSEgoKb3B0aW1pemVycxhkIAMoCRJLChFjdXN0", + "b21fb3B0aW1pemVycxjIASADKAsyLy50ZW5zb3JmbG93LlJld3JpdGVyQ29u", + "ZmlnLkN1c3RvbUdyYXBoT3B0aW1pemVyEkQKH2ludGVyX29wdGltaXplcl92", + "ZXJpZmllcl9jb25maWcYrAIgASgLMhoudGVuc29yZmxvdy5WZXJpZmllckNv", + "bmZpZxJGCiFwb3N0X29wdGltaXphdGlvbl92ZXJpZmllcl9jb25maWcYrQIg", + "ASgLMhoudGVuc29yZmxvdy5WZXJpZmllckNvbmZpZxrKAQoUQ3VzdG9tR3Jh", + "cGhPcHRpbWl6ZXISDAoEbmFtZRgBIAEoCRJYCg1wYXJhbWV0ZXJfbWFwGAIg", + "AygLMkEudGVuc29yZmxvdy5SZXdyaXRlckNvbmZpZy5DdXN0b21HcmFwaE9w", + "dGltaXplci5QYXJhbWV0ZXJNYXBFbnRyeRpKChFQYXJhbWV0ZXJNYXBFbnRy", + "eRILCgNrZXkYASABKAkSJAoFdmFsdWUYAiABKAsyFS50ZW5zb3JmbG93LkF0", + "dHJWYWx1ZToCOAEiZAoGVG9nZ2xlEgsKB0RFRkFVTFQQABIGCgJPThABEgcK", + "A09GRhACEg4KCkFHR1JFU1NJVkUQAxIVChFFWFBFUklNRU5UQUxfTUxJUhAE", + "EhUKEUVYUEVSSU1FTlRBTF9CT1RIEAUiSQoJQ3B1TGF5b3V0EhgKFE5PX0NP", + "TlZFUlNJT05fT05fQ1BVEAASEAoMTkNIV19UT19OSFdDEAESEAoMTkhXQ19U", + "T19OQ0hXEAIiPAoRTnVtSXRlcmF0aW9uc1R5cGUSFQoRREVGQVVMVF9OVU1f", + "SVRFUlMQABIHCgNPTkUQARIHCgNUV08QAiKfAQoKTWVtT3B0VHlwZRITCg9E", + "RUZBVUxUX01FTV9PUFQQABIOCgpOT19NRU1fT1BUEAESCgoGTUFOVUFMEAIS", + "FwoTU1dBUFBJTkdfSEVVUklTVElDUxAEEhwKGFJFQ09NUFVUQVRJT05fSEVV", + "UklTVElDUxAFEhkKFVNDSEVEVUxJTkdfSEVVUklTVElDUxAGEg4KCkhFVVJJ", + "U1RJQ1MQA0KMAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQhRSZXdyaXRl", + "ckNvbmZpZ1Byb3Rvc1ABWlVnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29y", + "Zmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvcHJvdG9idWYvZm9yX2NvcmVfcHJv", + "dG9zX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.VerifierConfigReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AutoParallelOptions), global::Tensorflow.AutoParallelOptions.Parser, new[]{ "Enable", "NumReplicas" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ScopedAllocatorOptions), global::Tensorflow.ScopedAllocatorOptions.Parser, new[]{ "EnableOp" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RewriterConfig), global::Tensorflow.RewriterConfig.Parser, new[]{ "CpuLayoutConversion", "LayoutOptimizer", "ConstantFolding", "ShapeOptimization", "Remapping", "CommonSubgraphElimination", "ArithmeticOptimization", "DependencyOptimization", "LoopOptimization", "FunctionOptimization", "DebugStripper", "DisableModelPruning", "ScopedAllocatorOptimization", "PinToHostOptimization", "ImplementationSelector", "AutoMixedPrecision", "AutoMixedPrecisionMkl", "AutoMixedPrecisionOnednnBfloat16", "AutoMixedPrecisionCpu", "DisableMetaOptimizer", "UsePluginOptimizers", "ExperimentalConditionalCodeMotion", "MetaOptimizerIterations", "MinGraphNodes", "ExperimentalDisableCompressedTensorOptimization", "ExperimentalDisableFoldingQuantizationEmulation", "MemoryOptimization", "MemoryOptimizerTargetNodeNameScope", "MetaOptimizerTimeoutMs", "AutoParallel", "FailOnOptimizerErrors", "ScopedAllocatorOpts", "Optimizers", "CustomOptimizers", "InterOptimizerVerifierConfig", "PostOptimizationVerifierConfig" }, null, new[]{ typeof(global::Tensorflow.RewriterConfig.Types.Toggle), typeof(global::Tensorflow.RewriterConfig.Types.CpuLayout), typeof(global::Tensorflow.RewriterConfig.Types.NumIterationsType), typeof(global::Tensorflow.RewriterConfig.Types.MemOptType) }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RewriterConfig.Types.CustomGraphOptimizer), global::Tensorflow.RewriterConfig.Types.CustomGraphOptimizer.Parser, new[]{ "Name", "ParameterMap" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, })}) + })); + } + #endregion + + } + #region Messages + public sealed partial class AutoParallelOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AutoParallelOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RewriterConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AutoParallelOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AutoParallelOptions(AutoParallelOptions other) : this() { + enable_ = other.enable_; + numReplicas_ = other.numReplicas_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AutoParallelOptions Clone() { + return new AutoParallelOptions(this); + } + + /// Field number for the "enable" field. + public const int EnableFieldNumber = 1; + private bool enable_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Enable { + get { return enable_; } + set { + enable_ = value; + } + } + + /// Field number for the "num_replicas" field. + public const int NumReplicasFieldNumber = 2; + private int numReplicas_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumReplicas { + get { return numReplicas_; } + set { + numReplicas_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AutoParallelOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AutoParallelOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Enable != other.Enable) return false; + if (NumReplicas != other.NumReplicas) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Enable != false) hash ^= Enable.GetHashCode(); + if (NumReplicas != 0) hash ^= NumReplicas.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Enable != false) { + output.WriteRawTag(8); + output.WriteBool(Enable); + } + if (NumReplicas != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumReplicas); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Enable != false) { + output.WriteRawTag(8); + output.WriteBool(Enable); + } + if (NumReplicas != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumReplicas); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Enable != false) { + size += 1 + 1; + } + if (NumReplicas != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumReplicas); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AutoParallelOptions other) { + if (other == null) { + return; + } + if (other.Enable != false) { + Enable = other.Enable; + } + if (other.NumReplicas != 0) { + NumReplicas = other.NumReplicas; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Enable = input.ReadBool(); + break; + } + case 16: { + NumReplicas = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Enable = input.ReadBool(); + break; + } + case 16: { + NumReplicas = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + public sealed partial class ScopedAllocatorOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ScopedAllocatorOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RewriterConfigReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScopedAllocatorOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScopedAllocatorOptions(ScopedAllocatorOptions other) : this() { + enableOp_ = other.enableOp_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScopedAllocatorOptions Clone() { + return new ScopedAllocatorOptions(this); + } + + /// Field number for the "enable_op" field. + public const int EnableOpFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_enableOp_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField enableOp_ = new pbc::RepeatedField(); + /// + /// If present, only perform optimization for these ops. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField EnableOp { + get { return enableOp_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ScopedAllocatorOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ScopedAllocatorOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!enableOp_.Equals(other.enableOp_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= enableOp_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + enableOp_.WriteTo(output, _repeated_enableOp_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + enableOp_.WriteTo(ref output, _repeated_enableOp_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += enableOp_.CalculateSize(_repeated_enableOp_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ScopedAllocatorOptions other) { + if (other == null) { + return; + } + enableOp_.Add(other.enableOp_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + enableOp_.AddEntriesFrom(input, _repeated_enableOp_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + enableOp_.AddEntriesFrom(ref input, _repeated_enableOp_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Graph rewriting is experimental and subject to change, not covered by any + /// API stability guarantees. + /// + public sealed partial class RewriterConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RewriterConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RewriterConfigReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RewriterConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RewriterConfig(RewriterConfig other) : this() { + cpuLayoutConversion_ = other.cpuLayoutConversion_; + layoutOptimizer_ = other.layoutOptimizer_; + constantFolding_ = other.constantFolding_; + shapeOptimization_ = other.shapeOptimization_; + remapping_ = other.remapping_; + commonSubgraphElimination_ = other.commonSubgraphElimination_; + arithmeticOptimization_ = other.arithmeticOptimization_; + dependencyOptimization_ = other.dependencyOptimization_; + loopOptimization_ = other.loopOptimization_; + functionOptimization_ = other.functionOptimization_; + debugStripper_ = other.debugStripper_; + disableModelPruning_ = other.disableModelPruning_; + scopedAllocatorOptimization_ = other.scopedAllocatorOptimization_; + pinToHostOptimization_ = other.pinToHostOptimization_; + implementationSelector_ = other.implementationSelector_; + autoMixedPrecision_ = other.autoMixedPrecision_; + autoMixedPrecisionMkl_ = other.autoMixedPrecisionMkl_; + autoMixedPrecisionOnednnBfloat16_ = other.autoMixedPrecisionOnednnBfloat16_; + autoMixedPrecisionCpu_ = other.autoMixedPrecisionCpu_; + disableMetaOptimizer_ = other.disableMetaOptimizer_; + usePluginOptimizers_ = other.usePluginOptimizers_; + experimentalConditionalCodeMotion_ = other.experimentalConditionalCodeMotion_; + metaOptimizerIterations_ = other.metaOptimizerIterations_; + minGraphNodes_ = other.minGraphNodes_; + experimentalDisableCompressedTensorOptimization_ = other.experimentalDisableCompressedTensorOptimization_; + experimentalDisableFoldingQuantizationEmulation_ = other.experimentalDisableFoldingQuantizationEmulation_; + memoryOptimization_ = other.memoryOptimization_; + memoryOptimizerTargetNodeNameScope_ = other.memoryOptimizerTargetNodeNameScope_; + metaOptimizerTimeoutMs_ = other.metaOptimizerTimeoutMs_; + autoParallel_ = other.autoParallel_ != null ? other.autoParallel_.Clone() : null; + failOnOptimizerErrors_ = other.failOnOptimizerErrors_; + scopedAllocatorOpts_ = other.scopedAllocatorOpts_ != null ? other.scopedAllocatorOpts_.Clone() : null; + optimizers_ = other.optimizers_.Clone(); + customOptimizers_ = other.customOptimizers_.Clone(); + interOptimizerVerifierConfig_ = other.interOptimizerVerifierConfig_ != null ? other.interOptimizerVerifierConfig_.Clone() : null; + postOptimizationVerifierConfig_ = other.postOptimizationVerifierConfig_ != null ? other.postOptimizationVerifierConfig_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RewriterConfig Clone() { + return new RewriterConfig(this); + } + + /// Field number for the "cpu_layout_conversion" field. + public const int CpuLayoutConversionFieldNumber = 50; + private global::Tensorflow.RewriterConfig.Types.CpuLayout cpuLayoutConversion_ = global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu; + /// + /// CPU Conversion settings between NHCW and NCHW. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.CpuLayout CpuLayoutConversion { + get { return cpuLayoutConversion_; } + set { + cpuLayoutConversion_ = value; + } + } + + /// Field number for the "layout_optimizer" field. + public const int LayoutOptimizerFieldNumber = 1; + private global::Tensorflow.RewriterConfig.Types.Toggle layoutOptimizer_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Optimize tensor layouts (default is ON) + /// e.g. This will try to use NCHW layout on GPU which is faster. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle LayoutOptimizer { + get { return layoutOptimizer_; } + set { + layoutOptimizer_ = value; + } + } + + /// Field number for the "constant_folding" field. + public const int ConstantFoldingFieldNumber = 3; + private global::Tensorflow.RewriterConfig.Types.Toggle constantFolding_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Fold constants (default is ON) + /// Statically infer the value of tensors when possible, and materialize the + /// result using constants. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ConstantFolding { + get { return constantFolding_; } + set { + constantFolding_ = value; + } + } + + /// Field number for the "shape_optimization" field. + public const int ShapeOptimizationFieldNumber = 13; + private global::Tensorflow.RewriterConfig.Types.Toggle shapeOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Shape optimizations (default is ON) + /// Simplify computations made on shapes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ShapeOptimization { + get { return shapeOptimization_; } + set { + shapeOptimization_ = value; + } + } + + /// Field number for the "remapping" field. + public const int RemappingFieldNumber = 14; + private global::Tensorflow.RewriterConfig.Types.Toggle remapping_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Remapping (default is ON) + /// Remap subgraphs onto more efficient implementations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle Remapping { + get { return remapping_; } + set { + remapping_ = value; + } + } + + /// Field number for the "common_subgraph_elimination" field. + public const int CommonSubgraphEliminationFieldNumber = 24; + private global::Tensorflow.RewriterConfig.Types.Toggle commonSubgraphElimination_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Common subgraph elimination (default is ON) + /// e.g. Simplify arithmetic ops; merge ops with same value (like constants). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle CommonSubgraphElimination { + get { return commonSubgraphElimination_; } + set { + commonSubgraphElimination_ = value; + } + } + + /// Field number for the "arithmetic_optimization" field. + public const int ArithmeticOptimizationFieldNumber = 7; + private global::Tensorflow.RewriterConfig.Types.Toggle arithmeticOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Arithmetic optimizations (default is ON) + /// e.g. Simplify arithmetic ops; merge ops with same value (like constants). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ArithmeticOptimization { + get { return arithmeticOptimization_; } + set { + arithmeticOptimization_ = value; + } + } + + /// Field number for the "dependency_optimization" field. + public const int DependencyOptimizationFieldNumber = 8; + private global::Tensorflow.RewriterConfig.Types.Toggle dependencyOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Control dependency optimizations (default is ON). + /// Remove redundant control dependencies, which may enable other optimization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle DependencyOptimization { + get { return dependencyOptimization_; } + set { + dependencyOptimization_ = value; + } + } + + /// Field number for the "loop_optimization" field. + public const int LoopOptimizationFieldNumber = 9; + private global::Tensorflow.RewriterConfig.Types.Toggle loopOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Loop optimizations (default is ON). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle LoopOptimization { + get { return loopOptimization_; } + set { + loopOptimization_ = value; + } + } + + /// Field number for the "function_optimization" field. + public const int FunctionOptimizationFieldNumber = 10; + private global::Tensorflow.RewriterConfig.Types.Toggle functionOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Function optimizations (default is ON). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle FunctionOptimization { + get { return functionOptimization_; } + set { + functionOptimization_ = value; + } + } + + /// Field number for the "debug_stripper" field. + public const int DebugStripperFieldNumber = 11; + private global::Tensorflow.RewriterConfig.Types.Toggle debugStripper_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Strips debug-related nodes from the graph (off by default). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle DebugStripper { + get { return debugStripper_; } + set { + debugStripper_ = value; + } + } + + /// Field number for the "disable_model_pruning" field. + public const int DisableModelPruningFieldNumber = 2; + private bool disableModelPruning_; + /// + /// If true, don't remove unnecessary ops from the graph + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableModelPruning { + get { return disableModelPruning_; } + set { + disableModelPruning_ = value; + } + } + + /// Field number for the "scoped_allocator_optimization" field. + public const int ScopedAllocatorOptimizationFieldNumber = 15; + private global::Tensorflow.RewriterConfig.Types.Toggle scopedAllocatorOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Try to allocate some independent Op outputs contiguously in order to + /// merge or eliminate downstream Ops (off by default). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ScopedAllocatorOptimization { + get { return scopedAllocatorOptimization_; } + set { + scopedAllocatorOptimization_ = value; + } + } + + /// Field number for the "pin_to_host_optimization" field. + public const int PinToHostOptimizationFieldNumber = 18; + private global::Tensorflow.RewriterConfig.Types.Toggle pinToHostOptimization_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Force small ops onto the CPU (default is OFF). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle PinToHostOptimization { + get { return pinToHostOptimization_; } + set { + pinToHostOptimization_ = value; + } + } + + /// Field number for the "implementation_selector" field. + public const int ImplementationSelectorFieldNumber = 22; + private global::Tensorflow.RewriterConfig.Types.Toggle implementationSelector_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Enable the swap of kernel implementations based on the device placement + /// (default is ON). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ImplementationSelector { + get { return implementationSelector_; } + set { + implementationSelector_ = value; + } + } + + /// Field number for the "auto_mixed_precision" field. + public const int AutoMixedPrecisionFieldNumber = 23; + private global::Tensorflow.RewriterConfig.Types.Toggle autoMixedPrecision_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Optimize data types for CUDA (default is OFF). + /// This will try to use float16 on GPU which is faster. + /// Note that this can change the numerical stability of the graph and may + /// require the use of loss scaling to maintain model convergence. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle AutoMixedPrecision { + get { return autoMixedPrecision_; } + set { + autoMixedPrecision_ = value; + } + } + + /// Field number for the "auto_mixed_precision_mkl" field. + public const int AutoMixedPrecisionMklFieldNumber = 25; + private global::Tensorflow.RewriterConfig.Types.Toggle autoMixedPrecisionMkl_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Optimize data types for oneDNN (default is OFF). + /// This will try to use bfloat16 on CPUs, which is faster. + /// Note that this can change the numerical stability of the graph. + /// Note: this is deprecated. + /// It is replaced by auto_mixed_precision_onednn_bfloat16 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle AutoMixedPrecisionMkl { + get { return autoMixedPrecisionMkl_; } + set { + autoMixedPrecisionMkl_ = value; + } + } + + /// Field number for the "auto_mixed_precision_onednn_bfloat16" field. + public const int AutoMixedPrecisionOnednnBfloat16FieldNumber = 31; + private global::Tensorflow.RewriterConfig.Types.Toggle autoMixedPrecisionOnednnBfloat16_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Optimize data types for oneDNN (default is OFF). + /// This will try to use bfloat16 on CPUs, which is faster. + /// Note that this can change the numerical stability of the graph. + /// Note: this is equivalent to the deprecated option auto_mixed_precision_mkl + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle AutoMixedPrecisionOnednnBfloat16 { + get { return autoMixedPrecisionOnednnBfloat16_; } + set { + autoMixedPrecisionOnednnBfloat16_ = value; + } + } + + /// Field number for the "auto_mixed_precision_cpu" field. + public const int AutoMixedPrecisionCpuFieldNumber = 29; + private global::Tensorflow.RewriterConfig.Types.Toggle autoMixedPrecisionCpu_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Emulate a model using data type float16 on CPU (default is OFF). + /// This will try to emulate the float16 inputs and outputs of an operator + /// on CPU to have better correlation with float16 on GPU; however the + /// computation in the operator is based on float32. + /// Note that this can change the numerical stability of the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle AutoMixedPrecisionCpu { + get { return autoMixedPrecisionCpu_; } + set { + autoMixedPrecisionCpu_ = value; + } + } + + /// Field number for the "disable_meta_optimizer" field. + public const int DisableMetaOptimizerFieldNumber = 19; + private bool disableMetaOptimizer_; + /// + /// Disable the entire meta optimizer (off by default). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DisableMetaOptimizer { + get { return disableMetaOptimizer_; } + set { + disableMetaOptimizer_ = value; + } + } + + /// Field number for the "use_plugin_optimizers" field. + public const int UsePluginOptimizersFieldNumber = 28; + private global::Tensorflow.RewriterConfig.Types.Toggle usePluginOptimizers_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Optimizers registered by plugin (default is ON) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle UsePluginOptimizers { + get { return usePluginOptimizers_; } + set { + usePluginOptimizers_ = value; + } + } + + /// Field number for the "experimental_conditional_code_motion" field. + public const int ExperimentalConditionalCodeMotionFieldNumber = 30; + private global::Tensorflow.RewriterConfig.Types.Toggle experimentalConditionalCodeMotion_ = global::Tensorflow.RewriterConfig.Types.Toggle.Default; + /// + /// Conditional code motion (default is ON). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.Toggle ExperimentalConditionalCodeMotion { + get { return experimentalConditionalCodeMotion_; } + set { + experimentalConditionalCodeMotion_ = value; + } + } + + /// Field number for the "meta_optimizer_iterations" field. + public const int MetaOptimizerIterationsFieldNumber = 12; + private global::Tensorflow.RewriterConfig.Types.NumIterationsType metaOptimizerIterations_ = global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters; + /// + /// Controls how many times we run the optimizers in meta optimizer (default + /// is once). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.NumIterationsType MetaOptimizerIterations { + get { return metaOptimizerIterations_; } + set { + metaOptimizerIterations_ = value; + } + } + + /// Field number for the "min_graph_nodes" field. + public const int MinGraphNodesFieldNumber = 17; + private int minGraphNodes_; + /// + /// The minimum number of nodes in a graph to optimizer. For smaller graphs, + /// optimization is skipped. + /// 0 means the system picks an appropriate number. + /// < 0 means do not skip optimization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int MinGraphNodes { + get { return minGraphNodes_; } + set { + minGraphNodes_ = value; + } + } + + /// Field number for the "experimental_disable_compressed_tensor_optimization" field. + public const int ExperimentalDisableCompressedTensorOptimizationFieldNumber = 26; + private bool experimentalDisableCompressedTensorOptimization_; + /// + /// Disable optimizations that assume compressed tensors. Note that this flag + /// is experimental and may be removed in the future. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ExperimentalDisableCompressedTensorOptimization { + get { return experimentalDisableCompressedTensorOptimization_; } + set { + experimentalDisableCompressedTensorOptimization_ = value; + } + } + + /// Field number for the "experimental_disable_folding_quantization_emulation" field. + public const int ExperimentalDisableFoldingQuantizationEmulationFieldNumber = 27; + private bool experimentalDisableFoldingQuantizationEmulation_; + /// + /// Disable folding quantization emulation ops such as FakeQuantWithMinMax* and + /// QuantizeAndDequantize*. Some compilers (e.g. the TF-to-tflite converter) + /// have to extract quantization configs (e.g. min/max range, number of bits, + /// and per-channel) from the quantization emulation ops. Note that this flag + /// is experimental and may be removed in the future. See b/174138564 for more + /// details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ExperimentalDisableFoldingQuantizationEmulation { + get { return experimentalDisableFoldingQuantizationEmulation_; } + set { + experimentalDisableFoldingQuantizationEmulation_ = value; + } + } + + /// Field number for the "memory_optimization" field. + public const int MemoryOptimizationFieldNumber = 4; + private global::Tensorflow.RewriterConfig.Types.MemOptType memoryOptimization_ = global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt; + /// + /// Configures memory optimization passes through the meta-optimizer. Has no + /// effect on manually requested memory optimization passes in the optimizers + /// field. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RewriterConfig.Types.MemOptType MemoryOptimization { + get { return memoryOptimization_; } + set { + memoryOptimization_ = value; + } + } + + /// Field number for the "memory_optimizer_target_node_name_scope" field. + public const int MemoryOptimizerTargetNodeNameScopeFieldNumber = 6; + private string memoryOptimizerTargetNodeNameScope_ = ""; + /// + /// A node name scope for node names which are valid outputs of recomputations. + /// Inputs to nodes that match this scope may be recomputed (subject either to + /// manual annotation of those input nodes or to manual annotation and + /// heuristics depending on memory_optimization), but the nodes themselves will + /// not be recomputed. This matches any sub-scopes as well, meaning the scope + /// can appear not just as a top-level scope. For example, if the value is + /// "gradients/", the default, it will match node name "gradients/foo", + /// "foo/gradients/bar", but not "foo_gradients/" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string MemoryOptimizerTargetNodeNameScope { + get { return memoryOptimizerTargetNodeNameScope_; } + set { + memoryOptimizerTargetNodeNameScope_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "meta_optimizer_timeout_ms" field. + public const int MetaOptimizerTimeoutMsFieldNumber = 20; + private long metaOptimizerTimeoutMs_; + /// + /// Maximum number of milliseconds to spend optimizing a single graph before + /// timing out. If less than or equal to 0 (default value) the optimizer will + /// never time out. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long MetaOptimizerTimeoutMs { + get { return metaOptimizerTimeoutMs_; } + set { + metaOptimizerTimeoutMs_ = value; + } + } + + /// Field number for the "auto_parallel" field. + public const int AutoParallelFieldNumber = 5; + private global::Tensorflow.AutoParallelOptions autoParallel_; + /// + /// Configures AutoParallel optimization passes either through the + /// meta-optimizer or when manually specified through the optimizers field. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.AutoParallelOptions AutoParallel { + get { return autoParallel_; } + set { + autoParallel_ = value; + } + } + + /// Field number for the "fail_on_optimizer_errors" field. + public const int FailOnOptimizerErrorsFieldNumber = 21; + private bool failOnOptimizerErrors_; + /// + /// If true, any optimization pass failing will cause the MetaOptimizer to + /// stop with an error. By default - or when set to false, failing passes are + /// skipped silently. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool FailOnOptimizerErrors { + get { return failOnOptimizerErrors_; } + set { + failOnOptimizerErrors_ = value; + } + } + + /// Field number for the "scoped_allocator_opts" field. + public const int ScopedAllocatorOptsFieldNumber = 16; + private global::Tensorflow.ScopedAllocatorOptions scopedAllocatorOpts_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ScopedAllocatorOptions ScopedAllocatorOpts { + get { return scopedAllocatorOpts_; } + set { + scopedAllocatorOpts_ = value; + } + } + + /// Field number for the "optimizers" field. + public const int OptimizersFieldNumber = 100; + private static readonly pb::FieldCodec _repeated_optimizers_codec + = pb::FieldCodec.ForString(802); + private readonly pbc::RepeatedField optimizers_ = new pbc::RepeatedField(); + /// + /// If non-empty, will use this as an alternative way to specify a list of + /// optimizations to turn on and the order of the optimizations (replacing the + /// meta-optimizer). + /// + /// Of the RewriterConfig options, only the AutoParallel configuration options + /// (the auto_parallel field) apply to manually requested optimization passes + /// ("autoparallel"). Memory optimization passes ("memory") invoked here are + /// not configurable (in contrast to memory optimization passes through the + /// meta-optimizer) and act only on manual op annotations. + /// + /// Custom optimizers (see custom_optimizers) that are not part of this + /// schedule will be run after - in the order that they were specified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Optimizers { + get { return optimizers_; } + } + + /// Field number for the "custom_optimizers" field. + public const int CustomOptimizersFieldNumber = 200; + private static readonly pb::FieldCodec _repeated_customOptimizers_codec + = pb::FieldCodec.ForMessage(1602, global::Tensorflow.RewriterConfig.Types.CustomGraphOptimizer.Parser); + private readonly pbc::RepeatedField customOptimizers_ = new pbc::RepeatedField(); + /// + /// list of CustomGraphOptimizers to apply. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CustomOptimizers { + get { return customOptimizers_; } + } + + /// Field number for the "inter_optimizer_verifier_config" field. + public const int InterOptimizerVerifierConfigFieldNumber = 300; + private global::Tensorflow.VerifierConfig interOptimizerVerifierConfig_; + /// + /// VerifierConfig specifying the verifiers to be run after every optimizer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VerifierConfig InterOptimizerVerifierConfig { + get { return interOptimizerVerifierConfig_; } + set { + interOptimizerVerifierConfig_ = value; + } + } + + /// Field number for the "post_optimization_verifier_config" field. + public const int PostOptimizationVerifierConfigFieldNumber = 301; + private global::Tensorflow.VerifierConfig postOptimizationVerifierConfig_; + /// + /// VerifierConfig specifying the verifiers to be run at the end, after all + /// optimizers have run. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VerifierConfig PostOptimizationVerifierConfig { + get { return postOptimizationVerifierConfig_; } + set { + postOptimizationVerifierConfig_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RewriterConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RewriterConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (CpuLayoutConversion != other.CpuLayoutConversion) return false; + if (LayoutOptimizer != other.LayoutOptimizer) return false; + if (ConstantFolding != other.ConstantFolding) return false; + if (ShapeOptimization != other.ShapeOptimization) return false; + if (Remapping != other.Remapping) return false; + if (CommonSubgraphElimination != other.CommonSubgraphElimination) return false; + if (ArithmeticOptimization != other.ArithmeticOptimization) return false; + if (DependencyOptimization != other.DependencyOptimization) return false; + if (LoopOptimization != other.LoopOptimization) return false; + if (FunctionOptimization != other.FunctionOptimization) return false; + if (DebugStripper != other.DebugStripper) return false; + if (DisableModelPruning != other.DisableModelPruning) return false; + if (ScopedAllocatorOptimization != other.ScopedAllocatorOptimization) return false; + if (PinToHostOptimization != other.PinToHostOptimization) return false; + if (ImplementationSelector != other.ImplementationSelector) return false; + if (AutoMixedPrecision != other.AutoMixedPrecision) return false; + if (AutoMixedPrecisionMkl != other.AutoMixedPrecisionMkl) return false; + if (AutoMixedPrecisionOnednnBfloat16 != other.AutoMixedPrecisionOnednnBfloat16) return false; + if (AutoMixedPrecisionCpu != other.AutoMixedPrecisionCpu) return false; + if (DisableMetaOptimizer != other.DisableMetaOptimizer) return false; + if (UsePluginOptimizers != other.UsePluginOptimizers) return false; + if (ExperimentalConditionalCodeMotion != other.ExperimentalConditionalCodeMotion) return false; + if (MetaOptimizerIterations != other.MetaOptimizerIterations) return false; + if (MinGraphNodes != other.MinGraphNodes) return false; + if (ExperimentalDisableCompressedTensorOptimization != other.ExperimentalDisableCompressedTensorOptimization) return false; + if (ExperimentalDisableFoldingQuantizationEmulation != other.ExperimentalDisableFoldingQuantizationEmulation) return false; + if (MemoryOptimization != other.MemoryOptimization) return false; + if (MemoryOptimizerTargetNodeNameScope != other.MemoryOptimizerTargetNodeNameScope) return false; + if (MetaOptimizerTimeoutMs != other.MetaOptimizerTimeoutMs) return false; + if (!object.Equals(AutoParallel, other.AutoParallel)) return false; + if (FailOnOptimizerErrors != other.FailOnOptimizerErrors) return false; + if (!object.Equals(ScopedAllocatorOpts, other.ScopedAllocatorOpts)) return false; + if(!optimizers_.Equals(other.optimizers_)) return false; + if(!customOptimizers_.Equals(other.customOptimizers_)) return false; + if (!object.Equals(InterOptimizerVerifierConfig, other.InterOptimizerVerifierConfig)) return false; + if (!object.Equals(PostOptimizationVerifierConfig, other.PostOptimizationVerifierConfig)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (CpuLayoutConversion != global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu) hash ^= CpuLayoutConversion.GetHashCode(); + if (LayoutOptimizer != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= LayoutOptimizer.GetHashCode(); + if (ConstantFolding != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ConstantFolding.GetHashCode(); + if (ShapeOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ShapeOptimization.GetHashCode(); + if (Remapping != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= Remapping.GetHashCode(); + if (CommonSubgraphElimination != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= CommonSubgraphElimination.GetHashCode(); + if (ArithmeticOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ArithmeticOptimization.GetHashCode(); + if (DependencyOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= DependencyOptimization.GetHashCode(); + if (LoopOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= LoopOptimization.GetHashCode(); + if (FunctionOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= FunctionOptimization.GetHashCode(); + if (DebugStripper != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= DebugStripper.GetHashCode(); + if (DisableModelPruning != false) hash ^= DisableModelPruning.GetHashCode(); + if (ScopedAllocatorOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ScopedAllocatorOptimization.GetHashCode(); + if (PinToHostOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= PinToHostOptimization.GetHashCode(); + if (ImplementationSelector != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ImplementationSelector.GetHashCode(); + if (AutoMixedPrecision != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= AutoMixedPrecision.GetHashCode(); + if (AutoMixedPrecisionMkl != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= AutoMixedPrecisionMkl.GetHashCode(); + if (AutoMixedPrecisionOnednnBfloat16 != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= AutoMixedPrecisionOnednnBfloat16.GetHashCode(); + if (AutoMixedPrecisionCpu != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= AutoMixedPrecisionCpu.GetHashCode(); + if (DisableMetaOptimizer != false) hash ^= DisableMetaOptimizer.GetHashCode(); + if (UsePluginOptimizers != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= UsePluginOptimizers.GetHashCode(); + if (ExperimentalConditionalCodeMotion != global::Tensorflow.RewriterConfig.Types.Toggle.Default) hash ^= ExperimentalConditionalCodeMotion.GetHashCode(); + if (MetaOptimizerIterations != global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters) hash ^= MetaOptimizerIterations.GetHashCode(); + if (MinGraphNodes != 0) hash ^= MinGraphNodes.GetHashCode(); + if (ExperimentalDisableCompressedTensorOptimization != false) hash ^= ExperimentalDisableCompressedTensorOptimization.GetHashCode(); + if (ExperimentalDisableFoldingQuantizationEmulation != false) hash ^= ExperimentalDisableFoldingQuantizationEmulation.GetHashCode(); + if (MemoryOptimization != global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt) hash ^= MemoryOptimization.GetHashCode(); + if (MemoryOptimizerTargetNodeNameScope.Length != 0) hash ^= MemoryOptimizerTargetNodeNameScope.GetHashCode(); + if (MetaOptimizerTimeoutMs != 0L) hash ^= MetaOptimizerTimeoutMs.GetHashCode(); + if (autoParallel_ != null) hash ^= AutoParallel.GetHashCode(); + if (FailOnOptimizerErrors != false) hash ^= FailOnOptimizerErrors.GetHashCode(); + if (scopedAllocatorOpts_ != null) hash ^= ScopedAllocatorOpts.GetHashCode(); + hash ^= optimizers_.GetHashCode(); + hash ^= customOptimizers_.GetHashCode(); + if (interOptimizerVerifierConfig_ != null) hash ^= InterOptimizerVerifierConfig.GetHashCode(); + if (postOptimizationVerifierConfig_ != null) hash ^= PostOptimizationVerifierConfig.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LayoutOptimizer != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(8); + output.WriteEnum((int) LayoutOptimizer); + } + if (DisableModelPruning != false) { + output.WriteRawTag(16); + output.WriteBool(DisableModelPruning); + } + if (ConstantFolding != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(24); + output.WriteEnum((int) ConstantFolding); + } + if (MemoryOptimization != global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt) { + output.WriteRawTag(32); + output.WriteEnum((int) MemoryOptimization); + } + if (autoParallel_ != null) { + output.WriteRawTag(42); + output.WriteMessage(AutoParallel); + } + if (MemoryOptimizerTargetNodeNameScope.Length != 0) { + output.WriteRawTag(50); + output.WriteString(MemoryOptimizerTargetNodeNameScope); + } + if (ArithmeticOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(56); + output.WriteEnum((int) ArithmeticOptimization); + } + if (DependencyOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(64); + output.WriteEnum((int) DependencyOptimization); + } + if (LoopOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(72); + output.WriteEnum((int) LoopOptimization); + } + if (FunctionOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(80); + output.WriteEnum((int) FunctionOptimization); + } + if (DebugStripper != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(88); + output.WriteEnum((int) DebugStripper); + } + if (MetaOptimizerIterations != global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters) { + output.WriteRawTag(96); + output.WriteEnum((int) MetaOptimizerIterations); + } + if (ShapeOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(104); + output.WriteEnum((int) ShapeOptimization); + } + if (Remapping != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(112); + output.WriteEnum((int) Remapping); + } + if (ScopedAllocatorOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(120); + output.WriteEnum((int) ScopedAllocatorOptimization); + } + if (scopedAllocatorOpts_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(ScopedAllocatorOpts); + } + if (MinGraphNodes != 0) { + output.WriteRawTag(136, 1); + output.WriteInt32(MinGraphNodes); + } + if (PinToHostOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(144, 1); + output.WriteEnum((int) PinToHostOptimization); + } + if (DisableMetaOptimizer != false) { + output.WriteRawTag(152, 1); + output.WriteBool(DisableMetaOptimizer); + } + if (MetaOptimizerTimeoutMs != 0L) { + output.WriteRawTag(160, 1); + output.WriteInt64(MetaOptimizerTimeoutMs); + } + if (FailOnOptimizerErrors != false) { + output.WriteRawTag(168, 1); + output.WriteBool(FailOnOptimizerErrors); + } + if (ImplementationSelector != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(176, 1); + output.WriteEnum((int) ImplementationSelector); + } + if (AutoMixedPrecision != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(184, 1); + output.WriteEnum((int) AutoMixedPrecision); + } + if (CommonSubgraphElimination != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(192, 1); + output.WriteEnum((int) CommonSubgraphElimination); + } + if (AutoMixedPrecisionMkl != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(200, 1); + output.WriteEnum((int) AutoMixedPrecisionMkl); + } + if (ExperimentalDisableCompressedTensorOptimization != false) { + output.WriteRawTag(208, 1); + output.WriteBool(ExperimentalDisableCompressedTensorOptimization); + } + if (ExperimentalDisableFoldingQuantizationEmulation != false) { + output.WriteRawTag(216, 1); + output.WriteBool(ExperimentalDisableFoldingQuantizationEmulation); + } + if (UsePluginOptimizers != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(224, 1); + output.WriteEnum((int) UsePluginOptimizers); + } + if (AutoMixedPrecisionCpu != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(232, 1); + output.WriteEnum((int) AutoMixedPrecisionCpu); + } + if (ExperimentalConditionalCodeMotion != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(240, 1); + output.WriteEnum((int) ExperimentalConditionalCodeMotion); + } + if (AutoMixedPrecisionOnednnBfloat16 != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(248, 1); + output.WriteEnum((int) AutoMixedPrecisionOnednnBfloat16); + } + if (CpuLayoutConversion != global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu) { + output.WriteRawTag(144, 3); + output.WriteEnum((int) CpuLayoutConversion); + } + optimizers_.WriteTo(output, _repeated_optimizers_codec); + customOptimizers_.WriteTo(output, _repeated_customOptimizers_codec); + if (interOptimizerVerifierConfig_ != null) { + output.WriteRawTag(226, 18); + output.WriteMessage(InterOptimizerVerifierConfig); + } + if (postOptimizationVerifierConfig_ != null) { + output.WriteRawTag(234, 18); + output.WriteMessage(PostOptimizationVerifierConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LayoutOptimizer != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(8); + output.WriteEnum((int) LayoutOptimizer); + } + if (DisableModelPruning != false) { + output.WriteRawTag(16); + output.WriteBool(DisableModelPruning); + } + if (ConstantFolding != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(24); + output.WriteEnum((int) ConstantFolding); + } + if (MemoryOptimization != global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt) { + output.WriteRawTag(32); + output.WriteEnum((int) MemoryOptimization); + } + if (autoParallel_ != null) { + output.WriteRawTag(42); + output.WriteMessage(AutoParallel); + } + if (MemoryOptimizerTargetNodeNameScope.Length != 0) { + output.WriteRawTag(50); + output.WriteString(MemoryOptimizerTargetNodeNameScope); + } + if (ArithmeticOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(56); + output.WriteEnum((int) ArithmeticOptimization); + } + if (DependencyOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(64); + output.WriteEnum((int) DependencyOptimization); + } + if (LoopOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(72); + output.WriteEnum((int) LoopOptimization); + } + if (FunctionOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(80); + output.WriteEnum((int) FunctionOptimization); + } + if (DebugStripper != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(88); + output.WriteEnum((int) DebugStripper); + } + if (MetaOptimizerIterations != global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters) { + output.WriteRawTag(96); + output.WriteEnum((int) MetaOptimizerIterations); + } + if (ShapeOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(104); + output.WriteEnum((int) ShapeOptimization); + } + if (Remapping != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(112); + output.WriteEnum((int) Remapping); + } + if (ScopedAllocatorOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(120); + output.WriteEnum((int) ScopedAllocatorOptimization); + } + if (scopedAllocatorOpts_ != null) { + output.WriteRawTag(130, 1); + output.WriteMessage(ScopedAllocatorOpts); + } + if (MinGraphNodes != 0) { + output.WriteRawTag(136, 1); + output.WriteInt32(MinGraphNodes); + } + if (PinToHostOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(144, 1); + output.WriteEnum((int) PinToHostOptimization); + } + if (DisableMetaOptimizer != false) { + output.WriteRawTag(152, 1); + output.WriteBool(DisableMetaOptimizer); + } + if (MetaOptimizerTimeoutMs != 0L) { + output.WriteRawTag(160, 1); + output.WriteInt64(MetaOptimizerTimeoutMs); + } + if (FailOnOptimizerErrors != false) { + output.WriteRawTag(168, 1); + output.WriteBool(FailOnOptimizerErrors); + } + if (ImplementationSelector != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(176, 1); + output.WriteEnum((int) ImplementationSelector); + } + if (AutoMixedPrecision != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(184, 1); + output.WriteEnum((int) AutoMixedPrecision); + } + if (CommonSubgraphElimination != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(192, 1); + output.WriteEnum((int) CommonSubgraphElimination); + } + if (AutoMixedPrecisionMkl != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(200, 1); + output.WriteEnum((int) AutoMixedPrecisionMkl); + } + if (ExperimentalDisableCompressedTensorOptimization != false) { + output.WriteRawTag(208, 1); + output.WriteBool(ExperimentalDisableCompressedTensorOptimization); + } + if (ExperimentalDisableFoldingQuantizationEmulation != false) { + output.WriteRawTag(216, 1); + output.WriteBool(ExperimentalDisableFoldingQuantizationEmulation); + } + if (UsePluginOptimizers != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(224, 1); + output.WriteEnum((int) UsePluginOptimizers); + } + if (AutoMixedPrecisionCpu != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(232, 1); + output.WriteEnum((int) AutoMixedPrecisionCpu); + } + if (ExperimentalConditionalCodeMotion != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(240, 1); + output.WriteEnum((int) ExperimentalConditionalCodeMotion); + } + if (AutoMixedPrecisionOnednnBfloat16 != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + output.WriteRawTag(248, 1); + output.WriteEnum((int) AutoMixedPrecisionOnednnBfloat16); + } + if (CpuLayoutConversion != global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu) { + output.WriteRawTag(144, 3); + output.WriteEnum((int) CpuLayoutConversion); + } + optimizers_.WriteTo(ref output, _repeated_optimizers_codec); + customOptimizers_.WriteTo(ref output, _repeated_customOptimizers_codec); + if (interOptimizerVerifierConfig_ != null) { + output.WriteRawTag(226, 18); + output.WriteMessage(InterOptimizerVerifierConfig); + } + if (postOptimizationVerifierConfig_ != null) { + output.WriteRawTag(234, 18); + output.WriteMessage(PostOptimizationVerifierConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (CpuLayoutConversion != global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) CpuLayoutConversion); + } + if (LayoutOptimizer != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) LayoutOptimizer); + } + if (ConstantFolding != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ConstantFolding); + } + if (ShapeOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ShapeOptimization); + } + if (Remapping != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Remapping); + } + if (CommonSubgraphElimination != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) CommonSubgraphElimination); + } + if (ArithmeticOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ArithmeticOptimization); + } + if (DependencyOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DependencyOptimization); + } + if (LoopOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) LoopOptimization); + } + if (FunctionOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) FunctionOptimization); + } + if (DebugStripper != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DebugStripper); + } + if (DisableModelPruning != false) { + size += 1 + 1; + } + if (ScopedAllocatorOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ScopedAllocatorOptimization); + } + if (PinToHostOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) PinToHostOptimization); + } + if (ImplementationSelector != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) ImplementationSelector); + } + if (AutoMixedPrecision != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) AutoMixedPrecision); + } + if (AutoMixedPrecisionMkl != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) AutoMixedPrecisionMkl); + } + if (AutoMixedPrecisionOnednnBfloat16 != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) AutoMixedPrecisionOnednnBfloat16); + } + if (AutoMixedPrecisionCpu != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) AutoMixedPrecisionCpu); + } + if (DisableMetaOptimizer != false) { + size += 2 + 1; + } + if (UsePluginOptimizers != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) UsePluginOptimizers); + } + if (ExperimentalConditionalCodeMotion != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) ExperimentalConditionalCodeMotion); + } + if (MetaOptimizerIterations != global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MetaOptimizerIterations); + } + if (MinGraphNodes != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MinGraphNodes); + } + if (ExperimentalDisableCompressedTensorOptimization != false) { + size += 2 + 1; + } + if (ExperimentalDisableFoldingQuantizationEmulation != false) { + size += 2 + 1; + } + if (MemoryOptimization != global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MemoryOptimization); + } + if (MemoryOptimizerTargetNodeNameScope.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MemoryOptimizerTargetNodeNameScope); + } + if (MetaOptimizerTimeoutMs != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(MetaOptimizerTimeoutMs); + } + if (autoParallel_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AutoParallel); + } + if (FailOnOptimizerErrors != false) { + size += 2 + 1; + } + if (scopedAllocatorOpts_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ScopedAllocatorOpts); + } + size += optimizers_.CalculateSize(_repeated_optimizers_codec); + size += customOptimizers_.CalculateSize(_repeated_customOptimizers_codec); + if (interOptimizerVerifierConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(InterOptimizerVerifierConfig); + } + if (postOptimizationVerifierConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(PostOptimizationVerifierConfig); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RewriterConfig other) { + if (other == null) { + return; + } + if (other.CpuLayoutConversion != global::Tensorflow.RewriterConfig.Types.CpuLayout.NoConversionOnCpu) { + CpuLayoutConversion = other.CpuLayoutConversion; + } + if (other.LayoutOptimizer != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + LayoutOptimizer = other.LayoutOptimizer; + } + if (other.ConstantFolding != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ConstantFolding = other.ConstantFolding; + } + if (other.ShapeOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ShapeOptimization = other.ShapeOptimization; + } + if (other.Remapping != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + Remapping = other.Remapping; + } + if (other.CommonSubgraphElimination != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + CommonSubgraphElimination = other.CommonSubgraphElimination; + } + if (other.ArithmeticOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ArithmeticOptimization = other.ArithmeticOptimization; + } + if (other.DependencyOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + DependencyOptimization = other.DependencyOptimization; + } + if (other.LoopOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + LoopOptimization = other.LoopOptimization; + } + if (other.FunctionOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + FunctionOptimization = other.FunctionOptimization; + } + if (other.DebugStripper != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + DebugStripper = other.DebugStripper; + } + if (other.DisableModelPruning != false) { + DisableModelPruning = other.DisableModelPruning; + } + if (other.ScopedAllocatorOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ScopedAllocatorOptimization = other.ScopedAllocatorOptimization; + } + if (other.PinToHostOptimization != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + PinToHostOptimization = other.PinToHostOptimization; + } + if (other.ImplementationSelector != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ImplementationSelector = other.ImplementationSelector; + } + if (other.AutoMixedPrecision != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + AutoMixedPrecision = other.AutoMixedPrecision; + } + if (other.AutoMixedPrecisionMkl != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + AutoMixedPrecisionMkl = other.AutoMixedPrecisionMkl; + } + if (other.AutoMixedPrecisionOnednnBfloat16 != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + AutoMixedPrecisionOnednnBfloat16 = other.AutoMixedPrecisionOnednnBfloat16; + } + if (other.AutoMixedPrecisionCpu != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + AutoMixedPrecisionCpu = other.AutoMixedPrecisionCpu; + } + if (other.DisableMetaOptimizer != false) { + DisableMetaOptimizer = other.DisableMetaOptimizer; + } + if (other.UsePluginOptimizers != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + UsePluginOptimizers = other.UsePluginOptimizers; + } + if (other.ExperimentalConditionalCodeMotion != global::Tensorflow.RewriterConfig.Types.Toggle.Default) { + ExperimentalConditionalCodeMotion = other.ExperimentalConditionalCodeMotion; + } + if (other.MetaOptimizerIterations != global::Tensorflow.RewriterConfig.Types.NumIterationsType.DefaultNumIters) { + MetaOptimizerIterations = other.MetaOptimizerIterations; + } + if (other.MinGraphNodes != 0) { + MinGraphNodes = other.MinGraphNodes; + } + if (other.ExperimentalDisableCompressedTensorOptimization != false) { + ExperimentalDisableCompressedTensorOptimization = other.ExperimentalDisableCompressedTensorOptimization; + } + if (other.ExperimentalDisableFoldingQuantizationEmulation != false) { + ExperimentalDisableFoldingQuantizationEmulation = other.ExperimentalDisableFoldingQuantizationEmulation; + } + if (other.MemoryOptimization != global::Tensorflow.RewriterConfig.Types.MemOptType.DefaultMemOpt) { + MemoryOptimization = other.MemoryOptimization; + } + if (other.MemoryOptimizerTargetNodeNameScope.Length != 0) { + MemoryOptimizerTargetNodeNameScope = other.MemoryOptimizerTargetNodeNameScope; + } + if (other.MetaOptimizerTimeoutMs != 0L) { + MetaOptimizerTimeoutMs = other.MetaOptimizerTimeoutMs; + } + if (other.autoParallel_ != null) { + if (autoParallel_ == null) { + AutoParallel = new global::Tensorflow.AutoParallelOptions(); + } + AutoParallel.MergeFrom(other.AutoParallel); + } + if (other.FailOnOptimizerErrors != false) { + FailOnOptimizerErrors = other.FailOnOptimizerErrors; + } + if (other.scopedAllocatorOpts_ != null) { + if (scopedAllocatorOpts_ == null) { + ScopedAllocatorOpts = new global::Tensorflow.ScopedAllocatorOptions(); + } + ScopedAllocatorOpts.MergeFrom(other.ScopedAllocatorOpts); + } + optimizers_.Add(other.optimizers_); + customOptimizers_.Add(other.customOptimizers_); + if (other.interOptimizerVerifierConfig_ != null) { + if (interOptimizerVerifierConfig_ == null) { + InterOptimizerVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + InterOptimizerVerifierConfig.MergeFrom(other.InterOptimizerVerifierConfig); + } + if (other.postOptimizationVerifierConfig_ != null) { + if (postOptimizationVerifierConfig_ == null) { + PostOptimizationVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + PostOptimizationVerifierConfig.MergeFrom(other.PostOptimizationVerifierConfig); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + LayoutOptimizer = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 16: { + DisableModelPruning = input.ReadBool(); + break; + } + case 24: { + ConstantFolding = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 32: { + MemoryOptimization = (global::Tensorflow.RewriterConfig.Types.MemOptType) input.ReadEnum(); + break; + } + case 42: { + if (autoParallel_ == null) { + AutoParallel = new global::Tensorflow.AutoParallelOptions(); + } + input.ReadMessage(AutoParallel); + break; + } + case 50: { + MemoryOptimizerTargetNodeNameScope = input.ReadString(); + break; + } + case 56: { + ArithmeticOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 64: { + DependencyOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 72: { + LoopOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 80: { + FunctionOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 88: { + DebugStripper = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 96: { + MetaOptimizerIterations = (global::Tensorflow.RewriterConfig.Types.NumIterationsType) input.ReadEnum(); + break; + } + case 104: { + ShapeOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 112: { + Remapping = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 120: { + ScopedAllocatorOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 130: { + if (scopedAllocatorOpts_ == null) { + ScopedAllocatorOpts = new global::Tensorflow.ScopedAllocatorOptions(); + } + input.ReadMessage(ScopedAllocatorOpts); + break; + } + case 136: { + MinGraphNodes = input.ReadInt32(); + break; + } + case 144: { + PinToHostOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 152: { + DisableMetaOptimizer = input.ReadBool(); + break; + } + case 160: { + MetaOptimizerTimeoutMs = input.ReadInt64(); + break; + } + case 168: { + FailOnOptimizerErrors = input.ReadBool(); + break; + } + case 176: { + ImplementationSelector = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 184: { + AutoMixedPrecision = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 192: { + CommonSubgraphElimination = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 200: { + AutoMixedPrecisionMkl = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 208: { + ExperimentalDisableCompressedTensorOptimization = input.ReadBool(); + break; + } + case 216: { + ExperimentalDisableFoldingQuantizationEmulation = input.ReadBool(); + break; + } + case 224: { + UsePluginOptimizers = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 232: { + AutoMixedPrecisionCpu = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 240: { + ExperimentalConditionalCodeMotion = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 248: { + AutoMixedPrecisionOnednnBfloat16 = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 400: { + CpuLayoutConversion = (global::Tensorflow.RewriterConfig.Types.CpuLayout) input.ReadEnum(); + break; + } + case 802: { + optimizers_.AddEntriesFrom(input, _repeated_optimizers_codec); + break; + } + case 1602: { + customOptimizers_.AddEntriesFrom(input, _repeated_customOptimizers_codec); + break; + } + case 2402: { + if (interOptimizerVerifierConfig_ == null) { + InterOptimizerVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + input.ReadMessage(InterOptimizerVerifierConfig); + break; + } + case 2410: { + if (postOptimizationVerifierConfig_ == null) { + PostOptimizationVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + input.ReadMessage(PostOptimizationVerifierConfig); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + LayoutOptimizer = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 16: { + DisableModelPruning = input.ReadBool(); + break; + } + case 24: { + ConstantFolding = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 32: { + MemoryOptimization = (global::Tensorflow.RewriterConfig.Types.MemOptType) input.ReadEnum(); + break; + } + case 42: { + if (autoParallel_ == null) { + AutoParallel = new global::Tensorflow.AutoParallelOptions(); + } + input.ReadMessage(AutoParallel); + break; + } + case 50: { + MemoryOptimizerTargetNodeNameScope = input.ReadString(); + break; + } + case 56: { + ArithmeticOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 64: { + DependencyOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 72: { + LoopOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 80: { + FunctionOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 88: { + DebugStripper = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 96: { + MetaOptimizerIterations = (global::Tensorflow.RewriterConfig.Types.NumIterationsType) input.ReadEnum(); + break; + } + case 104: { + ShapeOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 112: { + Remapping = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 120: { + ScopedAllocatorOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 130: { + if (scopedAllocatorOpts_ == null) { + ScopedAllocatorOpts = new global::Tensorflow.ScopedAllocatorOptions(); + } + input.ReadMessage(ScopedAllocatorOpts); + break; + } + case 136: { + MinGraphNodes = input.ReadInt32(); + break; + } + case 144: { + PinToHostOptimization = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 152: { + DisableMetaOptimizer = input.ReadBool(); + break; + } + case 160: { + MetaOptimizerTimeoutMs = input.ReadInt64(); + break; + } + case 168: { + FailOnOptimizerErrors = input.ReadBool(); + break; + } + case 176: { + ImplementationSelector = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 184: { + AutoMixedPrecision = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 192: { + CommonSubgraphElimination = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 200: { + AutoMixedPrecisionMkl = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 208: { + ExperimentalDisableCompressedTensorOptimization = input.ReadBool(); + break; + } + case 216: { + ExperimentalDisableFoldingQuantizationEmulation = input.ReadBool(); + break; + } + case 224: { + UsePluginOptimizers = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 232: { + AutoMixedPrecisionCpu = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 240: { + ExperimentalConditionalCodeMotion = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 248: { + AutoMixedPrecisionOnednnBfloat16 = (global::Tensorflow.RewriterConfig.Types.Toggle) input.ReadEnum(); + break; + } + case 400: { + CpuLayoutConversion = (global::Tensorflow.RewriterConfig.Types.CpuLayout) input.ReadEnum(); + break; + } + case 802: { + optimizers_.AddEntriesFrom(ref input, _repeated_optimizers_codec); + break; + } + case 1602: { + customOptimizers_.AddEntriesFrom(ref input, _repeated_customOptimizers_codec); + break; + } + case 2402: { + if (interOptimizerVerifierConfig_ == null) { + InterOptimizerVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + input.ReadMessage(InterOptimizerVerifierConfig); + break; + } + case 2410: { + if (postOptimizationVerifierConfig_ == null) { + PostOptimizationVerifierConfig = new global::Tensorflow.VerifierConfig(); + } + input.ReadMessage(PostOptimizationVerifierConfig); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the RewriterConfig message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Toggle { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("ON")] On = 1, + [pbr::OriginalName("OFF")] Off = 2, + /// + /// Enable some aggressive optimizations that use assumptions that TF graphs + /// may break. For example, assume the shape of a placeholder matches its + /// actual feed. + /// + [pbr::OriginalName("AGGRESSIVE")] Aggressive = 3, + /// + /// Run MLIR pass if there's one implemented in TFG, do nothing otherwise. + /// I.e., if there's no corresponding TFG pass, it's an OFF. This is supposed + /// to be mapped with `ON` and there's no `AGGRESSIVE` in MLIR pass now. + /// + [pbr::OriginalName("EXPERIMENTAL_MLIR")] ExperimentalMlir = 4, + /// + /// Run both MLIR and Grappler passes consecutively and MLIR pass will come + /// first. + /// + [pbr::OriginalName("EXPERIMENTAL_BOTH")] ExperimentalBoth = 5, + } + + /// + /// Enum for layout conversion between NCHW and NHWC on CPU. Default is OFF. + /// + public enum CpuLayout { + [pbr::OriginalName("NO_CONVERSION_ON_CPU")] NoConversionOnCpu = 0, + [pbr::OriginalName("NCHW_TO_NHWC")] NchwToNhwc = 1, + [pbr::OriginalName("NHWC_TO_NCHW")] NhwcToNchw = 2, + } + + /// + /// Enum controlling the number of times to run optimizers. The default is to + /// run them twice. + /// + public enum NumIterationsType { + [pbr::OriginalName("DEFAULT_NUM_ITERS")] DefaultNumIters = 0, + [pbr::OriginalName("ONE")] One = 1, + [pbr::OriginalName("TWO")] Two = 2, + } + + public enum MemOptType { + /// + /// The default setting (SCHEDULING and SWAPPING HEURISTICS only) + /// + [pbr::OriginalName("DEFAULT_MEM_OPT")] DefaultMemOpt = 0, + /// + /// Disabled in the meta-optimizer. + /// + [pbr::OriginalName("NO_MEM_OPT")] NoMemOpt = 1, + /// + /// Driven by manual op-level annotations. + /// + [pbr::OriginalName("MANUAL")] Manual = 2, + /// + /// Swapping heuristic will move a tensor from the GPU to the CPU and move + /// it back when needed to reduce peak memory usage. + /// + [pbr::OriginalName("SWAPPING_HEURISTICS")] SwappingHeuristics = 4, + /// + /// Recomputation heuristics will recompute ops (such as Relu activation) + /// during backprop instead of storing them, reducing peak memory usage. + /// + [pbr::OriginalName("RECOMPUTATION_HEURISTICS")] RecomputationHeuristics = 5, + /// + /// Scheduling will split big ops such as AddN and try to enforce a schedule + /// of the new computations that decreases peak memory usage. + /// + [pbr::OriginalName("SCHEDULING_HEURISTICS")] SchedulingHeuristics = 6, + /// + /// Use any combination of swapping and recomputation heuristics. + /// + [pbr::OriginalName("HEURISTICS")] Heuristics = 3, + } + + /// + /// Message to describe custom graph optimizer and its parameters + /// + public sealed partial class CustomGraphOptimizer : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CustomGraphOptimizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.RewriterConfig.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomGraphOptimizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomGraphOptimizer(CustomGraphOptimizer other) : this() { + name_ = other.name_; + parameterMap_ = other.parameterMap_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomGraphOptimizer Clone() { + return new CustomGraphOptimizer(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "parameter_map" field. + public const int ParameterMapFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_parameterMap_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); + private readonly pbc::MapField parameterMap_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ParameterMap { + get { return parameterMap_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CustomGraphOptimizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CustomGraphOptimizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!ParameterMap.Equals(other.ParameterMap)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= ParameterMap.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + parameterMap_.WriteTo(output, _map_parameterMap_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + parameterMap_.WriteTo(ref output, _map_parameterMap_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += parameterMap_.CalculateSize(_map_parameterMap_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CustomGraphOptimizer other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + parameterMap_.Add(other.parameterMap_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + parameterMap_.AddEntriesFrom(input, _map_parameterMap_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + parameterMap_.AddEntriesFrom(ref input, _map_parameterMap_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/SavedModel.cs b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs new file mode 100644 index 000000000..67cea4889 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs @@ -0,0 +1,276 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/saved_model.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto + public static partial class SavedModelReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/saved_model.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SavedModelReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S", + "CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy", + "YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h", + "X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm", + "bG93Lk1ldGFHcmFwaERlZkKIAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3Jr", + "QhBTYXZlZE1vZGVsUHJvdG9zUAFaVWdpdGh1Yi5jb20vdGVuc29yZmxvdy90", + "ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1Zi9mb3JfY29y", + "ZV9wcm90b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// SavedModel is the high level serialization format for TensorFlow Models. + /// See [todo: doc links, similar to session_bundle] for more information. + /// + public sealed partial class SavedModel : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedModel()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedModel() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedModel(SavedModel other) : this() { + savedModelSchemaVersion_ = other.savedModelSchemaVersion_; + metaGraphs_ = other.metaGraphs_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedModel Clone() { + return new SavedModel(this); + } + + /// Field number for the "saved_model_schema_version" field. + public const int SavedModelSchemaVersionFieldNumber = 1; + private long savedModelSchemaVersion_; + /// + /// The schema version of the SavedModel instance. Used for versioning when + /// making future changes to the specification/implementation. Initial value + /// at release will be 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long SavedModelSchemaVersion { + get { return savedModelSchemaVersion_; } + set { + savedModelSchemaVersion_ = value; + } + } + + /// Field number for the "meta_graphs" field. + public const int MetaGraphsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_metaGraphs_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser); + private readonly pbc::RepeatedField metaGraphs_ = new pbc::RepeatedField(); + /// + /// One or more MetaGraphs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MetaGraphs { + get { return metaGraphs_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedModel); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedModel other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false; + if(!metaGraphs_.Equals(other.metaGraphs_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode(); + hash ^= metaGraphs_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SavedModelSchemaVersion != 0L) { + output.WriteRawTag(8); + output.WriteInt64(SavedModelSchemaVersion); + } + metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SavedModelSchemaVersion != 0L) { + output.WriteRawTag(8); + output.WriteInt64(SavedModelSchemaVersion); + } + metaGraphs_.WriteTo(ref output, _repeated_metaGraphs_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SavedModelSchemaVersion != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion); + } + size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedModel other) { + if (other == null) { + return; + } + if (other.SavedModelSchemaVersion != 0L) { + SavedModelSchemaVersion = other.SavedModelSchemaVersion; + } + metaGraphs_.Add(other.metaGraphs_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SavedModelSchemaVersion = input.ReadInt64(); + break; + } + case 18: { + metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + SavedModelSchemaVersion = input.ReadInt64(); + break; + } + case 18: { + metaGraphs_.AddEntriesFrom(ref input, _repeated_metaGraphs_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs new file mode 100644 index 000000000..df7019ad4 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -0,0 +1,4189 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/saved_object_graph.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/saved_object_graph.proto + public static partial class SavedObjectGraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/saved_object_graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SavedObjectGraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjF0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfb2JqZWN0X2dyYXBo", + "LnByb3RvEgp0ZW5zb3JmbG93Ghlnb29nbGUvcHJvdG9idWYvYW55LnByb3Rv", + "Gix0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvcl9zaGFwZS5wcm90", + "bxoldGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90eXBlcy5wcm90bxoodGVu", + "c29yZmxvdy9jb3JlL2ZyYW1ld29yay92YXJpYWJsZS5wcm90bxoodGVuc29y", + "Zmxvdy9jb3JlL2ZyYW1ld29yay92ZXJzaW9ucy5wcm90bxoldGVuc29yZmxv", + "dy9jb3JlL3Byb3RvYnVmL3N0cnVjdC5wcm90bxo1dGVuc29yZmxvdy9jb3Jl", + "L3Byb3RvYnVmL3RyYWNrYWJsZV9vYmplY3RfZ3JhcGgucHJvdG8i6AEKEFNh", + "dmVkT2JqZWN0R3JhcGgSJgoFbm9kZXMYASADKAsyFy50ZW5zb3JmbG93LlNh", + "dmVkT2JqZWN0Ek8KEmNvbmNyZXRlX2Z1bmN0aW9ucxgCIAMoCzIzLnRlbnNv", + "cmZsb3cuU2F2ZWRPYmplY3RHcmFwaC5Db25jcmV0ZUZ1bmN0aW9uc0VudHJ5", + "GlsKFkNvbmNyZXRlRnVuY3Rpb25zRW50cnkSCwoDa2V5GAEgASgJEjAKBXZh", + "bHVlGAIgASgLMiEudGVuc29yZmxvdy5TYXZlZENvbmNyZXRlRnVuY3Rpb246", + "AjgBItAHCgtTYXZlZE9iamVjdBJSCghjaGlsZHJlbhgBIAMoCzJALnRlbnNv", + "cmZsb3cuVHJhY2thYmxlT2JqZWN0R3JhcGguVHJhY2thYmxlT2JqZWN0Lk9i", + "amVjdFJlZmVyZW5jZRJWCgxkZXBlbmRlbmNpZXMYDyADKAsyQC50ZW5zb3Jm", + "bG93LlRyYWNrYWJsZU9iamVjdEdyYXBoLlRyYWNrYWJsZU9iamVjdC5PYmpl", + "Y3RSZWZlcmVuY2USXgoOc2xvdF92YXJpYWJsZXMYAyADKAsyRi50ZW5zb3Jm", + "bG93LlRyYWNrYWJsZU9iamVjdEdyYXBoLlRyYWNrYWJsZU9iamVjdC5TbG90", + "VmFyaWFibGVSZWZlcmVuY2USMgoLdXNlcl9vYmplY3QYBCABKAsyGy50ZW5z", + "b3JmbG93LlNhdmVkVXNlck9iamVjdEgAEicKBWFzc2V0GAUgASgLMhYudGVu", + "c29yZmxvdy5TYXZlZEFzc2V0SAASLQoIZnVuY3Rpb24YBiABKAsyGS50ZW5z", + "b3JmbG93LlNhdmVkRnVuY3Rpb25IABItCgh2YXJpYWJsZRgHIAEoCzIZLnRl", + "bnNvcmZsb3cuU2F2ZWRWYXJpYWJsZUgAEkcKFmJhcmVfY29uY3JldGVfZnVu", + "Y3Rpb24YCCABKAsyJS50ZW5zb3JmbG93LlNhdmVkQmFyZUNvbmNyZXRlRnVu", + "Y3Rpb25IABItCghjb25zdGFudBgJIAEoCzIZLnRlbnNvcmZsb3cuU2F2ZWRD", + "b25zdGFudEgAEi0KCHJlc291cmNlGAogASgLMhkudGVuc29yZmxvdy5TYXZl", + "ZFJlc291cmNlSAASNQoPY2FwdHVyZWRfdGVuc29yGAwgASgLMhoudGVuc29y", + "Zmxvdy5DYXB0dXJlZFRlbnNvckgAEkYKEHNhdmVhYmxlX29iamVjdHMYCyAD", + "KAsyLC50ZW5zb3JmbG93LlNhdmVkT2JqZWN0LlNhdmVhYmxlT2JqZWN0c0Vu", + "dHJ5EhcKD3JlZ2lzdGVyZWRfbmFtZRgNIAEoCRIzChVzZXJpYWxpemVkX3Vz", + "ZXJfcHJvdG8YDiABKAsyFC5nb29nbGUucHJvdG9idWYuQW55EhgKEHJlZ2lz", + "dGVyZWRfc2F2ZXIYECABKAkaUgoUU2F2ZWFibGVPYmplY3RzRW50cnkSCwoD", + "a2V5GAEgASgJEikKBXZhbHVlGAIgASgLMhoudGVuc29yZmxvdy5TYXZlYWJs", + "ZU9iamVjdDoCOAFCBgoEa2luZEoECAIQA1IKYXR0cmlidXRlcyJkCg9TYXZl", + "ZFVzZXJPYmplY3QSEgoKaWRlbnRpZmllchgBIAEoCRInCgd2ZXJzaW9uGAIg", + "ASgLMhYudGVuc29yZmxvdy5WZXJzaW9uRGVmEhQKCG1ldGFkYXRhGAMgASgJ", + "QgIYASIqCgpTYXZlZEFzc2V0EhwKFGFzc2V0X2ZpbGVfZGVmX2luZGV4GAEg", + "ASgFIlwKDVNhdmVkRnVuY3Rpb24SGgoSY29uY3JldGVfZnVuY3Rpb25zGAEg", + "AygJEi8KDWZ1bmN0aW9uX3NwZWMYAiABKAsyGC50ZW5zb3JmbG93LkZ1bmN0", + "aW9uU3BlYyI5Cg5DYXB0dXJlZFRlbnNvchIMCgRuYW1lGAEgASgJEhkKEWNv", + "bmNyZXRlX2Z1bmN0aW9uGAIgASgJIqgBChVTYXZlZENvbmNyZXRlRnVuY3Rp", + "b24SFAoMYm91bmRfaW5wdXRzGAIgAygFEkIKHWNhbm9uaWNhbGl6ZWRfaW5w", + "dXRfc2lnbmF0dXJlGAMgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFs", + "dWUSNQoQb3V0cHV0X3NpZ25hdHVyZRgEIAEoCzIbLnRlbnNvcmZsb3cuU3Ry", + "dWN0dXJlZFZhbHVlIq0BChlTYXZlZEJhcmVDb25jcmV0ZUZ1bmN0aW9uEh4K", + "FmNvbmNyZXRlX2Z1bmN0aW9uX25hbWUYASABKAkSGQoRYXJndW1lbnRfa2V5", + "d29yZHMYAiADKAkSJAocYWxsb3dlZF9wb3NpdGlvbmFsX2FyZ3VtZW50cxgD", + "IAEoAxIvCg1mdW5jdGlvbl9zcGVjGAQgASgLMhgudGVuc29yZmxvdy5GdW5j", + "dGlvblNwZWMiIgoNU2F2ZWRDb25zdGFudBIRCglvcGVyYXRpb24YASABKAki", + "1wIKDVNhdmVkVmFyaWFibGUSIwoFZHR5cGUYASABKA4yFC50ZW5zb3JmbG93", + "LkRhdGFUeXBlEisKBXNoYXBlGAIgASgLMhwudGVuc29yZmxvdy5UZW5zb3JT", + "aGFwZVByb3RvEhEKCXRyYWluYWJsZRgDIAEoCBI8Cg9zeW5jaHJvbml6YXRp", + "b24YBCABKA4yIy50ZW5zb3JmbG93LlZhcmlhYmxlU3luY2hyb25pemF0aW9u", + "EjQKC2FnZ3JlZ2F0aW9uGAUgASgOMh8udGVuc29yZmxvdy5WYXJpYWJsZUFn", + "Z3JlZ2F0aW9uEgwKBG5hbWUYBiABKAkSDgoGZGV2aWNlGAcgASgJEk8KLGV4", + "cGVyaW1lbnRhbF9kaXN0cmlidXRlZF92YXJpYWJsZV9jb21wb25lbnRzGAgg", + "AygLMhkudGVuc29yZmxvdy5TYXZlZFZhcmlhYmxlIvsBCgxGdW5jdGlvblNw", + "ZWMSMAoLZnVsbGFyZ3NwZWMYASABKAsyGy50ZW5zb3JmbG93LlN0cnVjdHVy", + "ZWRWYWx1ZRIRCglpc19tZXRob2QYAiABKAgSNAoPaW5wdXRfc2lnbmF0dXJl", + "GAUgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFsdWUSOAoLaml0X2Nv", + "bXBpbGUYBiABKA4yIy50ZW5zb3JmbG93LkZ1bmN0aW9uU3BlYy5KaXRDb21w", + "aWxlIioKCkppdENvbXBpbGUSCwoHREVGQVVMVBAAEgYKAk9OEAESBwoDT0ZG", + "EAJKBAgDEARKBAgEEAUiHwoNU2F2ZWRSZXNvdXJjZRIOCgZkZXZpY2UYASAB", + "KAkiQQoOU2F2ZWFibGVPYmplY3QSFQoNc2F2ZV9mdW5jdGlvbhgCIAEoBRIY", + "ChByZXN0b3JlX2Z1bmN0aW9uGAMgASgFQlpaVWdpdGh1Yi5jb20vdGVuc29y", + "Zmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1Zi9m", + "b3JfY29yZV9wcm90b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Google.Protobuf.WellKnownTypes.AnyReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.VariableReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, global::Tensorflow.StructReflection.Descriptor, global::Tensorflow.TrackableObjectGraphReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedObjectGraph), global::Tensorflow.SavedObjectGraph.Parser, new[]{ "Nodes", "ConcreteFunctions" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedObject), global::Tensorflow.SavedObject.Parser, new[]{ "Children", "Dependencies", "SlotVariables", "UserObject", "Asset", "Function", "Variable", "BareConcreteFunction", "Constant", "Resource", "CapturedTensor", "SaveableObjects", "RegisteredName", "SerializedUserProto", "RegisteredSaver" }, new[]{ "Kind" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedUserObject), global::Tensorflow.SavedUserObject.Parser, new[]{ "Identifier", "Version", "Metadata" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedAsset), global::Tensorflow.SavedAsset.Parser, new[]{ "AssetFileDefIndex" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedFunction), global::Tensorflow.SavedFunction.Parser, new[]{ "ConcreteFunctions", "FunctionSpec" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CapturedTensor), global::Tensorflow.CapturedTensor.Parser, new[]{ "Name", "ConcreteFunction" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedConcreteFunction), global::Tensorflow.SavedConcreteFunction.Parser, new[]{ "BoundInputs", "CanonicalizedInputSignature", "OutputSignature" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedBareConcreteFunction), global::Tensorflow.SavedBareConcreteFunction.Parser, new[]{ "ConcreteFunctionName", "ArgumentKeywords", "AllowedPositionalArguments", "FunctionSpec" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedConstant), global::Tensorflow.SavedConstant.Parser, new[]{ "Operation" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedVariable), global::Tensorflow.SavedVariable.Parser, new[]{ "Dtype", "Shape", "Trainable", "Synchronization", "Aggregation", "Name", "Device", "ExperimentalDistributedVariableComponents" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionSpec), global::Tensorflow.FunctionSpec.Parser, new[]{ "Fullargspec", "IsMethod", "InputSignature", "JitCompile" }, null, new[]{ typeof(global::Tensorflow.FunctionSpec.Types.JitCompile) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedResource), global::Tensorflow.SavedResource.Parser, new[]{ "Device" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveableObject), global::Tensorflow.SaveableObject.Parser, new[]{ "SaveFunction", "RestoreFunction" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class SavedObjectGraph : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedObjectGraph()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObjectGraph() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObjectGraph(SavedObjectGraph other) : this() { + nodes_ = other.nodes_.Clone(); + concreteFunctions_ = other.concreteFunctions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObjectGraph Clone() { + return new SavedObjectGraph(this); + } + + /// Field number for the "nodes" field. + public const int NodesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_nodes_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.SavedObject.Parser); + private readonly pbc::RepeatedField nodes_ = new pbc::RepeatedField(); + /// + /// Flattened list of objects in the object graph. + /// + /// The position of the object in this list indicates its id. + /// Nodes[0] is considered the root node. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Nodes { + get { return nodes_; } + } + + /// Field number for the "concrete_functions" field. + public const int ConcreteFunctionsFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_concreteFunctions_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.SavedConcreteFunction.Parser), 18); + private readonly pbc::MapField concreteFunctions_ = new pbc::MapField(); + /// + /// Information about captures and output structures in concrete functions. + /// Referenced from SavedBareConcreteFunction and SavedFunction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ConcreteFunctions { + get { return concreteFunctions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedObjectGraph); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedObjectGraph other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!nodes_.Equals(other.nodes_)) return false; + if (!ConcreteFunctions.Equals(other.ConcreteFunctions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= nodes_.GetHashCode(); + hash ^= ConcreteFunctions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + nodes_.WriteTo(output, _repeated_nodes_codec); + concreteFunctions_.WriteTo(output, _map_concreteFunctions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + nodes_.WriteTo(ref output, _repeated_nodes_codec); + concreteFunctions_.WriteTo(ref output, _map_concreteFunctions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += nodes_.CalculateSize(_repeated_nodes_codec); + size += concreteFunctions_.CalculateSize(_map_concreteFunctions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedObjectGraph other) { + if (other == null) { + return; + } + nodes_.Add(other.nodes_); + concreteFunctions_.Add(other.concreteFunctions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + nodes_.AddEntriesFrom(input, _repeated_nodes_codec); + break; + } + case 18: { + concreteFunctions_.AddEntriesFrom(input, _map_concreteFunctions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + nodes_.AddEntriesFrom(ref input, _repeated_nodes_codec); + break; + } + case 18: { + concreteFunctions_.AddEntriesFrom(ref input, _map_concreteFunctions_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class SavedObject : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedObject()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObject() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObject(SavedObject other) : this() { + children_ = other.children_.Clone(); + dependencies_ = other.dependencies_.Clone(); + slotVariables_ = other.slotVariables_.Clone(); + saveableObjects_ = other.saveableObjects_.Clone(); + registeredName_ = other.registeredName_; + serializedUserProto_ = other.serializedUserProto_ != null ? other.serializedUserProto_.Clone() : null; + registeredSaver_ = other.registeredSaver_; + switch (other.KindCase) { + case KindOneofCase.UserObject: + UserObject = other.UserObject.Clone(); + break; + case KindOneofCase.Asset: + Asset = other.Asset.Clone(); + break; + case KindOneofCase.Function: + Function = other.Function.Clone(); + break; + case KindOneofCase.Variable: + Variable = other.Variable.Clone(); + break; + case KindOneofCase.BareConcreteFunction: + BareConcreteFunction = other.BareConcreteFunction.Clone(); + break; + case KindOneofCase.Constant: + Constant = other.Constant.Clone(); + break; + case KindOneofCase.Resource: + Resource = other.Resource.Clone(); + break; + case KindOneofCase.CapturedTensor: + CapturedTensor = other.CapturedTensor.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedObject Clone() { + return new SavedObject(this); + } + + /// Field number for the "children" field. + public const int ChildrenFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_children_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); + private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + /// + /// Objects which this object depends on: named edges in the dependency + /// graph. + /// + /// Note: All kinds of SavedObject may have children, except + /// "constant" and "captured_tensor". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Children { + get { return children_; } + } + + /// Field number for the "dependencies" field. + public const int DependenciesFieldNumber = 15; + private static readonly pb::FieldCodec _repeated_dependencies_codec + = pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); + private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); + /// + /// Ordered list of dependencies that must be loaded before this object. + /// SavedModel loads with the bottom-up approach, by first creating all objects + /// (in the order defined by the dependencies), then connecting the edges. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dependencies { + get { return dependencies_; } + } + + /// Field number for the "slot_variables" field. + public const int SlotVariablesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_slotVariables_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference.Parser); + private readonly pbc::RepeatedField slotVariables_ = new pbc::RepeatedField(); + /// + /// Slot variables owned by this object. This describes the three-way + /// (optimizer, variable, slot variable) relationship; none of the three + /// depend on the others directly. + /// + /// Note: currently only valid if kind == "user_object". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SlotVariables { + get { return slotVariables_; } + } + + /// Field number for the "user_object" field. + public const int UserObjectFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedUserObject UserObject { + get { return kindCase_ == KindOneofCase.UserObject ? (global::Tensorflow.SavedUserObject) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.UserObject; + } + } + + /// Field number for the "asset" field. + public const int AssetFieldNumber = 5; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedAsset Asset { + get { return kindCase_ == KindOneofCase.Asset ? (global::Tensorflow.SavedAsset) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Asset; + } + } + + /// Field number for the "function" field. + public const int FunctionFieldNumber = 6; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedFunction Function { + get { return kindCase_ == KindOneofCase.Function ? (global::Tensorflow.SavedFunction) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Function; + } + } + + /// Field number for the "variable" field. + public const int VariableFieldNumber = 7; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedVariable Variable { + get { return kindCase_ == KindOneofCase.Variable ? (global::Tensorflow.SavedVariable) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Variable; + } + } + + /// Field number for the "bare_concrete_function" field. + public const int BareConcreteFunctionFieldNumber = 8; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedBareConcreteFunction BareConcreteFunction { + get { return kindCase_ == KindOneofCase.BareConcreteFunction ? (global::Tensorflow.SavedBareConcreteFunction) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.BareConcreteFunction; + } + } + + /// Field number for the "constant" field. + public const int ConstantFieldNumber = 9; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedConstant Constant { + get { return kindCase_ == KindOneofCase.Constant ? (global::Tensorflow.SavedConstant) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Constant; + } + } + + /// Field number for the "resource" field. + public const int ResourceFieldNumber = 10; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SavedResource Resource { + get { return kindCase_ == KindOneofCase.Resource ? (global::Tensorflow.SavedResource) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.Resource; + } + } + + /// Field number for the "captured_tensor" field. + public const int CapturedTensorFieldNumber = 12; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.CapturedTensor CapturedTensor { + get { return kindCase_ == KindOneofCase.CapturedTensor ? (global::Tensorflow.CapturedTensor) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.CapturedTensor; + } + } + + /// Field number for the "saveable_objects" field. + public const int SaveableObjectsFieldNumber = 11; + private static readonly pbc::MapField.Codec _map_saveableObjects_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.SaveableObject.Parser), 90); + private readonly pbc::MapField saveableObjects_ = new pbc::MapField(); + /// + /// Stores the functions used to save and restore this object. At most one of + /// `saveable_objects` or `registered_saver` is defined for each SavedObject. + /// See the comment below for the difference between SaveableObject and + /// registered savers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField SaveableObjects { + get { return saveableObjects_; } + } + + /// Field number for the "registered_name" field. + public const int RegisteredNameFieldNumber = 13; + private string registeredName_ = ""; + /// + /// The name of the registered class of the form "{package}.{class_name}". + /// This field is used to search for the registered class at loading time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RegisteredName { + get { return registeredName_; } + set { + registeredName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "serialized_user_proto" field. + public const int SerializedUserProtoFieldNumber = 14; + private global::Google.Protobuf.WellKnownTypes.Any serializedUserProto_; + /// + /// The user-generated proto storing metadata for this object, to be passed to + /// the registered classes's _deserialize_from_proto method when this object is + /// loaded from the SavedModel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Google.Protobuf.WellKnownTypes.Any SerializedUserProto { + get { return serializedUserProto_; } + set { + serializedUserProto_ = value; + } + } + + /// Field number for the "registered_saver" field. + public const int RegisteredSaverFieldNumber = 16; + private string registeredSaver_ = ""; + /// + /// String name of the registered saver. At most one of `saveable_objects` or + /// `registered_saver` is defined for each SavedObject. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RegisteredSaver { + get { return registeredSaver_; } + set { + registeredSaver_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + private object kind_; + /// Enum of possible cases for the "kind" oneof. + public enum KindOneofCase { + None = 0, + UserObject = 4, + Asset = 5, + Function = 6, + Variable = 7, + BareConcreteFunction = 8, + Constant = 9, + Resource = 10, + CapturedTensor = 12, + } + private KindOneofCase kindCase_ = KindOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KindOneofCase KindCase { + get { return kindCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearKind() { + kindCase_ = KindOneofCase.None; + kind_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedObject); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedObject other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!children_.Equals(other.children_)) return false; + if(!dependencies_.Equals(other.dependencies_)) return false; + if(!slotVariables_.Equals(other.slotVariables_)) return false; + if (!object.Equals(UserObject, other.UserObject)) return false; + if (!object.Equals(Asset, other.Asset)) return false; + if (!object.Equals(Function, other.Function)) return false; + if (!object.Equals(Variable, other.Variable)) return false; + if (!object.Equals(BareConcreteFunction, other.BareConcreteFunction)) return false; + if (!object.Equals(Constant, other.Constant)) return false; + if (!object.Equals(Resource, other.Resource)) return false; + if (!object.Equals(CapturedTensor, other.CapturedTensor)) return false; + if (!SaveableObjects.Equals(other.SaveableObjects)) return false; + if (RegisteredName != other.RegisteredName) return false; + if (!object.Equals(SerializedUserProto, other.SerializedUserProto)) return false; + if (RegisteredSaver != other.RegisteredSaver) return false; + if (KindCase != other.KindCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= children_.GetHashCode(); + hash ^= dependencies_.GetHashCode(); + hash ^= slotVariables_.GetHashCode(); + if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); + if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); + if (kindCase_ == KindOneofCase.Function) hash ^= Function.GetHashCode(); + if (kindCase_ == KindOneofCase.Variable) hash ^= Variable.GetHashCode(); + if (kindCase_ == KindOneofCase.BareConcreteFunction) hash ^= BareConcreteFunction.GetHashCode(); + if (kindCase_ == KindOneofCase.Constant) hash ^= Constant.GetHashCode(); + if (kindCase_ == KindOneofCase.Resource) hash ^= Resource.GetHashCode(); + if (kindCase_ == KindOneofCase.CapturedTensor) hash ^= CapturedTensor.GetHashCode(); + hash ^= SaveableObjects.GetHashCode(); + if (RegisteredName.Length != 0) hash ^= RegisteredName.GetHashCode(); + if (serializedUserProto_ != null) hash ^= SerializedUserProto.GetHashCode(); + if (RegisteredSaver.Length != 0) hash ^= RegisteredSaver.GetHashCode(); + hash ^= (int) kindCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + children_.WriteTo(output, _repeated_children_codec); + slotVariables_.WriteTo(output, _repeated_slotVariables_codec); + if (kindCase_ == KindOneofCase.UserObject) { + output.WriteRawTag(34); + output.WriteMessage(UserObject); + } + if (kindCase_ == KindOneofCase.Asset) { + output.WriteRawTag(42); + output.WriteMessage(Asset); + } + if (kindCase_ == KindOneofCase.Function) { + output.WriteRawTag(50); + output.WriteMessage(Function); + } + if (kindCase_ == KindOneofCase.Variable) { + output.WriteRawTag(58); + output.WriteMessage(Variable); + } + if (kindCase_ == KindOneofCase.BareConcreteFunction) { + output.WriteRawTag(66); + output.WriteMessage(BareConcreteFunction); + } + if (kindCase_ == KindOneofCase.Constant) { + output.WriteRawTag(74); + output.WriteMessage(Constant); + } + if (kindCase_ == KindOneofCase.Resource) { + output.WriteRawTag(82); + output.WriteMessage(Resource); + } + saveableObjects_.WriteTo(output, _map_saveableObjects_codec); + if (kindCase_ == KindOneofCase.CapturedTensor) { + output.WriteRawTag(98); + output.WriteMessage(CapturedTensor); + } + if (RegisteredName.Length != 0) { + output.WriteRawTag(106); + output.WriteString(RegisteredName); + } + if (serializedUserProto_ != null) { + output.WriteRawTag(114); + output.WriteMessage(SerializedUserProto); + } + dependencies_.WriteTo(output, _repeated_dependencies_codec); + if (RegisteredSaver.Length != 0) { + output.WriteRawTag(130, 1); + output.WriteString(RegisteredSaver); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + children_.WriteTo(ref output, _repeated_children_codec); + slotVariables_.WriteTo(ref output, _repeated_slotVariables_codec); + if (kindCase_ == KindOneofCase.UserObject) { + output.WriteRawTag(34); + output.WriteMessage(UserObject); + } + if (kindCase_ == KindOneofCase.Asset) { + output.WriteRawTag(42); + output.WriteMessage(Asset); + } + if (kindCase_ == KindOneofCase.Function) { + output.WriteRawTag(50); + output.WriteMessage(Function); + } + if (kindCase_ == KindOneofCase.Variable) { + output.WriteRawTag(58); + output.WriteMessage(Variable); + } + if (kindCase_ == KindOneofCase.BareConcreteFunction) { + output.WriteRawTag(66); + output.WriteMessage(BareConcreteFunction); + } + if (kindCase_ == KindOneofCase.Constant) { + output.WriteRawTag(74); + output.WriteMessage(Constant); + } + if (kindCase_ == KindOneofCase.Resource) { + output.WriteRawTag(82); + output.WriteMessage(Resource); + } + saveableObjects_.WriteTo(ref output, _map_saveableObjects_codec); + if (kindCase_ == KindOneofCase.CapturedTensor) { + output.WriteRawTag(98); + output.WriteMessage(CapturedTensor); + } + if (RegisteredName.Length != 0) { + output.WriteRawTag(106); + output.WriteString(RegisteredName); + } + if (serializedUserProto_ != null) { + output.WriteRawTag(114); + output.WriteMessage(SerializedUserProto); + } + dependencies_.WriteTo(ref output, _repeated_dependencies_codec); + if (RegisteredSaver.Length != 0) { + output.WriteRawTag(130, 1); + output.WriteString(RegisteredSaver); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += children_.CalculateSize(_repeated_children_codec); + size += dependencies_.CalculateSize(_repeated_dependencies_codec); + size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); + if (kindCase_ == KindOneofCase.UserObject) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); + } + if (kindCase_ == KindOneofCase.Asset) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Asset); + } + if (kindCase_ == KindOneofCase.Function) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Function); + } + if (kindCase_ == KindOneofCase.Variable) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Variable); + } + if (kindCase_ == KindOneofCase.BareConcreteFunction) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BareConcreteFunction); + } + if (kindCase_ == KindOneofCase.Constant) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Constant); + } + if (kindCase_ == KindOneofCase.Resource) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Resource); + } + if (kindCase_ == KindOneofCase.CapturedTensor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CapturedTensor); + } + size += saveableObjects_.CalculateSize(_map_saveableObjects_codec); + if (RegisteredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RegisteredName); + } + if (serializedUserProto_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SerializedUserProto); + } + if (RegisteredSaver.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(RegisteredSaver); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedObject other) { + if (other == null) { + return; + } + children_.Add(other.children_); + dependencies_.Add(other.dependencies_); + slotVariables_.Add(other.slotVariables_); + saveableObjects_.Add(other.saveableObjects_); + if (other.RegisteredName.Length != 0) { + RegisteredName = other.RegisteredName; + } + if (other.serializedUserProto_ != null) { + if (serializedUserProto_ == null) { + SerializedUserProto = new global::Google.Protobuf.WellKnownTypes.Any(); + } + SerializedUserProto.MergeFrom(other.SerializedUserProto); + } + if (other.RegisteredSaver.Length != 0) { + RegisteredSaver = other.RegisteredSaver; + } + switch (other.KindCase) { + case KindOneofCase.UserObject: + if (UserObject == null) { + UserObject = new global::Tensorflow.SavedUserObject(); + } + UserObject.MergeFrom(other.UserObject); + break; + case KindOneofCase.Asset: + if (Asset == null) { + Asset = new global::Tensorflow.SavedAsset(); + } + Asset.MergeFrom(other.Asset); + break; + case KindOneofCase.Function: + if (Function == null) { + Function = new global::Tensorflow.SavedFunction(); + } + Function.MergeFrom(other.Function); + break; + case KindOneofCase.Variable: + if (Variable == null) { + Variable = new global::Tensorflow.SavedVariable(); + } + Variable.MergeFrom(other.Variable); + break; + case KindOneofCase.BareConcreteFunction: + if (BareConcreteFunction == null) { + BareConcreteFunction = new global::Tensorflow.SavedBareConcreteFunction(); + } + BareConcreteFunction.MergeFrom(other.BareConcreteFunction); + break; + case KindOneofCase.Constant: + if (Constant == null) { + Constant = new global::Tensorflow.SavedConstant(); + } + Constant.MergeFrom(other.Constant); + break; + case KindOneofCase.Resource: + if (Resource == null) { + Resource = new global::Tensorflow.SavedResource(); + } + Resource.MergeFrom(other.Resource); + break; + case KindOneofCase.CapturedTensor: + if (CapturedTensor == null) { + CapturedTensor = new global::Tensorflow.CapturedTensor(); + } + CapturedTensor.MergeFrom(other.CapturedTensor); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + children_.AddEntriesFrom(input, _repeated_children_codec); + break; + } + case 26: { + slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); + break; + } + case 34: { + global::Tensorflow.SavedUserObject subBuilder = new global::Tensorflow.SavedUserObject(); + if (kindCase_ == KindOneofCase.UserObject) { + subBuilder.MergeFrom(UserObject); + } + input.ReadMessage(subBuilder); + UserObject = subBuilder; + break; + } + case 42: { + global::Tensorflow.SavedAsset subBuilder = new global::Tensorflow.SavedAsset(); + if (kindCase_ == KindOneofCase.Asset) { + subBuilder.MergeFrom(Asset); + } + input.ReadMessage(subBuilder); + Asset = subBuilder; + break; + } + case 50: { + global::Tensorflow.SavedFunction subBuilder = new global::Tensorflow.SavedFunction(); + if (kindCase_ == KindOneofCase.Function) { + subBuilder.MergeFrom(Function); + } + input.ReadMessage(subBuilder); + Function = subBuilder; + break; + } + case 58: { + global::Tensorflow.SavedVariable subBuilder = new global::Tensorflow.SavedVariable(); + if (kindCase_ == KindOneofCase.Variable) { + subBuilder.MergeFrom(Variable); + } + input.ReadMessage(subBuilder); + Variable = subBuilder; + break; + } + case 66: { + global::Tensorflow.SavedBareConcreteFunction subBuilder = new global::Tensorflow.SavedBareConcreteFunction(); + if (kindCase_ == KindOneofCase.BareConcreteFunction) { + subBuilder.MergeFrom(BareConcreteFunction); + } + input.ReadMessage(subBuilder); + BareConcreteFunction = subBuilder; + break; + } + case 74: { + global::Tensorflow.SavedConstant subBuilder = new global::Tensorflow.SavedConstant(); + if (kindCase_ == KindOneofCase.Constant) { + subBuilder.MergeFrom(Constant); + } + input.ReadMessage(subBuilder); + Constant = subBuilder; + break; + } + case 82: { + global::Tensorflow.SavedResource subBuilder = new global::Tensorflow.SavedResource(); + if (kindCase_ == KindOneofCase.Resource) { + subBuilder.MergeFrom(Resource); + } + input.ReadMessage(subBuilder); + Resource = subBuilder; + break; + } + case 90: { + saveableObjects_.AddEntriesFrom(input, _map_saveableObjects_codec); + break; + } + case 98: { + global::Tensorflow.CapturedTensor subBuilder = new global::Tensorflow.CapturedTensor(); + if (kindCase_ == KindOneofCase.CapturedTensor) { + subBuilder.MergeFrom(CapturedTensor); + } + input.ReadMessage(subBuilder); + CapturedTensor = subBuilder; + break; + } + case 106: { + RegisteredName = input.ReadString(); + break; + } + case 114: { + if (serializedUserProto_ == null) { + SerializedUserProto = new global::Google.Protobuf.WellKnownTypes.Any(); + } + input.ReadMessage(SerializedUserProto); + break; + } + case 122: { + dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec); + break; + } + case 130: { + RegisteredSaver = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + children_.AddEntriesFrom(ref input, _repeated_children_codec); + break; + } + case 26: { + slotVariables_.AddEntriesFrom(ref input, _repeated_slotVariables_codec); + break; + } + case 34: { + global::Tensorflow.SavedUserObject subBuilder = new global::Tensorflow.SavedUserObject(); + if (kindCase_ == KindOneofCase.UserObject) { + subBuilder.MergeFrom(UserObject); + } + input.ReadMessage(subBuilder); + UserObject = subBuilder; + break; + } + case 42: { + global::Tensorflow.SavedAsset subBuilder = new global::Tensorflow.SavedAsset(); + if (kindCase_ == KindOneofCase.Asset) { + subBuilder.MergeFrom(Asset); + } + input.ReadMessage(subBuilder); + Asset = subBuilder; + break; + } + case 50: { + global::Tensorflow.SavedFunction subBuilder = new global::Tensorflow.SavedFunction(); + if (kindCase_ == KindOneofCase.Function) { + subBuilder.MergeFrom(Function); + } + input.ReadMessage(subBuilder); + Function = subBuilder; + break; + } + case 58: { + global::Tensorflow.SavedVariable subBuilder = new global::Tensorflow.SavedVariable(); + if (kindCase_ == KindOneofCase.Variable) { + subBuilder.MergeFrom(Variable); + } + input.ReadMessage(subBuilder); + Variable = subBuilder; + break; + } + case 66: { + global::Tensorflow.SavedBareConcreteFunction subBuilder = new global::Tensorflow.SavedBareConcreteFunction(); + if (kindCase_ == KindOneofCase.BareConcreteFunction) { + subBuilder.MergeFrom(BareConcreteFunction); + } + input.ReadMessage(subBuilder); + BareConcreteFunction = subBuilder; + break; + } + case 74: { + global::Tensorflow.SavedConstant subBuilder = new global::Tensorflow.SavedConstant(); + if (kindCase_ == KindOneofCase.Constant) { + subBuilder.MergeFrom(Constant); + } + input.ReadMessage(subBuilder); + Constant = subBuilder; + break; + } + case 82: { + global::Tensorflow.SavedResource subBuilder = new global::Tensorflow.SavedResource(); + if (kindCase_ == KindOneofCase.Resource) { + subBuilder.MergeFrom(Resource); + } + input.ReadMessage(subBuilder); + Resource = subBuilder; + break; + } + case 90: { + saveableObjects_.AddEntriesFrom(ref input, _map_saveableObjects_codec); + break; + } + case 98: { + global::Tensorflow.CapturedTensor subBuilder = new global::Tensorflow.CapturedTensor(); + if (kindCase_ == KindOneofCase.CapturedTensor) { + subBuilder.MergeFrom(CapturedTensor); + } + input.ReadMessage(subBuilder); + CapturedTensor = subBuilder; + break; + } + case 106: { + RegisteredName = input.ReadString(); + break; + } + case 114: { + if (serializedUserProto_ == null) { + SerializedUserProto = new global::Google.Protobuf.WellKnownTypes.Any(); + } + input.ReadMessage(SerializedUserProto); + break; + } + case 122: { + dependencies_.AddEntriesFrom(ref input, _repeated_dependencies_codec); + break; + } + case 130: { + RegisteredSaver = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// A SavedUserObject is an object (in the object-oriented language of the + /// TensorFlow program) of some user- or framework-defined class other than + /// those handled specifically by the other kinds of SavedObjects. + /// + /// This object cannot be evaluated as a tensor, and therefore cannot be bound + /// to an input of a function. + /// + public sealed partial class SavedUserObject : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedUserObject()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedUserObject() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedUserObject(SavedUserObject other) : this() { + identifier_ = other.identifier_; + version_ = other.version_ != null ? other.version_.Clone() : null; + metadata_ = other.metadata_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedUserObject Clone() { + return new SavedUserObject(this); + } + + /// Field number for the "identifier" field. + public const int IdentifierFieldNumber = 1; + private string identifier_ = ""; + /// + /// Corresponds to a registration of the type to use in the loading program. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Identifier { + get { return identifier_; } + set { + identifier_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 2; + private global::Tensorflow.VersionDef version_; + /// + /// Version information from the producer of this SavedUserObject. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VersionDef Version { + get { return version_; } + set { + version_ = value; + } + } + + /// Field number for the "metadata" field. + public const int MetadataFieldNumber = 3; + private string metadata_ = ""; + /// + /// Metadata for deserializing this object. + /// + /// Deprecated! At the time of deprecation, Keras was the only user of this + /// field, and its saving and loading code will be updated shortly. + /// Please save your application-specific metadata to a separate file. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Metadata { + get { return metadata_; } + set { + metadata_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedUserObject); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedUserObject other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Identifier != other.Identifier) return false; + if (!object.Equals(Version, other.Version)) return false; + if (Metadata != other.Metadata) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Identifier.Length != 0) hash ^= Identifier.GetHashCode(); + if (version_ != null) hash ^= Version.GetHashCode(); + if (Metadata.Length != 0) hash ^= Metadata.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Identifier.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Identifier); + } + if (version_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Version); + } + if (Metadata.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Metadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Identifier.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Identifier); + } + if (version_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Version); + } + if (Metadata.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Metadata); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Identifier.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Identifier); + } + if (version_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Version); + } + if (Metadata.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Metadata); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedUserObject other) { + if (other == null) { + return; + } + if (other.Identifier.Length != 0) { + Identifier = other.Identifier; + } + if (other.version_ != null) { + if (version_ == null) { + Version = new global::Tensorflow.VersionDef(); + } + Version.MergeFrom(other.Version); + } + if (other.Metadata.Length != 0) { + Metadata = other.Metadata; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Identifier = input.ReadString(); + break; + } + case 18: { + if (version_ == null) { + Version = new global::Tensorflow.VersionDef(); + } + input.ReadMessage(Version); + break; + } + case 26: { + Metadata = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Identifier = input.ReadString(); + break; + } + case 18: { + if (version_ == null) { + Version = new global::Tensorflow.VersionDef(); + } + input.ReadMessage(Version); + break; + } + case 26: { + Metadata = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// A SavedAsset points to an asset in the MetaGraph. + /// + /// When bound to a function this object evaluates to a tensor with the absolute + /// filename. Users should not depend on a particular part of the filename to + /// remain stable (e.g. basename could be changed). + /// + public sealed partial class SavedAsset : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedAsset()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedAsset() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedAsset(SavedAsset other) : this() { + assetFileDefIndex_ = other.assetFileDefIndex_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedAsset Clone() { + return new SavedAsset(this); + } + + /// Field number for the "asset_file_def_index" field. + public const int AssetFileDefIndexFieldNumber = 1; + private int assetFileDefIndex_; + /// + /// Index into `MetaGraphDef.asset_file_def[]` that describes the Asset. + /// + /// Only the field `AssetFileDef.filename` is used. Other fields, such as + /// `AssetFileDef.tensor_info`, MUST be ignored. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int AssetFileDefIndex { + get { return assetFileDefIndex_; } + set { + assetFileDefIndex_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedAsset); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedAsset other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AssetFileDefIndex != other.AssetFileDefIndex) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (AssetFileDefIndex != 0) hash ^= AssetFileDefIndex.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (AssetFileDefIndex != 0) { + output.WriteRawTag(8); + output.WriteInt32(AssetFileDefIndex); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (AssetFileDefIndex != 0) { + output.WriteRawTag(8); + output.WriteInt32(AssetFileDefIndex); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (AssetFileDefIndex != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(AssetFileDefIndex); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedAsset other) { + if (other == null) { + return; + } + if (other.AssetFileDefIndex != 0) { + AssetFileDefIndex = other.AssetFileDefIndex; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AssetFileDefIndex = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + AssetFileDefIndex = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + /// + /// A function with multiple signatures, possibly with non-Tensor arguments. + /// + public sealed partial class SavedFunction : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedFunction()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedFunction() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedFunction(SavedFunction other) : this() { + concreteFunctions_ = other.concreteFunctions_.Clone(); + functionSpec_ = other.functionSpec_ != null ? other.functionSpec_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedFunction Clone() { + return new SavedFunction(this); + } + + /// Field number for the "concrete_functions" field. + public const int ConcreteFunctionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_concreteFunctions_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField concreteFunctions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ConcreteFunctions { + get { return concreteFunctions_; } + } + + /// Field number for the "function_spec" field. + public const int FunctionSpecFieldNumber = 2; + private global::Tensorflow.FunctionSpec functionSpec_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FunctionSpec FunctionSpec { + get { return functionSpec_; } + set { + functionSpec_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedFunction); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedFunction other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!concreteFunctions_.Equals(other.concreteFunctions_)) return false; + if (!object.Equals(FunctionSpec, other.FunctionSpec)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= concreteFunctions_.GetHashCode(); + if (functionSpec_ != null) hash ^= FunctionSpec.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + concreteFunctions_.WriteTo(output, _repeated_concreteFunctions_codec); + if (functionSpec_ != null) { + output.WriteRawTag(18); + output.WriteMessage(FunctionSpec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + concreteFunctions_.WriteTo(ref output, _repeated_concreteFunctions_codec); + if (functionSpec_ != null) { + output.WriteRawTag(18); + output.WriteMessage(FunctionSpec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += concreteFunctions_.CalculateSize(_repeated_concreteFunctions_codec); + if (functionSpec_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FunctionSpec); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedFunction other) { + if (other == null) { + return; + } + concreteFunctions_.Add(other.concreteFunctions_); + if (other.functionSpec_ != null) { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + FunctionSpec.MergeFrom(other.FunctionSpec); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + concreteFunctions_.AddEntriesFrom(input, _repeated_concreteFunctions_codec); + break; + } + case 18: { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + input.ReadMessage(FunctionSpec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + concreteFunctions_.AddEntriesFrom(ref input, _repeated_concreteFunctions_codec); + break; + } + case 18: { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + input.ReadMessage(FunctionSpec); + break; + } + } + } + } + #endif + + } + + public sealed partial class CapturedTensor : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CapturedTensor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CapturedTensor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CapturedTensor(CapturedTensor other) : this() { + name_ = other.name_; + concreteFunction_ = other.concreteFunction_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CapturedTensor Clone() { + return new CapturedTensor(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Name of captured tensor + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "concrete_function" field. + public const int ConcreteFunctionFieldNumber = 2; + private string concreteFunction_ = ""; + /// + /// Name of concrete function which contains the computed graph tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ConcreteFunction { + get { return concreteFunction_; } + set { + concreteFunction_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CapturedTensor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CapturedTensor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ConcreteFunction != other.ConcreteFunction) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ConcreteFunction.Length != 0) hash ^= ConcreteFunction.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ConcreteFunction.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ConcreteFunction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ConcreteFunction.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ConcreteFunction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ConcreteFunction.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ConcreteFunction); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CapturedTensor other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ConcreteFunction.Length != 0) { + ConcreteFunction = other.ConcreteFunction; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + ConcreteFunction = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + ConcreteFunction = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Stores low-level information about a concrete function. Referenced in either + /// a SavedFunction or a SavedBareConcreteFunction. + /// + public sealed partial class SavedConcreteFunction : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedConcreteFunction()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConcreteFunction() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConcreteFunction(SavedConcreteFunction other) : this() { + boundInputs_ = other.boundInputs_.Clone(); + canonicalizedInputSignature_ = other.canonicalizedInputSignature_ != null ? other.canonicalizedInputSignature_.Clone() : null; + outputSignature_ = other.outputSignature_ != null ? other.outputSignature_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConcreteFunction Clone() { + return new SavedConcreteFunction(this); + } + + /// Field number for the "bound_inputs" field. + public const int BoundInputsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_boundInputs_codec + = pb::FieldCodec.ForInt32(18); + private readonly pbc::RepeatedField boundInputs_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField BoundInputs { + get { return boundInputs_; } + } + + /// Field number for the "canonicalized_input_signature" field. + public const int CanonicalizedInputSignatureFieldNumber = 3; + private global::Tensorflow.StructuredValue canonicalizedInputSignature_; + /// + /// Input in canonicalized form that was received to create this concrete + /// function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue CanonicalizedInputSignature { + get { return canonicalizedInputSignature_; } + set { + canonicalizedInputSignature_ = value; + } + } + + /// Field number for the "output_signature" field. + public const int OutputSignatureFieldNumber = 4; + private global::Tensorflow.StructuredValue outputSignature_; + /// + /// Output that was the return value of this function after replacing all + /// Tensors with TensorSpecs. This can be an arbitrary nested function and will + /// be used to reconstruct the full structure from pure tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue OutputSignature { + get { return outputSignature_; } + set { + outputSignature_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedConcreteFunction); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedConcreteFunction other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!boundInputs_.Equals(other.boundInputs_)) return false; + if (!object.Equals(CanonicalizedInputSignature, other.CanonicalizedInputSignature)) return false; + if (!object.Equals(OutputSignature, other.OutputSignature)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= boundInputs_.GetHashCode(); + if (canonicalizedInputSignature_ != null) hash ^= CanonicalizedInputSignature.GetHashCode(); + if (outputSignature_ != null) hash ^= OutputSignature.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + boundInputs_.WriteTo(output, _repeated_boundInputs_codec); + if (canonicalizedInputSignature_ != null) { + output.WriteRawTag(26); + output.WriteMessage(CanonicalizedInputSignature); + } + if (outputSignature_ != null) { + output.WriteRawTag(34); + output.WriteMessage(OutputSignature); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + boundInputs_.WriteTo(ref output, _repeated_boundInputs_codec); + if (canonicalizedInputSignature_ != null) { + output.WriteRawTag(26); + output.WriteMessage(CanonicalizedInputSignature); + } + if (outputSignature_ != null) { + output.WriteRawTag(34); + output.WriteMessage(OutputSignature); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += boundInputs_.CalculateSize(_repeated_boundInputs_codec); + if (canonicalizedInputSignature_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CanonicalizedInputSignature); + } + if (outputSignature_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(OutputSignature); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedConcreteFunction other) { + if (other == null) { + return; + } + boundInputs_.Add(other.boundInputs_); + if (other.canonicalizedInputSignature_ != null) { + if (canonicalizedInputSignature_ == null) { + CanonicalizedInputSignature = new global::Tensorflow.StructuredValue(); + } + CanonicalizedInputSignature.MergeFrom(other.CanonicalizedInputSignature); + } + if (other.outputSignature_ != null) { + if (outputSignature_ == null) { + OutputSignature = new global::Tensorflow.StructuredValue(); + } + OutputSignature.MergeFrom(other.OutputSignature); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: + case 16: { + boundInputs_.AddEntriesFrom(input, _repeated_boundInputs_codec); + break; + } + case 26: { + if (canonicalizedInputSignature_ == null) { + CanonicalizedInputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(CanonicalizedInputSignature); + break; + } + case 34: { + if (outputSignature_ == null) { + OutputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(OutputSignature); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 18: + case 16: { + boundInputs_.AddEntriesFrom(ref input, _repeated_boundInputs_codec); + break; + } + case 26: { + if (canonicalizedInputSignature_ == null) { + CanonicalizedInputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(CanonicalizedInputSignature); + break; + } + case 34: { + if (outputSignature_ == null) { + OutputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(OutputSignature); + break; + } + } + } + } + #endif + + } + + public sealed partial class SavedBareConcreteFunction : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedBareConcreteFunction()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedBareConcreteFunction() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedBareConcreteFunction(SavedBareConcreteFunction other) : this() { + concreteFunctionName_ = other.concreteFunctionName_; + argumentKeywords_ = other.argumentKeywords_.Clone(); + allowedPositionalArguments_ = other.allowedPositionalArguments_; + functionSpec_ = other.functionSpec_ != null ? other.functionSpec_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedBareConcreteFunction Clone() { + return new SavedBareConcreteFunction(this); + } + + /// Field number for the "concrete_function_name" field. + public const int ConcreteFunctionNameFieldNumber = 1; + private string concreteFunctionName_ = ""; + /// + /// Identifies a SavedConcreteFunction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ConcreteFunctionName { + get { return concreteFunctionName_; } + set { + concreteFunctionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "argument_keywords" field. + public const int ArgumentKeywordsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_argumentKeywords_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField argumentKeywords_ = new pbc::RepeatedField(); + /// + /// A sequence of unique strings, one per Tensor argument. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ArgumentKeywords { + get { return argumentKeywords_; } + } + + /// Field number for the "allowed_positional_arguments" field. + public const int AllowedPositionalArgumentsFieldNumber = 3; + private long allowedPositionalArguments_; + /// + /// The prefix of `argument_keywords` which may be identified by position. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllowedPositionalArguments { + get { return allowedPositionalArguments_; } + set { + allowedPositionalArguments_ = value; + } + } + + /// Field number for the "function_spec" field. + public const int FunctionSpecFieldNumber = 4; + private global::Tensorflow.FunctionSpec functionSpec_; + /// + /// The spec of the function that this ConcreteFunction is traced from. This + /// allows the ConcreteFunction to be called with nest structure inputs. This + /// field may not be populated. If this field is absent, the concrete function + /// can only be called with flat inputs. + /// TODO(b/169361281): support calling saved ConcreteFunction with structured + /// inputs in C++ SavedModel API. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FunctionSpec FunctionSpec { + get { return functionSpec_; } + set { + functionSpec_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedBareConcreteFunction); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedBareConcreteFunction other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ConcreteFunctionName != other.ConcreteFunctionName) return false; + if(!argumentKeywords_.Equals(other.argumentKeywords_)) return false; + if (AllowedPositionalArguments != other.AllowedPositionalArguments) return false; + if (!object.Equals(FunctionSpec, other.FunctionSpec)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ConcreteFunctionName.Length != 0) hash ^= ConcreteFunctionName.GetHashCode(); + hash ^= argumentKeywords_.GetHashCode(); + if (AllowedPositionalArguments != 0L) hash ^= AllowedPositionalArguments.GetHashCode(); + if (functionSpec_ != null) hash ^= FunctionSpec.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ConcreteFunctionName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ConcreteFunctionName); + } + argumentKeywords_.WriteTo(output, _repeated_argumentKeywords_codec); + if (AllowedPositionalArguments != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AllowedPositionalArguments); + } + if (functionSpec_ != null) { + output.WriteRawTag(34); + output.WriteMessage(FunctionSpec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ConcreteFunctionName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ConcreteFunctionName); + } + argumentKeywords_.WriteTo(ref output, _repeated_argumentKeywords_codec); + if (AllowedPositionalArguments != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AllowedPositionalArguments); + } + if (functionSpec_ != null) { + output.WriteRawTag(34); + output.WriteMessage(FunctionSpec); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ConcreteFunctionName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ConcreteFunctionName); + } + size += argumentKeywords_.CalculateSize(_repeated_argumentKeywords_codec); + if (AllowedPositionalArguments != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllowedPositionalArguments); + } + if (functionSpec_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FunctionSpec); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedBareConcreteFunction other) { + if (other == null) { + return; + } + if (other.ConcreteFunctionName.Length != 0) { + ConcreteFunctionName = other.ConcreteFunctionName; + } + argumentKeywords_.Add(other.argumentKeywords_); + if (other.AllowedPositionalArguments != 0L) { + AllowedPositionalArguments = other.AllowedPositionalArguments; + } + if (other.functionSpec_ != null) { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + FunctionSpec.MergeFrom(other.FunctionSpec); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ConcreteFunctionName = input.ReadString(); + break; + } + case 18: { + argumentKeywords_.AddEntriesFrom(input, _repeated_argumentKeywords_codec); + break; + } + case 24: { + AllowedPositionalArguments = input.ReadInt64(); + break; + } + case 34: { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + input.ReadMessage(FunctionSpec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ConcreteFunctionName = input.ReadString(); + break; + } + case 18: { + argumentKeywords_.AddEntriesFrom(ref input, _repeated_argumentKeywords_codec); + break; + } + case 24: { + AllowedPositionalArguments = input.ReadInt64(); + break; + } + case 34: { + if (functionSpec_ == null) { + FunctionSpec = new global::Tensorflow.FunctionSpec(); + } + input.ReadMessage(FunctionSpec); + break; + } + } + } + } + #endif + + } + + public sealed partial class SavedConstant : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedConstant()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConstant() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConstant(SavedConstant other) : this() { + operation_ = other.operation_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedConstant Clone() { + return new SavedConstant(this); + } + + /// Field number for the "operation" field. + public const int OperationFieldNumber = 1; + private string operation_ = ""; + /// + /// An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Operation { + get { return operation_; } + set { + operation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedConstant); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedConstant other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Operation != other.Operation) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Operation.Length != 0) hash ^= Operation.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Operation.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Operation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Operation.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Operation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Operation.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Operation); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedConstant other) { + if (other == null) { + return; + } + if (other.Operation.Length != 0) { + Operation = other.Operation; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Operation = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Operation = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// Represents a Variable that is initialized by loading the contents from the + /// checkpoint. + /// + public sealed partial class SavedVariable : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedVariable()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedVariable() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedVariable(SavedVariable other) : this() { + dtype_ = other.dtype_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + trainable_ = other.trainable_; + synchronization_ = other.synchronization_; + aggregation_ = other.aggregation_; + name_ = other.name_; + device_ = other.device_; + experimentalDistributedVariableComponents_ = other.experimentalDistributedVariableComponents_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedVariable Clone() { + return new SavedVariable(this); + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 1; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "trainable" field. + public const int TrainableFieldNumber = 3; + private bool trainable_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Trainable { + get { return trainable_; } + set { + trainable_ = value; + } + } + + /// Field number for the "synchronization" field. + public const int SynchronizationFieldNumber = 4; + private global::Tensorflow.VariableSynchronization synchronization_ = global::Tensorflow.VariableSynchronization.Auto; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VariableSynchronization Synchronization { + get { return synchronization_; } + set { + synchronization_ = value; + } + } + + /// Field number for the "aggregation" field. + public const int AggregationFieldNumber = 5; + private global::Tensorflow.VariableAggregation aggregation_ = global::Tensorflow.VariableAggregation.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VariableAggregation Aggregation { + get { return aggregation_; } + set { + aggregation_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 6; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 7; + private string device_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "experimental_distributed_variable_components" field. + public const int ExperimentalDistributedVariableComponentsFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_experimentalDistributedVariableComponents_codec + = pb::FieldCodec.ForMessage(66, global::Tensorflow.SavedVariable.Parser); + private readonly pbc::RepeatedField experimentalDistributedVariableComponents_ = new pbc::RepeatedField(); + /// + /// List of component variables for a distributed variable. + /// + /// When this field is non-empty, the SavedVariable will be assumed + /// to be a distributed variable defined by the components listed here. + /// + /// This is only supported by experimental loaders at the moment. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ExperimentalDistributedVariableComponents { + get { return experimentalDistributedVariableComponents_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedVariable); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedVariable other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Dtype != other.Dtype) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (Trainable != other.Trainable) return false; + if (Synchronization != other.Synchronization) return false; + if (Aggregation != other.Aggregation) return false; + if (Name != other.Name) return false; + if (Device != other.Device) return false; + if(!experimentalDistributedVariableComponents_.Equals(other.experimentalDistributedVariableComponents_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (Trainable != false) hash ^= Trainable.GetHashCode(); + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) hash ^= Synchronization.GetHashCode(); + if (Aggregation != global::Tensorflow.VariableAggregation.None) hash ^= Aggregation.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Device.Length != 0) hash ^= Device.GetHashCode(); + hash ^= experimentalDistributedVariableComponents_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Trainable != false) { + output.WriteRawTag(24); + output.WriteBool(Trainable); + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + output.WriteRawTag(32); + output.WriteEnum((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + output.WriteRawTag(40); + output.WriteEnum((int) Aggregation); + } + if (Name.Length != 0) { + output.WriteRawTag(50); + output.WriteString(Name); + } + if (Device.Length != 0) { + output.WriteRawTag(58); + output.WriteString(Device); + } + experimentalDistributedVariableComponents_.WriteTo(output, _repeated_experimentalDistributedVariableComponents_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Trainable != false) { + output.WriteRawTag(24); + output.WriteBool(Trainable); + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + output.WriteRawTag(32); + output.WriteEnum((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + output.WriteRawTag(40); + output.WriteEnum((int) Aggregation); + } + if (Name.Length != 0) { + output.WriteRawTag(50); + output.WriteString(Name); + } + if (Device.Length != 0) { + output.WriteRawTag(58); + output.WriteString(Device); + } + experimentalDistributedVariableComponents_.WriteTo(ref output, _repeated_experimentalDistributedVariableComponents_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (Trainable != false) { + size += 1 + 1; + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Aggregation); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + size += experimentalDistributedVariableComponents_.CalculateSize(_repeated_experimentalDistributedVariableComponents_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedVariable other) { + if (other == null) { + return; + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.Trainable != false) { + Trainable = other.Trainable; + } + if (other.Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + Synchronization = other.Synchronization; + } + if (other.Aggregation != global::Tensorflow.VariableAggregation.None) { + Aggregation = other.Aggregation; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Device.Length != 0) { + Device = other.Device; + } + experimentalDistributedVariableComponents_.Add(other.experimentalDistributedVariableComponents_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Trainable = input.ReadBool(); + break; + } + case 32: { + Synchronization = (global::Tensorflow.VariableSynchronization) input.ReadEnum(); + break; + } + case 40: { + Aggregation = (global::Tensorflow.VariableAggregation) input.ReadEnum(); + break; + } + case 50: { + Name = input.ReadString(); + break; + } + case 58: { + Device = input.ReadString(); + break; + } + case 66: { + experimentalDistributedVariableComponents_.AddEntriesFrom(input, _repeated_experimentalDistributedVariableComponents_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Trainable = input.ReadBool(); + break; + } + case 32: { + Synchronization = (global::Tensorflow.VariableSynchronization) input.ReadEnum(); + break; + } + case 40: { + Aggregation = (global::Tensorflow.VariableAggregation) input.ReadEnum(); + break; + } + case 50: { + Name = input.ReadString(); + break; + } + case 58: { + Device = input.ReadString(); + break; + } + case 66: { + experimentalDistributedVariableComponents_.AddEntriesFrom(ref input, _repeated_experimentalDistributedVariableComponents_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Represents `FunctionSpec` used in `Function`. This represents a + /// function that has been wrapped as a TensorFlow `Function`. + /// + public sealed partial class FunctionSpec : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionSpec()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionSpec() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionSpec(FunctionSpec other) : this() { + fullargspec_ = other.fullargspec_ != null ? other.fullargspec_.Clone() : null; + isMethod_ = other.isMethod_; + inputSignature_ = other.inputSignature_ != null ? other.inputSignature_.Clone() : null; + jitCompile_ = other.jitCompile_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FunctionSpec Clone() { + return new FunctionSpec(this); + } + + /// Field number for the "fullargspec" field. + public const int FullargspecFieldNumber = 1; + private global::Tensorflow.StructuredValue fullargspec_; + /// + /// Full arg spec from inspect.getfullargspec(). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue Fullargspec { + get { return fullargspec_; } + set { + fullargspec_ = value; + } + } + + /// Field number for the "is_method" field. + public const int IsMethodFieldNumber = 2; + private bool isMethod_; + /// + /// Whether this represents a class method. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsMethod { + get { return isMethod_; } + set { + isMethod_ = value; + } + } + + /// Field number for the "input_signature" field. + public const int InputSignatureFieldNumber = 5; + private global::Tensorflow.StructuredValue inputSignature_; + /// + /// The input signature, if specified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue InputSignature { + get { return inputSignature_; } + set { + inputSignature_ = value; + } + } + + /// Field number for the "jit_compile" field. + public const int JitCompileFieldNumber = 6; + private global::Tensorflow.FunctionSpec.Types.JitCompile jitCompile_ = global::Tensorflow.FunctionSpec.Types.JitCompile.Default; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.FunctionSpec.Types.JitCompile JitCompile { + get { return jitCompile_; } + set { + jitCompile_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FunctionSpec); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FunctionSpec other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Fullargspec, other.Fullargspec)) return false; + if (IsMethod != other.IsMethod) return false; + if (!object.Equals(InputSignature, other.InputSignature)) return false; + if (JitCompile != other.JitCompile) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (fullargspec_ != null) hash ^= Fullargspec.GetHashCode(); + if (IsMethod != false) hash ^= IsMethod.GetHashCode(); + if (inputSignature_ != null) hash ^= InputSignature.GetHashCode(); + if (JitCompile != global::Tensorflow.FunctionSpec.Types.JitCompile.Default) hash ^= JitCompile.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (fullargspec_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Fullargspec); + } + if (IsMethod != false) { + output.WriteRawTag(16); + output.WriteBool(IsMethod); + } + if (inputSignature_ != null) { + output.WriteRawTag(42); + output.WriteMessage(InputSignature); + } + if (JitCompile != global::Tensorflow.FunctionSpec.Types.JitCompile.Default) { + output.WriteRawTag(48); + output.WriteEnum((int) JitCompile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (fullargspec_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Fullargspec); + } + if (IsMethod != false) { + output.WriteRawTag(16); + output.WriteBool(IsMethod); + } + if (inputSignature_ != null) { + output.WriteRawTag(42); + output.WriteMessage(InputSignature); + } + if (JitCompile != global::Tensorflow.FunctionSpec.Types.JitCompile.Default) { + output.WriteRawTag(48); + output.WriteEnum((int) JitCompile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (fullargspec_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Fullargspec); + } + if (IsMethod != false) { + size += 1 + 1; + } + if (inputSignature_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(InputSignature); + } + if (JitCompile != global::Tensorflow.FunctionSpec.Types.JitCompile.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) JitCompile); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FunctionSpec other) { + if (other == null) { + return; + } + if (other.fullargspec_ != null) { + if (fullargspec_ == null) { + Fullargspec = new global::Tensorflow.StructuredValue(); + } + Fullargspec.MergeFrom(other.Fullargspec); + } + if (other.IsMethod != false) { + IsMethod = other.IsMethod; + } + if (other.inputSignature_ != null) { + if (inputSignature_ == null) { + InputSignature = new global::Tensorflow.StructuredValue(); + } + InputSignature.MergeFrom(other.InputSignature); + } + if (other.JitCompile != global::Tensorflow.FunctionSpec.Types.JitCompile.Default) { + JitCompile = other.JitCompile; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (fullargspec_ == null) { + Fullargspec = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(Fullargspec); + break; + } + case 16: { + IsMethod = input.ReadBool(); + break; + } + case 42: { + if (inputSignature_ == null) { + InputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(InputSignature); + break; + } + case 48: { + JitCompile = (global::Tensorflow.FunctionSpec.Types.JitCompile) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (fullargspec_ == null) { + Fullargspec = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(Fullargspec); + break; + } + case 16: { + IsMethod = input.ReadBool(); + break; + } + case 42: { + if (inputSignature_ == null) { + InputSignature = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(InputSignature); + break; + } + case 48: { + JitCompile = (global::Tensorflow.FunctionSpec.Types.JitCompile) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the FunctionSpec message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Whether the function should be compiled by XLA. + /// + /// The public interface to `tf.function` uses an optional boolean to + /// represent three distinct states for this field. Unfortunately, proto3 + /// removes the ability to explicitly check for the presence or absence of a + /// field, so we instead map to an enum. + /// + /// See `tf.function` for details. + /// + public enum JitCompile { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("ON")] On = 1, + [pbr::OriginalName("OFF")] Off = 2, + } + + } + #endregion + + } + + /// + /// A SavedResource represents a TF object that holds state during its lifetime. + /// An object of this type can have a reference to a: + /// create_resource() and an initialize() function. + /// + public sealed partial class SavedResource : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedResource()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedResource() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedResource(SavedResource other) : this() { + device_ = other.device_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SavedResource Clone() { + return new SavedResource(this); + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 1; + private string device_ = ""; + /// + /// A device specification indicating a required placement for the resource + /// creation function, e.g. "CPU". An empty string allows the user to select a + /// device. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SavedResource); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SavedResource other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Device != other.Device) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Device.Length != 0) hash ^= Device.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SavedResource other) { + if (other == null) { + return; + } + if (other.Device.Length != 0) { + Device = other.Device; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Device = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Device = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class SaveableObject : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaveableObject()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveableObject() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveableObject(SaveableObject other) : this() { + saveFunction_ = other.saveFunction_; + restoreFunction_ = other.restoreFunction_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveableObject Clone() { + return new SaveableObject(this); + } + + /// Field number for the "save_function" field. + public const int SaveFunctionFieldNumber = 2; + private int saveFunction_; + /// + /// Node ids of concrete functions for saving and loading from a checkpoint. + /// These functions save and restore directly from tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int SaveFunction { + get { return saveFunction_; } + set { + saveFunction_ = value; + } + } + + /// Field number for the "restore_function" field. + public const int RestoreFunctionFieldNumber = 3; + private int restoreFunction_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int RestoreFunction { + get { return restoreFunction_; } + set { + restoreFunction_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SaveableObject); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SaveableObject other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SaveFunction != other.SaveFunction) return false; + if (RestoreFunction != other.RestoreFunction) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (SaveFunction != 0) hash ^= SaveFunction.GetHashCode(); + if (RestoreFunction != 0) hash ^= RestoreFunction.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (SaveFunction != 0) { + output.WriteRawTag(16); + output.WriteInt32(SaveFunction); + } + if (RestoreFunction != 0) { + output.WriteRawTag(24); + output.WriteInt32(RestoreFunction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (SaveFunction != 0) { + output.WriteRawTag(16); + output.WriteInt32(SaveFunction); + } + if (RestoreFunction != 0) { + output.WriteRawTag(24); + output.WriteInt32(RestoreFunction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (SaveFunction != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(SaveFunction); + } + if (RestoreFunction != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(RestoreFunction); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SaveableObject other) { + if (other == null) { + return; + } + if (other.SaveFunction != 0) { + SaveFunction = other.SaveFunction; + } + if (other.RestoreFunction != 0) { + RestoreFunction = other.RestoreFunction; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 16: { + SaveFunction = input.ReadInt32(); + break; + } + case 24: { + RestoreFunction = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 16: { + SaveFunction = input.ReadInt32(); + break; + } + case 24: { + RestoreFunction = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Saver.cs b/src/TensorFlowNET.Core/Protobuf/Saver.cs new file mode 100644 index 000000000..fac25e329 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Saver.cs @@ -0,0 +1,517 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/saver.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/saver.proto + public static partial class SaverReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/saver.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SaverReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiR0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZXIucHJvdG8SCnRlbnNv", + "cmZsb3cingIKCFNhdmVyRGVmEhwKFGZpbGVuYW1lX3RlbnNvcl9uYW1lGAEg", + "ASgJEhgKEHNhdmVfdGVuc29yX25hbWUYAiABKAkSFwoPcmVzdG9yZV9vcF9u", + "YW1lGAMgASgJEhMKC21heF90b19rZWVwGAQgASgFEg8KB3NoYXJkZWQYBSAB", + "KAgSJQoda2VlcF9jaGVja3BvaW50X2V2ZXJ5X25faG91cnMYBiABKAISPQoH", + "dmVyc2lvbhgHIAEoDjIsLnRlbnNvcmZsb3cuU2F2ZXJEZWYuQ2hlY2twb2lu", + "dEZvcm1hdFZlcnNpb24iNQoXQ2hlY2twb2ludEZvcm1hdFZlcnNpb24SCgoG", + "TEVHQUNZEAASBgoCVjEQARIGCgJWMhACQn4KE29yZy50ZW5zb3JmbG93LnV0", + "aWxCC1NhdmVyUHJvdG9zUAFaVWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5z", + "b3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1Zi9mb3JfY29yZV9w", + "cm90b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the configuration of a Saver. + /// + public sealed partial class SaverDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaverDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaverDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaverDef(SaverDef other) : this() { + filenameTensorName_ = other.filenameTensorName_; + saveTensorName_ = other.saveTensorName_; + restoreOpName_ = other.restoreOpName_; + maxToKeep_ = other.maxToKeep_; + sharded_ = other.sharded_; + keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; + version_ = other.version_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaverDef Clone() { + return new SaverDef(this); + } + + /// Field number for the "filename_tensor_name" field. + public const int FilenameTensorNameFieldNumber = 1; + private string filenameTensorName_ = ""; + /// + /// The name of the tensor in which to specify the filename when saving or + /// restoring a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FilenameTensorName { + get { return filenameTensorName_; } + set { + filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "save_tensor_name" field. + public const int SaveTensorNameFieldNumber = 2; + private string saveTensorName_ = ""; + /// + /// The operation to run when saving a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string SaveTensorName { + get { return saveTensorName_; } + set { + saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "restore_op_name" field. + public const int RestoreOpNameFieldNumber = 3; + private string restoreOpName_ = ""; + /// + /// The operation to run when restoring a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string RestoreOpName { + get { return restoreOpName_; } + set { + restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "max_to_keep" field. + public const int MaxToKeepFieldNumber = 4; + private int maxToKeep_; + /// + /// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int MaxToKeep { + get { return maxToKeep_; } + set { + maxToKeep_ = value; + } + } + + /// Field number for the "sharded" field. + public const int ShardedFieldNumber = 5; + private bool sharded_; + /// + /// Shard the save files, one per device that has Variable nodes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Sharded { + get { return sharded_; } + set { + sharded_ = value; + } + } + + /// Field number for the "keep_checkpoint_every_n_hours" field. + public const int KeepCheckpointEveryNHoursFieldNumber = 6; + private float keepCheckpointEveryNHours_; + /// + /// How often to keep an additional checkpoint. If not specified, only the last + /// "max_to_keep" checkpoints are kept; if specified, in addition to keeping + /// the last "max_to_keep" checkpoints, an additional checkpoint will be kept + /// for every n hours of training. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public float KeepCheckpointEveryNHours { + get { return keepCheckpointEveryNHours_; } + set { + keepCheckpointEveryNHours_ = value; + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 7; + private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version { + get { return version_; } + set { + version_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SaverDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SaverDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FilenameTensorName != other.FilenameTensorName) return false; + if (SaveTensorName != other.SaveTensorName) return false; + if (RestoreOpName != other.RestoreOpName) return false; + if (MaxToKeep != other.MaxToKeep) return false; + if (Sharded != other.Sharded) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; + if (Version != other.Version) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode(); + if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode(); + if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode(); + if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode(); + if (Sharded != false) hash ^= Sharded.GetHashCode(); + if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); + if (Version != global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy) hash ^= Version.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (FilenameTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FilenameTensorName); + } + if (SaveTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(SaveTensorName); + } + if (RestoreOpName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(RestoreOpName); + } + if (MaxToKeep != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaxToKeep); + } + if (Sharded != false) { + output.WriteRawTag(40); + output.WriteBool(Sharded); + } + if (KeepCheckpointEveryNHours != 0F) { + output.WriteRawTag(53); + output.WriteFloat(KeepCheckpointEveryNHours); + } + if (Version != global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy) { + output.WriteRawTag(56); + output.WriteEnum((int) Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (FilenameTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FilenameTensorName); + } + if (SaveTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(SaveTensorName); + } + if (RestoreOpName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(RestoreOpName); + } + if (MaxToKeep != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaxToKeep); + } + if (Sharded != false) { + output.WriteRawTag(40); + output.WriteBool(Sharded); + } + if (KeepCheckpointEveryNHours != 0F) { + output.WriteRawTag(53); + output.WriteFloat(KeepCheckpointEveryNHours); + } + if (Version != global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy) { + output.WriteRawTag(56); + output.WriteEnum((int) Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (FilenameTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName); + } + if (SaveTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName); + } + if (RestoreOpName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName); + } + if (MaxToKeep != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep); + } + if (Sharded != false) { + size += 1 + 1; + } + if (KeepCheckpointEveryNHours != 0F) { + size += 1 + 4; + } + if (Version != global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SaverDef other) { + if (other == null) { + return; + } + if (other.FilenameTensorName.Length != 0) { + FilenameTensorName = other.FilenameTensorName; + } + if (other.SaveTensorName.Length != 0) { + SaveTensorName = other.SaveTensorName; + } + if (other.RestoreOpName.Length != 0) { + RestoreOpName = other.RestoreOpName; + } + if (other.MaxToKeep != 0) { + MaxToKeep = other.MaxToKeep; + } + if (other.Sharded != false) { + Sharded = other.Sharded; + } + if (other.KeepCheckpointEveryNHours != 0F) { + KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; + } + if (other.Version != global::Tensorflow.SaverDef.Types.CheckpointFormatVersion.Legacy) { + Version = other.Version; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FilenameTensorName = input.ReadString(); + break; + } + case 18: { + SaveTensorName = input.ReadString(); + break; + } + case 26: { + RestoreOpName = input.ReadString(); + break; + } + case 32: { + MaxToKeep = input.ReadInt32(); + break; + } + case 40: { + Sharded = input.ReadBool(); + break; + } + case 53: { + KeepCheckpointEveryNHours = input.ReadFloat(); + break; + } + case 56: { + Version = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + FilenameTensorName = input.ReadString(); + break; + } + case 18: { + SaveTensorName = input.ReadString(); + break; + } + case 26: { + RestoreOpName = input.ReadString(); + break; + } + case 32: { + MaxToKeep = input.ReadInt32(); + break; + } + case 40: { + Sharded = input.ReadBool(); + break; + } + case 53: { + KeepCheckpointEveryNHours = input.ReadFloat(); + break; + } + case 56: { + Version = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the SaverDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// A version number that identifies a different on-disk checkpoint format. + /// Usually, each subclass of BaseSaverBuilder works with a particular + /// version/format. However, it is possible that the same builder may be + /// upgraded to support a newer checkpoint format in the future. + /// + public enum CheckpointFormatVersion { + /// + /// Internal legacy format. + /// + [pbr::OriginalName("LEGACY")] Legacy = 0, + /// + /// Deprecated format: tf.Saver() which works with tensorflow::table::Table. + /// + [pbr::OriginalName("V1")] V1 = 1, + /// + /// Current format: more efficient. + /// + [pbr::OriginalName("V2")] V2 = 2, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/ServiceConfig.cs b/src/TensorFlowNET.Core/Protobuf/ServiceConfig.cs new file mode 100644 index 000000000..2197b4bac --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ServiceConfig.cs @@ -0,0 +1,1179 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/service_config.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Data.Experimental { + + /// Holder for reflection information generated from tensorflow/core/protobuf/service_config.proto + public static partial class ServiceConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/service_config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ServiceConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci10ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2VydmljZV9jb25maWcucHJv", + "dG8SHHRlbnNvcmZsb3cuZGF0YS5leHBlcmltZW50YWwaK3RlbnNvcmZsb3cv", + "Y29yZS9wcm90b2J1Zi9kYXRhX3NlcnZpY2UucHJvdG8ijQIKEERpc3BhdGNo", + "ZXJDb25maWcSDAoEcG9ydBgBIAEoAxIQCghwcm90b2NvbBgCIAEoCRIQCgh3", + "b3JrX2RpchgDIAEoCRIbChNmYXVsdF90b2xlcmFudF9tb2RlGAQgASgIEhgK", + "EHdvcmtlcl9hZGRyZXNzZXMYByADKAkSOAoPZGVwbG95bWVudF9tb2RlGAkg", + "ASgOMh8udGVuc29yZmxvdy5kYXRhLkRlcGxveW1lbnRNb2RlEiAKGGpvYl9n", + "Y19jaGVja19pbnRlcnZhbF9tcxgFIAEoAxIZChFqb2JfZ2NfdGltZW91dF9t", + "cxgGIAEoAxIZChFjbGllbnRfdGltZW91dF9tcxgIIAEoAyK+AgoMV29ya2Vy", + "Q29uZmlnEgwKBHBvcnQYASABKAMSEAoIcHJvdG9jb2wYAiABKAkSGgoSZGlz", + "cGF0Y2hlcl9hZGRyZXNzGAMgASgJEhYKDndvcmtlcl9hZGRyZXNzGAQgASgJ", + "EhMKC3dvcmtlcl90YWdzGAogAygJEh0KFWhlYXJ0YmVhdF9pbnRlcnZhbF9t", + "cxgFIAEoAxIdChVkaXNwYXRjaGVyX3RpbWVvdXRfbXMYBiABKAMSHgoWZGF0", + "YV90cmFuc2Zlcl9wcm90b2NvbBgHIAEoCRIdChVkYXRhX3RyYW5zZmVyX2Fk", + "ZHJlc3MYCCABKAkSJgoeY3Jvc3NfdHJhaW5lcl9jYWNoZV9zaXplX2J5dGVz", + "GAsgASgDEiAKGHNodXRkb3duX3F1aWV0X3BlcmlvZF9tcxgJIAEoA0JXWlVn", + "aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", + "L2NvcmUvcHJvdG9idWYvZm9yX2NvcmVfcHJvdG9zX2dvX3Byb3RvYgZwcm90", + "bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Data.DataServiceReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.Experimental.DispatcherConfig), global::Tensorflow.Data.Experimental.DispatcherConfig.Parser, new[]{ "Port", "Protocol", "WorkDir", "FaultTolerantMode", "WorkerAddresses", "DeploymentMode", "JobGcCheckIntervalMs", "JobGcTimeoutMs", "ClientTimeoutMs" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Data.Experimental.WorkerConfig), global::Tensorflow.Data.Experimental.WorkerConfig.Parser, new[]{ "Port", "Protocol", "DispatcherAddress", "WorkerAddress", "WorkerTags", "HeartbeatIntervalMs", "DispatcherTimeoutMs", "DataTransferProtocol", "DataTransferAddress", "CrossTrainerCacheSizeBytes", "ShutdownQuietPeriodMs" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration for a tf.data service DispatchServer. + /// Next id: 10 + /// + public sealed partial class DispatcherConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DispatcherConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.Experimental.ServiceConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DispatcherConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DispatcherConfig(DispatcherConfig other) : this() { + port_ = other.port_; + protocol_ = other.protocol_; + workDir_ = other.workDir_; + faultTolerantMode_ = other.faultTolerantMode_; + workerAddresses_ = other.workerAddresses_.Clone(); + deploymentMode_ = other.deploymentMode_; + jobGcCheckIntervalMs_ = other.jobGcCheckIntervalMs_; + jobGcTimeoutMs_ = other.jobGcTimeoutMs_; + clientTimeoutMs_ = other.clientTimeoutMs_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DispatcherConfig Clone() { + return new DispatcherConfig(this); + } + + /// Field number for the "port" field. + public const int PortFieldNumber = 1; + private long port_; + /// + /// The port for the dispatcher to bind to. A value of 0 indicates that the + /// dispatcher may bind to any available port. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Port { + get { return port_; } + set { + port_ = value; + } + } + + /// Field number for the "protocol" field. + public const int ProtocolFieldNumber = 2; + private string protocol_ = ""; + /// + /// The protocol for the dispatcher to use when connecting to workers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Protocol { + get { return protocol_; } + set { + protocol_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "work_dir" field. + public const int WorkDirFieldNumber = 3; + private string workDir_ = ""; + /// + /// A work directory to use for storing dispatcher state, and for recovering + /// during restarts. The empty string indicates not to use any work directory. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string WorkDir { + get { return workDir_; } + set { + workDir_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "fault_tolerant_mode" field. + public const int FaultTolerantModeFieldNumber = 4; + private bool faultTolerantMode_; + /// + /// Whether to run in fault tolerant mode, where dispatcher state is saved + /// across restarts. Requires that `work_dir` is nonempty. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool FaultTolerantMode { + get { return faultTolerantMode_; } + set { + faultTolerantMode_ = value; + } + } + + /// Field number for the "worker_addresses" field. + public const int WorkerAddressesFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_workerAddresses_codec + = pb::FieldCodec.ForString(58); + private readonly pbc::RepeatedField workerAddresses_ = new pbc::RepeatedField(); + /// + /// (Optional.) If the job uses auto-sharding, it needs to specify a fixed list + /// of worker addresses that will register with the dispatcher. The worker + /// addresses should be in the format "host" or "host:port", where "port" is an + /// integer, named port, or %port% to match any port. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField WorkerAddresses { + get { return workerAddresses_; } + } + + /// Field number for the "deployment_mode" field. + public const int DeploymentModeFieldNumber = 9; + private global::Tensorflow.Data.DeploymentMode deploymentMode_ = global::Tensorflow.Data.DeploymentMode.Unspecified; + /// + /// (Optional.) tf.data service deployment mode. Supported values are "REMOTE", + /// "COLOCATED", and "HYBRID". If unspecified, it is assumed to be "REMOTE". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.Data.DeploymentMode DeploymentMode { + get { return deploymentMode_; } + set { + deploymentMode_ = value; + } + } + + /// Field number for the "job_gc_check_interval_ms" field. + public const int JobGcCheckIntervalMsFieldNumber = 5; + private long jobGcCheckIntervalMs_; + /// + /// How often the dispatcher should scan through to delete old and unused + /// jobs. A value of 0 indicates that the decision should be left up to the + /// runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long JobGcCheckIntervalMs { + get { return jobGcCheckIntervalMs_; } + set { + jobGcCheckIntervalMs_ = value; + } + } + + /// Field number for the "job_gc_timeout_ms" field. + public const int JobGcTimeoutMsFieldNumber = 6; + private long jobGcTimeoutMs_; + /// + /// How long a job needs to be unused before it becomes a candidate for garbage + /// collection. A value of -1 indicates that jobs should never be garbage + /// collected. A value of 0 indicates that the decision should be left up to + /// the runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long JobGcTimeoutMs { + get { return jobGcTimeoutMs_; } + set { + jobGcTimeoutMs_ = value; + } + } + + /// Field number for the "client_timeout_ms" field. + public const int ClientTimeoutMsFieldNumber = 8; + private long clientTimeoutMs_; + /// + /// How long to wait before garbage-collecting a client that hasn't + /// heartbeated to the dispatcher. A value of 0 indicates that the timeout + /// should be left to the runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ClientTimeoutMs { + get { return clientTimeoutMs_; } + set { + clientTimeoutMs_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DispatcherConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DispatcherConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Port != other.Port) return false; + if (Protocol != other.Protocol) return false; + if (WorkDir != other.WorkDir) return false; + if (FaultTolerantMode != other.FaultTolerantMode) return false; + if(!workerAddresses_.Equals(other.workerAddresses_)) return false; + if (DeploymentMode != other.DeploymentMode) return false; + if (JobGcCheckIntervalMs != other.JobGcCheckIntervalMs) return false; + if (JobGcTimeoutMs != other.JobGcTimeoutMs) return false; + if (ClientTimeoutMs != other.ClientTimeoutMs) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Port != 0L) hash ^= Port.GetHashCode(); + if (Protocol.Length != 0) hash ^= Protocol.GetHashCode(); + if (WorkDir.Length != 0) hash ^= WorkDir.GetHashCode(); + if (FaultTolerantMode != false) hash ^= FaultTolerantMode.GetHashCode(); + hash ^= workerAddresses_.GetHashCode(); + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) hash ^= DeploymentMode.GetHashCode(); + if (JobGcCheckIntervalMs != 0L) hash ^= JobGcCheckIntervalMs.GetHashCode(); + if (JobGcTimeoutMs != 0L) hash ^= JobGcTimeoutMs.GetHashCode(); + if (ClientTimeoutMs != 0L) hash ^= ClientTimeoutMs.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Port != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Port); + } + if (Protocol.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Protocol); + } + if (WorkDir.Length != 0) { + output.WriteRawTag(26); + output.WriteString(WorkDir); + } + if (FaultTolerantMode != false) { + output.WriteRawTag(32); + output.WriteBool(FaultTolerantMode); + } + if (JobGcCheckIntervalMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(JobGcCheckIntervalMs); + } + if (JobGcTimeoutMs != 0L) { + output.WriteRawTag(48); + output.WriteInt64(JobGcTimeoutMs); + } + workerAddresses_.WriteTo(output, _repeated_workerAddresses_codec); + if (ClientTimeoutMs != 0L) { + output.WriteRawTag(64); + output.WriteInt64(ClientTimeoutMs); + } + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + output.WriteRawTag(72); + output.WriteEnum((int) DeploymentMode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Port != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Port); + } + if (Protocol.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Protocol); + } + if (WorkDir.Length != 0) { + output.WriteRawTag(26); + output.WriteString(WorkDir); + } + if (FaultTolerantMode != false) { + output.WriteRawTag(32); + output.WriteBool(FaultTolerantMode); + } + if (JobGcCheckIntervalMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(JobGcCheckIntervalMs); + } + if (JobGcTimeoutMs != 0L) { + output.WriteRawTag(48); + output.WriteInt64(JobGcTimeoutMs); + } + workerAddresses_.WriteTo(ref output, _repeated_workerAddresses_codec); + if (ClientTimeoutMs != 0L) { + output.WriteRawTag(64); + output.WriteInt64(ClientTimeoutMs); + } + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + output.WriteRawTag(72); + output.WriteEnum((int) DeploymentMode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Port != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Port); + } + if (Protocol.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Protocol); + } + if (WorkDir.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(WorkDir); + } + if (FaultTolerantMode != false) { + size += 1 + 1; + } + size += workerAddresses_.CalculateSize(_repeated_workerAddresses_codec); + if (DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DeploymentMode); + } + if (JobGcCheckIntervalMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(JobGcCheckIntervalMs); + } + if (JobGcTimeoutMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(JobGcTimeoutMs); + } + if (ClientTimeoutMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClientTimeoutMs); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DispatcherConfig other) { + if (other == null) { + return; + } + if (other.Port != 0L) { + Port = other.Port; + } + if (other.Protocol.Length != 0) { + Protocol = other.Protocol; + } + if (other.WorkDir.Length != 0) { + WorkDir = other.WorkDir; + } + if (other.FaultTolerantMode != false) { + FaultTolerantMode = other.FaultTolerantMode; + } + workerAddresses_.Add(other.workerAddresses_); + if (other.DeploymentMode != global::Tensorflow.Data.DeploymentMode.Unspecified) { + DeploymentMode = other.DeploymentMode; + } + if (other.JobGcCheckIntervalMs != 0L) { + JobGcCheckIntervalMs = other.JobGcCheckIntervalMs; + } + if (other.JobGcTimeoutMs != 0L) { + JobGcTimeoutMs = other.JobGcTimeoutMs; + } + if (other.ClientTimeoutMs != 0L) { + ClientTimeoutMs = other.ClientTimeoutMs; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Port = input.ReadInt64(); + break; + } + case 18: { + Protocol = input.ReadString(); + break; + } + case 26: { + WorkDir = input.ReadString(); + break; + } + case 32: { + FaultTolerantMode = input.ReadBool(); + break; + } + case 40: { + JobGcCheckIntervalMs = input.ReadInt64(); + break; + } + case 48: { + JobGcTimeoutMs = input.ReadInt64(); + break; + } + case 58: { + workerAddresses_.AddEntriesFrom(input, _repeated_workerAddresses_codec); + break; + } + case 64: { + ClientTimeoutMs = input.ReadInt64(); + break; + } + case 72: { + DeploymentMode = (global::Tensorflow.Data.DeploymentMode) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Port = input.ReadInt64(); + break; + } + case 18: { + Protocol = input.ReadString(); + break; + } + case 26: { + WorkDir = input.ReadString(); + break; + } + case 32: { + FaultTolerantMode = input.ReadBool(); + break; + } + case 40: { + JobGcCheckIntervalMs = input.ReadInt64(); + break; + } + case 48: { + JobGcTimeoutMs = input.ReadInt64(); + break; + } + case 58: { + workerAddresses_.AddEntriesFrom(ref input, _repeated_workerAddresses_codec); + break; + } + case 64: { + ClientTimeoutMs = input.ReadInt64(); + break; + } + case 72: { + DeploymentMode = (global::Tensorflow.Data.DeploymentMode) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + /// + /// Configuration for a tf.data service WorkerServer. + /// Next id: 12 + /// + public sealed partial class WorkerConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WorkerConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Data.Experimental.ServiceConfigReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerConfig(WorkerConfig other) : this() { + port_ = other.port_; + protocol_ = other.protocol_; + dispatcherAddress_ = other.dispatcherAddress_; + workerAddress_ = other.workerAddress_; + workerTags_ = other.workerTags_.Clone(); + heartbeatIntervalMs_ = other.heartbeatIntervalMs_; + dispatcherTimeoutMs_ = other.dispatcherTimeoutMs_; + dataTransferProtocol_ = other.dataTransferProtocol_; + dataTransferAddress_ = other.dataTransferAddress_; + crossTrainerCacheSizeBytes_ = other.crossTrainerCacheSizeBytes_; + shutdownQuietPeriodMs_ = other.shutdownQuietPeriodMs_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WorkerConfig Clone() { + return new WorkerConfig(this); + } + + /// Field number for the "port" field. + public const int PortFieldNumber = 1; + private long port_; + /// + /// The port for the worker to bind to. A value of 0 indicates that the + /// worker may bind to any available port. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Port { + get { return port_; } + set { + port_ = value; + } + } + + /// Field number for the "protocol" field. + public const int ProtocolFieldNumber = 2; + private string protocol_ = ""; + /// + /// The protocol for the worker to use when connecting to the dispatcher. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Protocol { + get { return protocol_; } + set { + protocol_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "dispatcher_address" field. + public const int DispatcherAddressFieldNumber = 3; + private string dispatcherAddress_ = ""; + /// + /// The address of the dispatcher to register with. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DispatcherAddress { + get { return dispatcherAddress_; } + set { + dispatcherAddress_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "worker_address" field. + public const int WorkerAddressFieldNumber = 4; + private string workerAddress_ = ""; + /// + /// The address of the worker server. The substring "%port%", if specified, + /// will be replaced with the worker's bound port. This is useful when the port + /// is set to `0`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string WorkerAddress { + get { return workerAddress_; } + set { + workerAddress_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "worker_tags" field. + public const int WorkerTagsFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_workerTags_codec + = pb::FieldCodec.ForString(82); + private readonly pbc::RepeatedField workerTags_ = new pbc::RepeatedField(); + /// + /// Tags attached to the worker. This allows reading from selected workers. + /// For example, by applying a "COLOCATED" tag, tf.data service is able to read + /// from the local tf.data worker if one exists, then from off-TF-host workers, + /// to avoid cross-TF-host reads. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField WorkerTags { + get { return workerTags_; } + } + + /// Field number for the "heartbeat_interval_ms" field. + public const int HeartbeatIntervalMsFieldNumber = 5; + private long heartbeatIntervalMs_; + /// + /// How often the worker should heartbeat to the master. A value of 0 indicates + /// that the decision should be left up to the runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long HeartbeatIntervalMs { + get { return heartbeatIntervalMs_; } + set { + heartbeatIntervalMs_ = value; + } + } + + /// Field number for the "dispatcher_timeout_ms" field. + public const int DispatcherTimeoutMsFieldNumber = 6; + private long dispatcherTimeoutMs_; + /// + /// How long to retry requests to the dispatcher before giving up and reporting + /// an error. A value of 0 indicates that the decision should be left up to the + /// runtime. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DispatcherTimeoutMs { + get { return dispatcherTimeoutMs_; } + set { + dispatcherTimeoutMs_ = value; + } + } + + /// Field number for the "data_transfer_protocol" field. + public const int DataTransferProtocolFieldNumber = 7; + private string dataTransferProtocol_ = ""; + /// + /// The protocol for the worker to use when transferring data to clients. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DataTransferProtocol { + get { return dataTransferProtocol_; } + set { + dataTransferProtocol_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "data_transfer_address" field. + public const int DataTransferAddressFieldNumber = 8; + private string dataTransferAddress_ = ""; + /// + /// The data transfer address of the worker server. The substring "%port%", if + /// specified, will be replaced with the worker's bound port. This is useful + /// when the port is set to `0`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DataTransferAddress { + get { return dataTransferAddress_; } + set { + dataTransferAddress_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "cross_trainer_cache_size_bytes" field. + public const int CrossTrainerCacheSizeBytesFieldNumber = 11; + private long crossTrainerCacheSizeBytes_; + /// + /// Maximum size of the cross-trainer cache in bytes. If enabled, make sure + /// your training job provides sufficient memory resources. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long CrossTrainerCacheSizeBytes { + get { return crossTrainerCacheSizeBytes_; } + set { + crossTrainerCacheSizeBytes_ = value; + } + } + + /// Field number for the "shutdown_quiet_period_ms" field. + public const int ShutdownQuietPeriodMsFieldNumber = 9; + private long shutdownQuietPeriodMs_; + /// + /// When shutting down a worker, how long to wait for the gRPC server to + /// process the final requests. This is used to achieve clean shutdown in unit + /// tests. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ShutdownQuietPeriodMs { + get { return shutdownQuietPeriodMs_; } + set { + shutdownQuietPeriodMs_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WorkerConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WorkerConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Port != other.Port) return false; + if (Protocol != other.Protocol) return false; + if (DispatcherAddress != other.DispatcherAddress) return false; + if (WorkerAddress != other.WorkerAddress) return false; + if(!workerTags_.Equals(other.workerTags_)) return false; + if (HeartbeatIntervalMs != other.HeartbeatIntervalMs) return false; + if (DispatcherTimeoutMs != other.DispatcherTimeoutMs) return false; + if (DataTransferProtocol != other.DataTransferProtocol) return false; + if (DataTransferAddress != other.DataTransferAddress) return false; + if (CrossTrainerCacheSizeBytes != other.CrossTrainerCacheSizeBytes) return false; + if (ShutdownQuietPeriodMs != other.ShutdownQuietPeriodMs) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Port != 0L) hash ^= Port.GetHashCode(); + if (Protocol.Length != 0) hash ^= Protocol.GetHashCode(); + if (DispatcherAddress.Length != 0) hash ^= DispatcherAddress.GetHashCode(); + if (WorkerAddress.Length != 0) hash ^= WorkerAddress.GetHashCode(); + hash ^= workerTags_.GetHashCode(); + if (HeartbeatIntervalMs != 0L) hash ^= HeartbeatIntervalMs.GetHashCode(); + if (DispatcherTimeoutMs != 0L) hash ^= DispatcherTimeoutMs.GetHashCode(); + if (DataTransferProtocol.Length != 0) hash ^= DataTransferProtocol.GetHashCode(); + if (DataTransferAddress.Length != 0) hash ^= DataTransferAddress.GetHashCode(); + if (CrossTrainerCacheSizeBytes != 0L) hash ^= CrossTrainerCacheSizeBytes.GetHashCode(); + if (ShutdownQuietPeriodMs != 0L) hash ^= ShutdownQuietPeriodMs.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Port != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Port); + } + if (Protocol.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Protocol); + } + if (DispatcherAddress.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DispatcherAddress); + } + if (WorkerAddress.Length != 0) { + output.WriteRawTag(34); + output.WriteString(WorkerAddress); + } + if (HeartbeatIntervalMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(HeartbeatIntervalMs); + } + if (DispatcherTimeoutMs != 0L) { + output.WriteRawTag(48); + output.WriteInt64(DispatcherTimeoutMs); + } + if (DataTransferProtocol.Length != 0) { + output.WriteRawTag(58); + output.WriteString(DataTransferProtocol); + } + if (DataTransferAddress.Length != 0) { + output.WriteRawTag(66); + output.WriteString(DataTransferAddress); + } + if (ShutdownQuietPeriodMs != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ShutdownQuietPeriodMs); + } + workerTags_.WriteTo(output, _repeated_workerTags_codec); + if (CrossTrainerCacheSizeBytes != 0L) { + output.WriteRawTag(88); + output.WriteInt64(CrossTrainerCacheSizeBytes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Port != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Port); + } + if (Protocol.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Protocol); + } + if (DispatcherAddress.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DispatcherAddress); + } + if (WorkerAddress.Length != 0) { + output.WriteRawTag(34); + output.WriteString(WorkerAddress); + } + if (HeartbeatIntervalMs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(HeartbeatIntervalMs); + } + if (DispatcherTimeoutMs != 0L) { + output.WriteRawTag(48); + output.WriteInt64(DispatcherTimeoutMs); + } + if (DataTransferProtocol.Length != 0) { + output.WriteRawTag(58); + output.WriteString(DataTransferProtocol); + } + if (DataTransferAddress.Length != 0) { + output.WriteRawTag(66); + output.WriteString(DataTransferAddress); + } + if (ShutdownQuietPeriodMs != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ShutdownQuietPeriodMs); + } + workerTags_.WriteTo(ref output, _repeated_workerTags_codec); + if (CrossTrainerCacheSizeBytes != 0L) { + output.WriteRawTag(88); + output.WriteInt64(CrossTrainerCacheSizeBytes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Port != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Port); + } + if (Protocol.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Protocol); + } + if (DispatcherAddress.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DispatcherAddress); + } + if (WorkerAddress.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(WorkerAddress); + } + size += workerTags_.CalculateSize(_repeated_workerTags_codec); + if (HeartbeatIntervalMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatIntervalMs); + } + if (DispatcherTimeoutMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DispatcherTimeoutMs); + } + if (DataTransferProtocol.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DataTransferProtocol); + } + if (DataTransferAddress.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DataTransferAddress); + } + if (CrossTrainerCacheSizeBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(CrossTrainerCacheSizeBytes); + } + if (ShutdownQuietPeriodMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownQuietPeriodMs); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WorkerConfig other) { + if (other == null) { + return; + } + if (other.Port != 0L) { + Port = other.Port; + } + if (other.Protocol.Length != 0) { + Protocol = other.Protocol; + } + if (other.DispatcherAddress.Length != 0) { + DispatcherAddress = other.DispatcherAddress; + } + if (other.WorkerAddress.Length != 0) { + WorkerAddress = other.WorkerAddress; + } + workerTags_.Add(other.workerTags_); + if (other.HeartbeatIntervalMs != 0L) { + HeartbeatIntervalMs = other.HeartbeatIntervalMs; + } + if (other.DispatcherTimeoutMs != 0L) { + DispatcherTimeoutMs = other.DispatcherTimeoutMs; + } + if (other.DataTransferProtocol.Length != 0) { + DataTransferProtocol = other.DataTransferProtocol; + } + if (other.DataTransferAddress.Length != 0) { + DataTransferAddress = other.DataTransferAddress; + } + if (other.CrossTrainerCacheSizeBytes != 0L) { + CrossTrainerCacheSizeBytes = other.CrossTrainerCacheSizeBytes; + } + if (other.ShutdownQuietPeriodMs != 0L) { + ShutdownQuietPeriodMs = other.ShutdownQuietPeriodMs; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Port = input.ReadInt64(); + break; + } + case 18: { + Protocol = input.ReadString(); + break; + } + case 26: { + DispatcherAddress = input.ReadString(); + break; + } + case 34: { + WorkerAddress = input.ReadString(); + break; + } + case 40: { + HeartbeatIntervalMs = input.ReadInt64(); + break; + } + case 48: { + DispatcherTimeoutMs = input.ReadInt64(); + break; + } + case 58: { + DataTransferProtocol = input.ReadString(); + break; + } + case 66: { + DataTransferAddress = input.ReadString(); + break; + } + case 72: { + ShutdownQuietPeriodMs = input.ReadInt64(); + break; + } + case 82: { + workerTags_.AddEntriesFrom(input, _repeated_workerTags_codec); + break; + } + case 88: { + CrossTrainerCacheSizeBytes = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Port = input.ReadInt64(); + break; + } + case 18: { + Protocol = input.ReadString(); + break; + } + case 26: { + DispatcherAddress = input.ReadString(); + break; + } + case 34: { + WorkerAddress = input.ReadString(); + break; + } + case 40: { + HeartbeatIntervalMs = input.ReadInt64(); + break; + } + case 48: { + DispatcherTimeoutMs = input.ReadInt64(); + break; + } + case 58: { + DataTransferProtocol = input.ReadString(); + break; + } + case 66: { + DataTransferAddress = input.ReadString(); + break; + } + case 72: { + ShutdownQuietPeriodMs = input.ReadInt64(); + break; + } + case 82: { + workerTags_.AddEntriesFrom(ref input, _repeated_workerTags_codec); + break; + } + case 88: { + CrossTrainerCacheSizeBytes = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/StepStats.cs b/src/TensorFlowNET.Core/Protobuf/StepStats.cs new file mode 100644 index 000000000..48ecd0d50 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/StepStats.cs @@ -0,0 +1,2484 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/step_stats.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/step_stats.proto + public static partial class StepStatsReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/step_stats.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static StepStatsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3N0ZXBfc3RhdHMucHJvdG8S", + "CnRlbnNvcmZsb3caNnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvYWxsb2Nh", + "dGlvbl9kZXNjcmlwdGlvbi5wcm90bxoydGVuc29yZmxvdy9jb3JlL2ZyYW1l", + "d29yay90ZW5zb3JfZGVzY3JpcHRpb24ucHJvdG8iPQoQQWxsb2NhdGlvblJl", + "Y29yZBIUCgxhbGxvY19taWNyb3MYASABKAMSEwoLYWxsb2NfYnl0ZXMYAiAB", + "KAMixAEKE0FsbG9jYXRvck1lbW9yeVVzZWQSFgoOYWxsb2NhdG9yX25hbWUY", + "ASABKAkSEwoLdG90YWxfYnl0ZXMYAiABKAMSEgoKcGVha19ieXRlcxgDIAEo", + "AxISCgpsaXZlX2J5dGVzGAQgASgDEjgKEmFsbG9jYXRpb25fcmVjb3JkcxgG", + "IAMoCzIcLnRlbnNvcmZsb3cuQWxsb2NhdGlvblJlY29yZBIeChZhbGxvY2F0", + "b3JfYnl0ZXNfaW5fdXNlGAUgASgDIlUKCk5vZGVPdXRwdXQSDAoEc2xvdBgB", + "IAEoBRI5ChJ0ZW5zb3JfZGVzY3JpcHRpb24YAyABKAsyHS50ZW5zb3JmbG93", + "LlRlbnNvckRlc2NyaXB0aW9uIuwBCgtNZW1vcnlTdGF0cxIYChB0ZW1wX21l", + "bW9yeV9zaXplGAEgASgDEh4KFnBlcnNpc3RlbnRfbWVtb3J5X3NpemUYAyAB", + "KAMSIwobcGVyc2lzdGVudF90ZW5zb3JfYWxsb2NfaWRzGAUgAygDEiMKF2Rl", + "dmljZV90ZW1wX21lbW9yeV9zaXplGAIgASgDQgIYARIpCh1kZXZpY2VfcGVy", + "c2lzdGVudF9tZW1vcnlfc2l6ZRgEIAEoA0ICGAESLgoiZGV2aWNlX3BlcnNp", + "c3RlbnRfdGVuc29yX2FsbG9jX2lkcxgGIAMoA0ICGAEingQKDU5vZGVFeGVj", + "U3RhdHMSEQoJbm9kZV9uYW1lGAEgASgJEhgKEGFsbF9zdGFydF9taWNyb3MY", + "AiABKAMSGwoTb3Bfc3RhcnRfcmVsX21pY3JvcxgDIAEoAxIZChFvcF9lbmRf", + "cmVsX21pY3JvcxgEIAEoAxIaChJhbGxfZW5kX3JlbF9taWNyb3MYBSABKAMS", + "LwoGbWVtb3J5GAYgAygLMh8udGVuc29yZmxvdy5BbGxvY2F0b3JNZW1vcnlV", + "c2VkEiYKBm91dHB1dBgHIAMoCzIWLnRlbnNvcmZsb3cuTm9kZU91dHB1dBIW", + "Cg50aW1lbGluZV9sYWJlbBgIIAEoCRIYChBzY2hlZHVsZWRfbWljcm9zGAkg", + "ASgDEhEKCXRocmVhZF9pZBgKIAEoDRI8ChFyZWZlcmVuY2VkX3RlbnNvchgL", + "IAMoCzIhLnRlbnNvcmZsb3cuQWxsb2NhdGlvbkRlc2NyaXB0aW9uEi0KDG1l", + "bW9yeV9zdGF0cxgMIAEoCzIXLnRlbnNvcmZsb3cuTWVtb3J5U3RhdHMSFwoP", + "YWxsX3N0YXJ0X25hbm9zGA0gASgDEhoKEm9wX3N0YXJ0X3JlbF9uYW5vcxgO", + "IAEoAxIYChBvcF9lbmRfcmVsX25hbm9zGA8gASgDEhkKEWFsbF9lbmRfcmVs", + "X25hbm9zGBAgASgDEhcKD3NjaGVkdWxlZF9uYW5vcxgRIAEoAyLIAQoPRGV2", + "aWNlU3RlcFN0YXRzEg4KBmRldmljZRgBIAEoCRItCgpub2RlX3N0YXRzGAIg", + "AygLMhkudGVuc29yZmxvdy5Ob2RlRXhlY1N0YXRzEkIKDHRocmVhZF9uYW1l", + "cxgDIAMoCzIsLnRlbnNvcmZsb3cuRGV2aWNlU3RlcFN0YXRzLlRocmVhZE5h", + "bWVzRW50cnkaMgoQVGhyZWFkTmFtZXNFbnRyeRILCgNrZXkYASABKA0SDQoF", + "dmFsdWUYAiABKAk6AjgBIjsKCVN0ZXBTdGF0cxIuCglkZXZfc3RhdHMYASAD", + "KAsyGy50ZW5zb3JmbG93LkRldmljZVN0ZXBTdGF0c0KDAQoYb3JnLnRlbnNv", + "cmZsb3cuZnJhbWV3b3JrQg9TdGVwU3RhdHNQcm90b3NQAVpRZ2l0aHViLmNv", + "bS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2Zy", + "YW1ld29yay9zdGVwX3N0YXRzX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AllocationDescriptionReflection.Descriptor, global::Tensorflow.TensorDescriptionReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AllocationRecord), global::Tensorflow.AllocationRecord.Parser, new[]{ "AllocMicros", "AllocBytes" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AllocatorMemoryUsed), global::Tensorflow.AllocatorMemoryUsed.Parser, new[]{ "AllocatorName", "TotalBytes", "PeakBytes", "LiveBytes", "AllocationRecords", "AllocatorBytesInUse" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeOutput), global::Tensorflow.NodeOutput.Parser, new[]{ "Slot", "TensorDescription" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemoryStats), global::Tensorflow.MemoryStats.Parser, new[]{ "TempMemorySize", "PersistentMemorySize", "PersistentTensorAllocIds", "DeviceTempMemorySize", "DevicePersistentMemorySize", "DevicePersistentTensorAllocIds" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeExecStats), global::Tensorflow.NodeExecStats.Parser, new[]{ "NodeName", "AllStartMicros", "OpStartRelMicros", "OpEndRelMicros", "AllEndRelMicros", "Memory", "Output", "TimelineLabel", "ScheduledMicros", "ThreadId", "ReferencedTensor", "MemoryStats", "AllStartNanos", "OpStartRelNanos", "OpEndRelNanos", "AllEndRelNanos", "ScheduledNanos" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceStepStats), global::Tensorflow.DeviceStepStats.Parser, new[]{ "Device", "NodeStats", "ThreadNames" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.StepStats), global::Tensorflow.StepStats.Parser, new[]{ "DevStats" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// An allocation/de-allocation operation performed by the allocator. + /// + public sealed partial class AllocationRecord : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AllocationRecord()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationRecord() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationRecord(AllocationRecord other) : this() { + allocMicros_ = other.allocMicros_; + allocBytes_ = other.allocBytes_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocationRecord Clone() { + return new AllocationRecord(this); + } + + /// Field number for the "alloc_micros" field. + public const int AllocMicrosFieldNumber = 1; + private long allocMicros_; + /// + /// The timestamp of the operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocMicros { + get { return allocMicros_; } + set { + allocMicros_ = value; + } + } + + /// Field number for the "alloc_bytes" field. + public const int AllocBytesFieldNumber = 2; + private long allocBytes_; + /// + /// Number of bytes allocated, or de-allocated if negative. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocBytes { + get { return allocBytes_; } + set { + allocBytes_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AllocationRecord); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AllocationRecord other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AllocMicros != other.AllocMicros) return false; + if (AllocBytes != other.AllocBytes) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (AllocMicros != 0L) hash ^= AllocMicros.GetHashCode(); + if (AllocBytes != 0L) hash ^= AllocBytes.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (AllocMicros != 0L) { + output.WriteRawTag(8); + output.WriteInt64(AllocMicros); + } + if (AllocBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllocBytes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (AllocMicros != 0L) { + output.WriteRawTag(8); + output.WriteInt64(AllocMicros); + } + if (AllocBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllocBytes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (AllocMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocMicros); + } + if (AllocBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocBytes); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AllocationRecord other) { + if (other == null) { + return; + } + if (other.AllocMicros != 0L) { + AllocMicros = other.AllocMicros; + } + if (other.AllocBytes != 0L) { + AllocBytes = other.AllocBytes; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AllocMicros = input.ReadInt64(); + break; + } + case 16: { + AllocBytes = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + AllocMicros = input.ReadInt64(); + break; + } + case 16: { + AllocBytes = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class AllocatorMemoryUsed : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AllocatorMemoryUsed()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocatorMemoryUsed() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocatorMemoryUsed(AllocatorMemoryUsed other) : this() { + allocatorName_ = other.allocatorName_; + totalBytes_ = other.totalBytes_; + peakBytes_ = other.peakBytes_; + liveBytes_ = other.liveBytes_; + allocationRecords_ = other.allocationRecords_.Clone(); + allocatorBytesInUse_ = other.allocatorBytesInUse_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public AllocatorMemoryUsed Clone() { + return new AllocatorMemoryUsed(this); + } + + /// Field number for the "allocator_name" field. + public const int AllocatorNameFieldNumber = 1; + private string allocatorName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string AllocatorName { + get { return allocatorName_; } + set { + allocatorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "total_bytes" field. + public const int TotalBytesFieldNumber = 2; + private long totalBytes_; + /// + /// These are per-node allocator memory stats. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TotalBytes { + get { return totalBytes_; } + set { + totalBytes_ = value; + } + } + + /// Field number for the "peak_bytes" field. + public const int PeakBytesFieldNumber = 3; + private long peakBytes_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PeakBytes { + get { return peakBytes_; } + set { + peakBytes_ = value; + } + } + + /// Field number for the "live_bytes" field. + public const int LiveBytesFieldNumber = 4; + private long liveBytes_; + /// + /// The bytes that are not deallocated. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long LiveBytes { + get { return liveBytes_; } + set { + liveBytes_ = value; + } + } + + /// Field number for the "allocation_records" field. + public const int AllocationRecordsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_allocationRecords_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.AllocationRecord.Parser); + private readonly pbc::RepeatedField allocationRecords_ = new pbc::RepeatedField(); + /// + /// The allocation and deallocation timeline. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AllocationRecords { + get { return allocationRecords_; } + } + + /// Field number for the "allocator_bytes_in_use" field. + public const int AllocatorBytesInUseFieldNumber = 5; + private long allocatorBytesInUse_; + /// + /// These are snapshots of the overall allocator memory stats. + /// The number of live bytes currently allocated by the allocator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllocatorBytesInUse { + get { return allocatorBytesInUse_; } + set { + allocatorBytesInUse_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as AllocatorMemoryUsed); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(AllocatorMemoryUsed other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AllocatorName != other.AllocatorName) return false; + if (TotalBytes != other.TotalBytes) return false; + if (PeakBytes != other.PeakBytes) return false; + if (LiveBytes != other.LiveBytes) return false; + if(!allocationRecords_.Equals(other.allocationRecords_)) return false; + if (AllocatorBytesInUse != other.AllocatorBytesInUse) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (AllocatorName.Length != 0) hash ^= AllocatorName.GetHashCode(); + if (TotalBytes != 0L) hash ^= TotalBytes.GetHashCode(); + if (PeakBytes != 0L) hash ^= PeakBytes.GetHashCode(); + if (LiveBytes != 0L) hash ^= LiveBytes.GetHashCode(); + hash ^= allocationRecords_.GetHashCode(); + if (AllocatorBytesInUse != 0L) hash ^= AllocatorBytesInUse.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (AllocatorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(AllocatorName); + } + if (TotalBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(TotalBytes); + } + if (PeakBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PeakBytes); + } + if (LiveBytes != 0L) { + output.WriteRawTag(32); + output.WriteInt64(LiveBytes); + } + if (AllocatorBytesInUse != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllocatorBytesInUse); + } + allocationRecords_.WriteTo(output, _repeated_allocationRecords_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (AllocatorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(AllocatorName); + } + if (TotalBytes != 0L) { + output.WriteRawTag(16); + output.WriteInt64(TotalBytes); + } + if (PeakBytes != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PeakBytes); + } + if (LiveBytes != 0L) { + output.WriteRawTag(32); + output.WriteInt64(LiveBytes); + } + if (AllocatorBytesInUse != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllocatorBytesInUse); + } + allocationRecords_.WriteTo(ref output, _repeated_allocationRecords_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (AllocatorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(AllocatorName); + } + if (TotalBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TotalBytes); + } + if (PeakBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PeakBytes); + } + if (LiveBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(LiveBytes); + } + size += allocationRecords_.CalculateSize(_repeated_allocationRecords_codec); + if (AllocatorBytesInUse != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllocatorBytesInUse); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(AllocatorMemoryUsed other) { + if (other == null) { + return; + } + if (other.AllocatorName.Length != 0) { + AllocatorName = other.AllocatorName; + } + if (other.TotalBytes != 0L) { + TotalBytes = other.TotalBytes; + } + if (other.PeakBytes != 0L) { + PeakBytes = other.PeakBytes; + } + if (other.LiveBytes != 0L) { + LiveBytes = other.LiveBytes; + } + allocationRecords_.Add(other.allocationRecords_); + if (other.AllocatorBytesInUse != 0L) { + AllocatorBytesInUse = other.AllocatorBytesInUse; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + AllocatorName = input.ReadString(); + break; + } + case 16: { + TotalBytes = input.ReadInt64(); + break; + } + case 24: { + PeakBytes = input.ReadInt64(); + break; + } + case 32: { + LiveBytes = input.ReadInt64(); + break; + } + case 40: { + AllocatorBytesInUse = input.ReadInt64(); + break; + } + case 50: { + allocationRecords_.AddEntriesFrom(input, _repeated_allocationRecords_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + AllocatorName = input.ReadString(); + break; + } + case 16: { + TotalBytes = input.ReadInt64(); + break; + } + case 24: { + PeakBytes = input.ReadInt64(); + break; + } + case 32: { + LiveBytes = input.ReadInt64(); + break; + } + case 40: { + AllocatorBytesInUse = input.ReadInt64(); + break; + } + case 50: { + allocationRecords_.AddEntriesFrom(ref input, _repeated_allocationRecords_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Output sizes recorded for a single execution of a graph node. + /// + public sealed partial class NodeOutput : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeOutput()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeOutput() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeOutput(NodeOutput other) : this() { + slot_ = other.slot_; + tensorDescription_ = other.tensorDescription_ != null ? other.tensorDescription_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeOutput Clone() { + return new NodeOutput(this); + } + + /// Field number for the "slot" field. + public const int SlotFieldNumber = 1; + private int slot_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Slot { + get { return slot_; } + set { + slot_ = value; + } + } + + /// Field number for the "tensor_description" field. + public const int TensorDescriptionFieldNumber = 3; + private global::Tensorflow.TensorDescription tensorDescription_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorDescription TensorDescription { + get { return tensorDescription_; } + set { + tensorDescription_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as NodeOutput); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(NodeOutput other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Slot != other.Slot) return false; + if (!object.Equals(TensorDescription, other.TensorDescription)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Slot != 0) hash ^= Slot.GetHashCode(); + if (tensorDescription_ != null) hash ^= TensorDescription.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Slot != 0) { + output.WriteRawTag(8); + output.WriteInt32(Slot); + } + if (tensorDescription_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorDescription); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Slot != 0) { + output.WriteRawTag(8); + output.WriteInt32(Slot); + } + if (tensorDescription_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorDescription); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Slot != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Slot); + } + if (tensorDescription_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorDescription); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(NodeOutput other) { + if (other == null) { + return; + } + if (other.Slot != 0) { + Slot = other.Slot; + } + if (other.tensorDescription_ != null) { + if (tensorDescription_ == null) { + TensorDescription = new global::Tensorflow.TensorDescription(); + } + TensorDescription.MergeFrom(other.TensorDescription); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Slot = input.ReadInt32(); + break; + } + case 26: { + if (tensorDescription_ == null) { + TensorDescription = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(TensorDescription); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Slot = input.ReadInt32(); + break; + } + case 26: { + if (tensorDescription_ == null) { + TensorDescription = new global::Tensorflow.TensorDescription(); + } + input.ReadMessage(TensorDescription); + break; + } + } + } + } + #endif + + } + + /// + /// For memory tracking. + /// + public sealed partial class MemoryStats : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemoryStats()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryStats() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryStats(MemoryStats other) : this() { + tempMemorySize_ = other.tempMemorySize_; + persistentMemorySize_ = other.persistentMemorySize_; + persistentTensorAllocIds_ = other.persistentTensorAllocIds_.Clone(); + deviceTempMemorySize_ = other.deviceTempMemorySize_; + devicePersistentMemorySize_ = other.devicePersistentMemorySize_; + devicePersistentTensorAllocIds_ = other.devicePersistentTensorAllocIds_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MemoryStats Clone() { + return new MemoryStats(this); + } + + /// Field number for the "temp_memory_size" field. + public const int TempMemorySizeFieldNumber = 1; + private long tempMemorySize_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long TempMemorySize { + get { return tempMemorySize_; } + set { + tempMemorySize_ = value; + } + } + + /// Field number for the "persistent_memory_size" field. + public const int PersistentMemorySizeFieldNumber = 3; + private long persistentMemorySize_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PersistentMemorySize { + get { return persistentMemorySize_; } + set { + persistentMemorySize_ = value; + } + } + + /// Field number for the "persistent_tensor_alloc_ids" field. + public const int PersistentTensorAllocIdsFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_persistentTensorAllocIds_codec + = pb::FieldCodec.ForInt64(42); + private readonly pbc::RepeatedField persistentTensorAllocIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField PersistentTensorAllocIds { + get { return persistentTensorAllocIds_; } + } + + /// Field number for the "device_temp_memory_size" field. + public const int DeviceTempMemorySizeFieldNumber = 2; + private long deviceTempMemorySize_; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DeviceTempMemorySize { + get { return deviceTempMemorySize_; } + set { + deviceTempMemorySize_ = value; + } + } + + /// Field number for the "device_persistent_memory_size" field. + public const int DevicePersistentMemorySizeFieldNumber = 4; + private long devicePersistentMemorySize_; + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DevicePersistentMemorySize { + get { return devicePersistentMemorySize_; } + set { + devicePersistentMemorySize_ = value; + } + } + + /// Field number for the "device_persistent_tensor_alloc_ids" field. + public const int DevicePersistentTensorAllocIdsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_devicePersistentTensorAllocIds_codec + = pb::FieldCodec.ForInt64(50); + private readonly pbc::RepeatedField devicePersistentTensorAllocIds_ = new pbc::RepeatedField(); + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DevicePersistentTensorAllocIds { + get { return devicePersistentTensorAllocIds_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MemoryStats); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MemoryStats other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TempMemorySize != other.TempMemorySize) return false; + if (PersistentMemorySize != other.PersistentMemorySize) return false; + if(!persistentTensorAllocIds_.Equals(other.persistentTensorAllocIds_)) return false; + if (DeviceTempMemorySize != other.DeviceTempMemorySize) return false; + if (DevicePersistentMemorySize != other.DevicePersistentMemorySize) return false; + if(!devicePersistentTensorAllocIds_.Equals(other.devicePersistentTensorAllocIds_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TempMemorySize != 0L) hash ^= TempMemorySize.GetHashCode(); + if (PersistentMemorySize != 0L) hash ^= PersistentMemorySize.GetHashCode(); + hash ^= persistentTensorAllocIds_.GetHashCode(); + if (DeviceTempMemorySize != 0L) hash ^= DeviceTempMemorySize.GetHashCode(); + if (DevicePersistentMemorySize != 0L) hash ^= DevicePersistentMemorySize.GetHashCode(); + hash ^= devicePersistentTensorAllocIds_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TempMemorySize != 0L) { + output.WriteRawTag(8); + output.WriteInt64(TempMemorySize); + } + if (DeviceTempMemorySize != 0L) { + output.WriteRawTag(16); + output.WriteInt64(DeviceTempMemorySize); + } + if (PersistentMemorySize != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PersistentMemorySize); + } + if (DevicePersistentMemorySize != 0L) { + output.WriteRawTag(32); + output.WriteInt64(DevicePersistentMemorySize); + } + persistentTensorAllocIds_.WriteTo(output, _repeated_persistentTensorAllocIds_codec); + devicePersistentTensorAllocIds_.WriteTo(output, _repeated_devicePersistentTensorAllocIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TempMemorySize != 0L) { + output.WriteRawTag(8); + output.WriteInt64(TempMemorySize); + } + if (DeviceTempMemorySize != 0L) { + output.WriteRawTag(16); + output.WriteInt64(DeviceTempMemorySize); + } + if (PersistentMemorySize != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PersistentMemorySize); + } + if (DevicePersistentMemorySize != 0L) { + output.WriteRawTag(32); + output.WriteInt64(DevicePersistentMemorySize); + } + persistentTensorAllocIds_.WriteTo(ref output, _repeated_persistentTensorAllocIds_codec); + devicePersistentTensorAllocIds_.WriteTo(ref output, _repeated_devicePersistentTensorAllocIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TempMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(TempMemorySize); + } + if (PersistentMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PersistentMemorySize); + } + size += persistentTensorAllocIds_.CalculateSize(_repeated_persistentTensorAllocIds_codec); + if (DeviceTempMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DeviceTempMemorySize); + } + if (DevicePersistentMemorySize != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DevicePersistentMemorySize); + } + size += devicePersistentTensorAllocIds_.CalculateSize(_repeated_devicePersistentTensorAllocIds_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MemoryStats other) { + if (other == null) { + return; + } + if (other.TempMemorySize != 0L) { + TempMemorySize = other.TempMemorySize; + } + if (other.PersistentMemorySize != 0L) { + PersistentMemorySize = other.PersistentMemorySize; + } + persistentTensorAllocIds_.Add(other.persistentTensorAllocIds_); + if (other.DeviceTempMemorySize != 0L) { + DeviceTempMemorySize = other.DeviceTempMemorySize; + } + if (other.DevicePersistentMemorySize != 0L) { + DevicePersistentMemorySize = other.DevicePersistentMemorySize; + } + devicePersistentTensorAllocIds_.Add(other.devicePersistentTensorAllocIds_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TempMemorySize = input.ReadInt64(); + break; + } + case 16: { + DeviceTempMemorySize = input.ReadInt64(); + break; + } + case 24: { + PersistentMemorySize = input.ReadInt64(); + break; + } + case 32: { + DevicePersistentMemorySize = input.ReadInt64(); + break; + } + case 42: + case 40: { + persistentTensorAllocIds_.AddEntriesFrom(input, _repeated_persistentTensorAllocIds_codec); + break; + } + case 50: + case 48: { + devicePersistentTensorAllocIds_.AddEntriesFrom(input, _repeated_devicePersistentTensorAllocIds_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + TempMemorySize = input.ReadInt64(); + break; + } + case 16: { + DeviceTempMemorySize = input.ReadInt64(); + break; + } + case 24: { + PersistentMemorySize = input.ReadInt64(); + break; + } + case 32: { + DevicePersistentMemorySize = input.ReadInt64(); + break; + } + case 42: + case 40: { + persistentTensorAllocIds_.AddEntriesFrom(ref input, _repeated_persistentTensorAllocIds_codec); + break; + } + case 50: + case 48: { + devicePersistentTensorAllocIds_.AddEntriesFrom(ref input, _repeated_devicePersistentTensorAllocIds_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Time/size stats recorded for a single execution of a graph node. + /// + public sealed partial class NodeExecStats : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeExecStats()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeExecStats() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeExecStats(NodeExecStats other) : this() { + nodeName_ = other.nodeName_; + allStartMicros_ = other.allStartMicros_; + opStartRelMicros_ = other.opStartRelMicros_; + opEndRelMicros_ = other.opEndRelMicros_; + allEndRelMicros_ = other.allEndRelMicros_; + memory_ = other.memory_.Clone(); + output_ = other.output_.Clone(); + timelineLabel_ = other.timelineLabel_; + scheduledMicros_ = other.scheduledMicros_; + threadId_ = other.threadId_; + referencedTensor_ = other.referencedTensor_.Clone(); + memoryStats_ = other.memoryStats_ != null ? other.memoryStats_.Clone() : null; + allStartNanos_ = other.allStartNanos_; + opStartRelNanos_ = other.opStartRelNanos_; + opEndRelNanos_ = other.opEndRelNanos_; + allEndRelNanos_ = other.allEndRelNanos_; + scheduledNanos_ = other.scheduledNanos_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NodeExecStats Clone() { + return new NodeExecStats(this); + } + + /// Field number for the "node_name" field. + public const int NodeNameFieldNumber = 1; + private string nodeName_ = ""; + /// + /// TODO(tucker): Use some more compact form of node identity than + /// the full string name. Either all processes should agree on a + /// global id (cost_id?) for each node, or we should use a hash of + /// the name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string NodeName { + get { return nodeName_; } + set { + nodeName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "all_start_micros" field. + public const int AllStartMicrosFieldNumber = 2; + private long allStartMicros_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllStartMicros { + get { return allStartMicros_; } + set { + allStartMicros_ = value; + } + } + + /// Field number for the "op_start_rel_micros" field. + public const int OpStartRelMicrosFieldNumber = 3; + private long opStartRelMicros_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OpStartRelMicros { + get { return opStartRelMicros_; } + set { + opStartRelMicros_ = value; + } + } + + /// Field number for the "op_end_rel_micros" field. + public const int OpEndRelMicrosFieldNumber = 4; + private long opEndRelMicros_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OpEndRelMicros { + get { return opEndRelMicros_; } + set { + opEndRelMicros_ = value; + } + } + + /// Field number for the "all_end_rel_micros" field. + public const int AllEndRelMicrosFieldNumber = 5; + private long allEndRelMicros_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllEndRelMicros { + get { return allEndRelMicros_; } + set { + allEndRelMicros_ = value; + } + } + + /// Field number for the "memory" field. + public const int MemoryFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_memory_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.AllocatorMemoryUsed.Parser); + private readonly pbc::RepeatedField memory_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Memory { + get { return memory_; } + } + + /// Field number for the "output" field. + public const int OutputFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_output_codec + = pb::FieldCodec.ForMessage(58, global::Tensorflow.NodeOutput.Parser); + private readonly pbc::RepeatedField output_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Output { + get { return output_; } + } + + /// Field number for the "timeline_label" field. + public const int TimelineLabelFieldNumber = 8; + private string timelineLabel_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TimelineLabel { + get { return timelineLabel_; } + set { + timelineLabel_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "scheduled_micros" field. + public const int ScheduledMicrosFieldNumber = 9; + private long scheduledMicros_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ScheduledMicros { + get { return scheduledMicros_; } + set { + scheduledMicros_ = value; + } + } + + /// Field number for the "thread_id" field. + public const int ThreadIdFieldNumber = 10; + private uint threadId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public uint ThreadId { + get { return threadId_; } + set { + threadId_ = value; + } + } + + /// Field number for the "referenced_tensor" field. + public const int ReferencedTensorFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_referencedTensor_codec + = pb::FieldCodec.ForMessage(90, global::Tensorflow.AllocationDescription.Parser); + private readonly pbc::RepeatedField referencedTensor_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ReferencedTensor { + get { return referencedTensor_; } + } + + /// Field number for the "memory_stats" field. + public const int MemoryStatsFieldNumber = 12; + private global::Tensorflow.MemoryStats memoryStats_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.MemoryStats MemoryStats { + get { return memoryStats_; } + set { + memoryStats_ = value; + } + } + + /// Field number for the "all_start_nanos" field. + public const int AllStartNanosFieldNumber = 13; + private long allStartNanos_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllStartNanos { + get { return allStartNanos_; } + set { + allStartNanos_ = value; + } + } + + /// Field number for the "op_start_rel_nanos" field. + public const int OpStartRelNanosFieldNumber = 14; + private long opStartRelNanos_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OpStartRelNanos { + get { return opStartRelNanos_; } + set { + opStartRelNanos_ = value; + } + } + + /// Field number for the "op_end_rel_nanos" field. + public const int OpEndRelNanosFieldNumber = 15; + private long opEndRelNanos_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OpEndRelNanos { + get { return opEndRelNanos_; } + set { + opEndRelNanos_ = value; + } + } + + /// Field number for the "all_end_rel_nanos" field. + public const int AllEndRelNanosFieldNumber = 16; + private long allEndRelNanos_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AllEndRelNanos { + get { return allEndRelNanos_; } + set { + allEndRelNanos_ = value; + } + } + + /// Field number for the "scheduled_nanos" field. + public const int ScheduledNanosFieldNumber = 17; + private long scheduledNanos_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ScheduledNanos { + get { return scheduledNanos_; } + set { + scheduledNanos_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as NodeExecStats); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(NodeExecStats other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeName != other.NodeName) return false; + if (AllStartMicros != other.AllStartMicros) return false; + if (OpStartRelMicros != other.OpStartRelMicros) return false; + if (OpEndRelMicros != other.OpEndRelMicros) return false; + if (AllEndRelMicros != other.AllEndRelMicros) return false; + if(!memory_.Equals(other.memory_)) return false; + if(!output_.Equals(other.output_)) return false; + if (TimelineLabel != other.TimelineLabel) return false; + if (ScheduledMicros != other.ScheduledMicros) return false; + if (ThreadId != other.ThreadId) return false; + if(!referencedTensor_.Equals(other.referencedTensor_)) return false; + if (!object.Equals(MemoryStats, other.MemoryStats)) return false; + if (AllStartNanos != other.AllStartNanos) return false; + if (OpStartRelNanos != other.OpStartRelNanos) return false; + if (OpEndRelNanos != other.OpEndRelNanos) return false; + if (AllEndRelNanos != other.AllEndRelNanos) return false; + if (ScheduledNanos != other.ScheduledNanos) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); + if (AllStartMicros != 0L) hash ^= AllStartMicros.GetHashCode(); + if (OpStartRelMicros != 0L) hash ^= OpStartRelMicros.GetHashCode(); + if (OpEndRelMicros != 0L) hash ^= OpEndRelMicros.GetHashCode(); + if (AllEndRelMicros != 0L) hash ^= AllEndRelMicros.GetHashCode(); + hash ^= memory_.GetHashCode(); + hash ^= output_.GetHashCode(); + if (TimelineLabel.Length != 0) hash ^= TimelineLabel.GetHashCode(); + if (ScheduledMicros != 0L) hash ^= ScheduledMicros.GetHashCode(); + if (ThreadId != 0) hash ^= ThreadId.GetHashCode(); + hash ^= referencedTensor_.GetHashCode(); + if (memoryStats_ != null) hash ^= MemoryStats.GetHashCode(); + if (AllStartNanos != 0L) hash ^= AllStartNanos.GetHashCode(); + if (OpStartRelNanos != 0L) hash ^= OpStartRelNanos.GetHashCode(); + if (OpEndRelNanos != 0L) hash ^= OpEndRelNanos.GetHashCode(); + if (AllEndRelNanos != 0L) hash ^= AllEndRelNanos.GetHashCode(); + if (ScheduledNanos != 0L) hash ^= ScheduledNanos.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(NodeName); + } + if (AllStartMicros != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllStartMicros); + } + if (OpStartRelMicros != 0L) { + output.WriteRawTag(24); + output.WriteInt64(OpStartRelMicros); + } + if (OpEndRelMicros != 0L) { + output.WriteRawTag(32); + output.WriteInt64(OpEndRelMicros); + } + if (AllEndRelMicros != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllEndRelMicros); + } + memory_.WriteTo(output, _repeated_memory_codec); + output_.WriteTo(output, _repeated_output_codec); + if (TimelineLabel.Length != 0) { + output.WriteRawTag(66); + output.WriteString(TimelineLabel); + } + if (ScheduledMicros != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ScheduledMicros); + } + if (ThreadId != 0) { + output.WriteRawTag(80); + output.WriteUInt32(ThreadId); + } + referencedTensor_.WriteTo(output, _repeated_referencedTensor_codec); + if (memoryStats_ != null) { + output.WriteRawTag(98); + output.WriteMessage(MemoryStats); + } + if (AllStartNanos != 0L) { + output.WriteRawTag(104); + output.WriteInt64(AllStartNanos); + } + if (OpStartRelNanos != 0L) { + output.WriteRawTag(112); + output.WriteInt64(OpStartRelNanos); + } + if (OpEndRelNanos != 0L) { + output.WriteRawTag(120); + output.WriteInt64(OpEndRelNanos); + } + if (AllEndRelNanos != 0L) { + output.WriteRawTag(128, 1); + output.WriteInt64(AllEndRelNanos); + } + if (ScheduledNanos != 0L) { + output.WriteRawTag(136, 1); + output.WriteInt64(ScheduledNanos); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(NodeName); + } + if (AllStartMicros != 0L) { + output.WriteRawTag(16); + output.WriteInt64(AllStartMicros); + } + if (OpStartRelMicros != 0L) { + output.WriteRawTag(24); + output.WriteInt64(OpStartRelMicros); + } + if (OpEndRelMicros != 0L) { + output.WriteRawTag(32); + output.WriteInt64(OpEndRelMicros); + } + if (AllEndRelMicros != 0L) { + output.WriteRawTag(40); + output.WriteInt64(AllEndRelMicros); + } + memory_.WriteTo(ref output, _repeated_memory_codec); + output_.WriteTo(ref output, _repeated_output_codec); + if (TimelineLabel.Length != 0) { + output.WriteRawTag(66); + output.WriteString(TimelineLabel); + } + if (ScheduledMicros != 0L) { + output.WriteRawTag(72); + output.WriteInt64(ScheduledMicros); + } + if (ThreadId != 0) { + output.WriteRawTag(80); + output.WriteUInt32(ThreadId); + } + referencedTensor_.WriteTo(ref output, _repeated_referencedTensor_codec); + if (memoryStats_ != null) { + output.WriteRawTag(98); + output.WriteMessage(MemoryStats); + } + if (AllStartNanos != 0L) { + output.WriteRawTag(104); + output.WriteInt64(AllStartNanos); + } + if (OpStartRelNanos != 0L) { + output.WriteRawTag(112); + output.WriteInt64(OpStartRelNanos); + } + if (OpEndRelNanos != 0L) { + output.WriteRawTag(120); + output.WriteInt64(OpEndRelNanos); + } + if (AllEndRelNanos != 0L) { + output.WriteRawTag(128, 1); + output.WriteInt64(AllEndRelNanos); + } + if (ScheduledNanos != 0L) { + output.WriteRawTag(136, 1); + output.WriteInt64(ScheduledNanos); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(NodeName); + } + if (AllStartMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllStartMicros); + } + if (OpStartRelMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OpStartRelMicros); + } + if (OpEndRelMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OpEndRelMicros); + } + if (AllEndRelMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllEndRelMicros); + } + size += memory_.CalculateSize(_repeated_memory_codec); + size += output_.CalculateSize(_repeated_output_codec); + if (TimelineLabel.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TimelineLabel); + } + if (ScheduledMicros != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ScheduledMicros); + } + if (ThreadId != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(ThreadId); + } + size += referencedTensor_.CalculateSize(_repeated_referencedTensor_codec); + if (memoryStats_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MemoryStats); + } + if (AllStartNanos != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AllStartNanos); + } + if (OpStartRelNanos != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OpStartRelNanos); + } + if (OpEndRelNanos != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OpEndRelNanos); + } + if (AllEndRelNanos != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(AllEndRelNanos); + } + if (ScheduledNanos != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(ScheduledNanos); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(NodeExecStats other) { + if (other == null) { + return; + } + if (other.NodeName.Length != 0) { + NodeName = other.NodeName; + } + if (other.AllStartMicros != 0L) { + AllStartMicros = other.AllStartMicros; + } + if (other.OpStartRelMicros != 0L) { + OpStartRelMicros = other.OpStartRelMicros; + } + if (other.OpEndRelMicros != 0L) { + OpEndRelMicros = other.OpEndRelMicros; + } + if (other.AllEndRelMicros != 0L) { + AllEndRelMicros = other.AllEndRelMicros; + } + memory_.Add(other.memory_); + output_.Add(other.output_); + if (other.TimelineLabel.Length != 0) { + TimelineLabel = other.TimelineLabel; + } + if (other.ScheduledMicros != 0L) { + ScheduledMicros = other.ScheduledMicros; + } + if (other.ThreadId != 0) { + ThreadId = other.ThreadId; + } + referencedTensor_.Add(other.referencedTensor_); + if (other.memoryStats_ != null) { + if (memoryStats_ == null) { + MemoryStats = new global::Tensorflow.MemoryStats(); + } + MemoryStats.MergeFrom(other.MemoryStats); + } + if (other.AllStartNanos != 0L) { + AllStartNanos = other.AllStartNanos; + } + if (other.OpStartRelNanos != 0L) { + OpStartRelNanos = other.OpStartRelNanos; + } + if (other.OpEndRelNanos != 0L) { + OpEndRelNanos = other.OpEndRelNanos; + } + if (other.AllEndRelNanos != 0L) { + AllEndRelNanos = other.AllEndRelNanos; + } + if (other.ScheduledNanos != 0L) { + ScheduledNanos = other.ScheduledNanos; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + NodeName = input.ReadString(); + break; + } + case 16: { + AllStartMicros = input.ReadInt64(); + break; + } + case 24: { + OpStartRelMicros = input.ReadInt64(); + break; + } + case 32: { + OpEndRelMicros = input.ReadInt64(); + break; + } + case 40: { + AllEndRelMicros = input.ReadInt64(); + break; + } + case 50: { + memory_.AddEntriesFrom(input, _repeated_memory_codec); + break; + } + case 58: { + output_.AddEntriesFrom(input, _repeated_output_codec); + break; + } + case 66: { + TimelineLabel = input.ReadString(); + break; + } + case 72: { + ScheduledMicros = input.ReadInt64(); + break; + } + case 80: { + ThreadId = input.ReadUInt32(); + break; + } + case 90: { + referencedTensor_.AddEntriesFrom(input, _repeated_referencedTensor_codec); + break; + } + case 98: { + if (memoryStats_ == null) { + MemoryStats = new global::Tensorflow.MemoryStats(); + } + input.ReadMessage(MemoryStats); + break; + } + case 104: { + AllStartNanos = input.ReadInt64(); + break; + } + case 112: { + OpStartRelNanos = input.ReadInt64(); + break; + } + case 120: { + OpEndRelNanos = input.ReadInt64(); + break; + } + case 128: { + AllEndRelNanos = input.ReadInt64(); + break; + } + case 136: { + ScheduledNanos = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + NodeName = input.ReadString(); + break; + } + case 16: { + AllStartMicros = input.ReadInt64(); + break; + } + case 24: { + OpStartRelMicros = input.ReadInt64(); + break; + } + case 32: { + OpEndRelMicros = input.ReadInt64(); + break; + } + case 40: { + AllEndRelMicros = input.ReadInt64(); + break; + } + case 50: { + memory_.AddEntriesFrom(ref input, _repeated_memory_codec); + break; + } + case 58: { + output_.AddEntriesFrom(ref input, _repeated_output_codec); + break; + } + case 66: { + TimelineLabel = input.ReadString(); + break; + } + case 72: { + ScheduledMicros = input.ReadInt64(); + break; + } + case 80: { + ThreadId = input.ReadUInt32(); + break; + } + case 90: { + referencedTensor_.AddEntriesFrom(ref input, _repeated_referencedTensor_codec); + break; + } + case 98: { + if (memoryStats_ == null) { + MemoryStats = new global::Tensorflow.MemoryStats(); + } + input.ReadMessage(MemoryStats); + break; + } + case 104: { + AllStartNanos = input.ReadInt64(); + break; + } + case 112: { + OpStartRelNanos = input.ReadInt64(); + break; + } + case 120: { + OpEndRelNanos = input.ReadInt64(); + break; + } + case 128: { + AllEndRelNanos = input.ReadInt64(); + break; + } + case 136: { + ScheduledNanos = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeviceStepStats : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceStepStats()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceStepStats() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceStepStats(DeviceStepStats other) : this() { + device_ = other.device_; + nodeStats_ = other.nodeStats_.Clone(); + threadNames_ = other.threadNames_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceStepStats Clone() { + return new DeviceStepStats(this); + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 1; + private string device_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "node_stats" field. + public const int NodeStatsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_nodeStats_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.NodeExecStats.Parser); + private readonly pbc::RepeatedField nodeStats_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField NodeStats { + get { return nodeStats_; } + } + + /// Field number for the "thread_names" field. + public const int ThreadNamesFieldNumber = 3; + private static readonly pbc::MapField.Codec _map_threadNames_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForUInt32(8, 0), pb::FieldCodec.ForString(18, ""), 26); + private readonly pbc::MapField threadNames_ = new pbc::MapField(); + /// + /// Its key is thread id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField ThreadNames { + get { return threadNames_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceStepStats); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceStepStats other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Device != other.Device) return false; + if(!nodeStats_.Equals(other.nodeStats_)) return false; + if (!ThreadNames.Equals(other.ThreadNames)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Device.Length != 0) hash ^= Device.GetHashCode(); + hash ^= nodeStats_.GetHashCode(); + hash ^= ThreadNames.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + nodeStats_.WriteTo(output, _repeated_nodeStats_codec); + threadNames_.WriteTo(output, _map_threadNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + nodeStats_.WriteTo(ref output, _repeated_nodeStats_codec); + threadNames_.WriteTo(ref output, _map_threadNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + size += nodeStats_.CalculateSize(_repeated_nodeStats_codec); + size += threadNames_.CalculateSize(_map_threadNames_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceStepStats other) { + if (other == null) { + return; + } + if (other.Device.Length != 0) { + Device = other.Device; + } + nodeStats_.Add(other.nodeStats_); + threadNames_.Add(other.threadNames_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Device = input.ReadString(); + break; + } + case 18: { + nodeStats_.AddEntriesFrom(input, _repeated_nodeStats_codec); + break; + } + case 26: { + threadNames_.AddEntriesFrom(input, _map_threadNames_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Device = input.ReadString(); + break; + } + case 18: { + nodeStats_.AddEntriesFrom(ref input, _repeated_nodeStats_codec); + break; + } + case 26: { + threadNames_.AddEntriesFrom(ref input, _map_threadNames_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class StepStats : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StepStats()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StepStatsReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StepStats() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StepStats(StepStats other) : this() { + devStats_ = other.devStats_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StepStats Clone() { + return new StepStats(this); + } + + /// Field number for the "dev_stats" field. + public const int DevStatsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_devStats_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.DeviceStepStats.Parser); + private readonly pbc::RepeatedField devStats_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DevStats { + get { return devStats_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as StepStats); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(StepStats other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!devStats_.Equals(other.devStats_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= devStats_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + devStats_.WriteTo(output, _repeated_devStats_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + devStats_.WriteTo(ref output, _repeated_devStats_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += devStats_.CalculateSize(_repeated_devStats_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(StepStats other) { + if (other == null) { + return; + } + devStats_.Add(other.devStats_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + devStats_.AddEntriesFrom(input, _repeated_devStats_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + devStats_.AddEntriesFrom(ref input, _repeated_devStats_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Struct.cs b/src/TensorFlowNET.Core/Protobuf/Struct.cs new file mode 100644 index 000000000..6a2e39f37 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Struct.cs @@ -0,0 +1,3216 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/struct.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/struct.proto + public static partial class StructReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/struct.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static StructReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc3RydWN0LnByb3RvEgp0ZW5z", + "b3JmbG93GiZ0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvci5wcm90", + "bxosdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90ZW5zb3Jfc2hhcGUucHJv", + "dG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdHlwZXMucHJvdG8ikAUK", + "D1N0cnVjdHVyZWRWYWx1ZRIrCgpub25lX3ZhbHVlGAEgASgLMhUudGVuc29y", + "Zmxvdy5Ob25lVmFsdWVIABIXCg1mbG9hdDY0X3ZhbHVlGAsgASgBSAASFQoL", + "aW50NjRfdmFsdWUYDCABKBJIABIWCgxzdHJpbmdfdmFsdWUYDSABKAlIABIU", + "Cgpib29sX3ZhbHVlGA4gASgISAASOgoSdGVuc29yX3NoYXBlX3ZhbHVlGB8g", + "ASgLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3RvSAASMgoSdGVuc29y", + "X2R0eXBlX3ZhbHVlGCAgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZUgAEjgK", + "EXRlbnNvcl9zcGVjX3ZhbHVlGCEgASgLMhsudGVuc29yZmxvdy5UZW5zb3JT", + "cGVjUHJvdG9IABI0Cg90eXBlX3NwZWNfdmFsdWUYIiABKAsyGS50ZW5zb3Jm", + "bG93LlR5cGVTcGVjUHJvdG9IABJHChlib3VuZGVkX3RlbnNvcl9zcGVjX3Zh", + "bHVlGCMgASgLMiIudGVuc29yZmxvdy5Cb3VuZGVkVGVuc29yU3BlY1Byb3Rv", + "SAASKwoKbGlzdF92YWx1ZRgzIAEoCzIVLnRlbnNvcmZsb3cuTGlzdFZhbHVl", + "SAASLQoLdHVwbGVfdmFsdWUYNCABKAsyFi50ZW5zb3JmbG93LlR1cGxlVmFs", + "dWVIABIrCgpkaWN0X3ZhbHVlGDUgASgLMhUudGVuc29yZmxvdy5EaWN0VmFs", + "dWVIABI4ChFuYW1lZF90dXBsZV92YWx1ZRg2IAEoCzIbLnRlbnNvcmZsb3cu", + "TmFtZWRUdXBsZVZhbHVlSABCBgoEa2luZCILCglOb25lVmFsdWUiOAoJTGlz", + "dFZhbHVlEisKBnZhbHVlcxgBIAMoCzIbLnRlbnNvcmZsb3cuU3RydWN0dXJl", + "ZFZhbHVlIjkKClR1cGxlVmFsdWUSKwoGdmFsdWVzGAEgAygLMhsudGVuc29y", + "Zmxvdy5TdHJ1Y3R1cmVkVmFsdWUiigEKCURpY3RWYWx1ZRIxCgZmaWVsZHMY", + "ASADKAsyIS50ZW5zb3JmbG93LkRpY3RWYWx1ZS5GaWVsZHNFbnRyeRpKCgtG", + "aWVsZHNFbnRyeRILCgNrZXkYASABKAkSKgoFdmFsdWUYAiABKAsyGy50ZW5z", + "b3JmbG93LlN0cnVjdHVyZWRWYWx1ZToCOAEiRAoJUGFpclZhbHVlEgsKA2tl", + "eRgBIAEoCRIqCgV2YWx1ZRgCIAEoCzIbLnRlbnNvcmZsb3cuU3RydWN0dXJl", + "ZFZhbHVlIkYKD05hbWVkVHVwbGVWYWx1ZRIMCgRuYW1lGAEgASgJEiUKBnZh", + "bHVlcxgCIAMoCzIVLnRlbnNvcmZsb3cuUGFpclZhbHVlInEKD1RlbnNvclNw", + "ZWNQcm90bxIMCgRuYW1lGAEgASgJEisKBXNoYXBlGAIgASgLMhwudGVuc29y", + "Zmxvdy5UZW5zb3JTaGFwZVByb3RvEiMKBWR0eXBlGAMgASgOMhQudGVuc29y", + "Zmxvdy5EYXRhVHlwZSLMAQoWQm91bmRlZFRlbnNvclNwZWNQcm90bxIMCgRu", + "YW1lGAEgASgJEisKBXNoYXBlGAIgASgLMhwudGVuc29yZmxvdy5UZW5zb3JT", + "aGFwZVByb3RvEiMKBWR0eXBlGAMgASgOMhQudGVuc29yZmxvdy5EYXRhVHlw", + "ZRIoCgdtaW5pbXVtGAQgASgLMhcudGVuc29yZmxvdy5UZW5zb3JQcm90bxIo", + "CgdtYXhpbXVtGAUgASgLMhcudGVuc29yZmxvdy5UZW5zb3JQcm90byL4AwoN", + "VHlwZVNwZWNQcm90bxJACg90eXBlX3NwZWNfY2xhc3MYASABKA4yJy50ZW5z", + "b3JmbG93LlR5cGVTcGVjUHJvdG8uVHlwZVNwZWNDbGFzcxIvCgp0eXBlX3N0", + "YXRlGAIgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFsdWUSHAoUdHlw", + "ZV9zcGVjX2NsYXNzX25hbWUYAyABKAkSGwoTbnVtX2ZsYXRfY29tcG9uZW50", + "cxgEIAEoBSK4AgoNVHlwZVNwZWNDbGFzcxILCgdVTktOT1dOEAASFgoSU1BB", + "UlNFX1RFTlNPUl9TUEVDEAESFwoTSU5ERVhFRF9TTElDRVNfU1BFQxACEhYK", + "ElJBR0dFRF9URU5TT1JfU1BFQxADEhUKEVRFTlNPUl9BUlJBWV9TUEVDEAQS", + "FQoRREFUQV9EQVRBU0VUX1NQRUMQBRIWChJEQVRBX0lURVJBVE9SX1NQRUMQ", + "BhIRCg1PUFRJT05BTF9TUEVDEAcSFAoQUEVSX1JFUExJQ0FfU1BFQxAIEhEK", + "DVZBUklBQkxFX1NQRUMQCRIWChJST1dfUEFSVElUSU9OX1NQRUMQChIYChRS", + "RUdJU1RFUkVEX1RZUEVfU1BFQxAMEhcKE0VYVEVOU0lPTl9UWVBFX1NQRUMQ", + "DSIECAsQC0JXWlVnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90", + "ZW5zb3JmbG93L2dvL2NvcmUvcHJvdG9idWYvZm9yX2NvcmVfcHJvdG9zX2dv", + "X3Byb3RvYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.TensorReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.StructuredValue), global::Tensorflow.StructuredValue.Parser, new[]{ "NoneValue", "Float64Value", "Int64Value", "StringValue", "BoolValue", "TensorShapeValue", "TensorDtypeValue", "TensorSpecValue", "TypeSpecValue", "BoundedTensorSpecValue", "ListValue", "TupleValue", "DictValue", "NamedTupleValue" }, new[]{ "Kind" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NoneValue), global::Tensorflow.NoneValue.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ListValue), global::Tensorflow.ListValue.Parser, new[]{ "Values" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TupleValue), global::Tensorflow.TupleValue.Parser, new[]{ "Values" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DictValue), global::Tensorflow.DictValue.Parser, new[]{ "Fields" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.PairValue), global::Tensorflow.PairValue.Parser, new[]{ "Key", "Value" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NamedTupleValue), global::Tensorflow.NamedTupleValue.Parser, new[]{ "Name", "Values" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorSpecProto), global::Tensorflow.TensorSpecProto.Parser, new[]{ "Name", "Shape", "Dtype" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.BoundedTensorSpecProto), global::Tensorflow.BoundedTensorSpecProto.Parser, new[]{ "Name", "Shape", "Dtype", "Minimum", "Maximum" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TypeSpecProto), global::Tensorflow.TypeSpecProto.Parser, new[]{ "TypeSpecClass", "TypeState", "TypeSpecClassName", "NumFlatComponents" }, null, new[]{ typeof(global::Tensorflow.TypeSpecProto.Types.TypeSpecClass) }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// `StructuredValue` represents a dynamically typed value representing various + /// data structures that are inspired by Python data structures typically used in + /// TensorFlow functions as inputs and outputs. + /// + /// For example when saving a Layer there may be a `training` argument. If the + /// user passes a boolean True/False, that switches between two concrete + /// TensorFlow functions. In order to switch between them in the same way after + /// loading the SavedModel, we need to represent "True" and "False". + /// + /// A more advanced example might be a function which takes a list of + /// dictionaries mapping from strings to Tensors. In order to map from + /// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]` + /// after load to the right saved TensorFlow function, we need to represent the + /// nested structure and the strings, recording that we have a trace for anything + /// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([], + /// tf.float64)}]` as an example. + /// + /// Likewise functions may return nested structures of Tensors, for example + /// returning a dictionary mapping from strings to Tensors. In order for the + /// loaded function to return the same structure we need to serialize it. + /// + /// This is an ergonomic aid for working with loaded SavedModels, not a promise + /// to serialize all possible function signatures. For example we do not expect + /// to pickle generic Python objects, and ideally we'd stay language-agnostic. + /// + public sealed partial class StructuredValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StructuredValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StructuredValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StructuredValue(StructuredValue other) : this() { + switch (other.KindCase) { + case KindOneofCase.NoneValue: + NoneValue = other.NoneValue.Clone(); + break; + case KindOneofCase.Float64Value: + Float64Value = other.Float64Value; + break; + case KindOneofCase.Int64Value: + Int64Value = other.Int64Value; + break; + case KindOneofCase.StringValue: + StringValue = other.StringValue; + break; + case KindOneofCase.BoolValue: + BoolValue = other.BoolValue; + break; + case KindOneofCase.TensorShapeValue: + TensorShapeValue = other.TensorShapeValue.Clone(); + break; + case KindOneofCase.TensorDtypeValue: + TensorDtypeValue = other.TensorDtypeValue; + break; + case KindOneofCase.TensorSpecValue: + TensorSpecValue = other.TensorSpecValue.Clone(); + break; + case KindOneofCase.TypeSpecValue: + TypeSpecValue = other.TypeSpecValue.Clone(); + break; + case KindOneofCase.BoundedTensorSpecValue: + BoundedTensorSpecValue = other.BoundedTensorSpecValue.Clone(); + break; + case KindOneofCase.ListValue: + ListValue = other.ListValue.Clone(); + break; + case KindOneofCase.TupleValue: + TupleValue = other.TupleValue.Clone(); + break; + case KindOneofCase.DictValue: + DictValue = other.DictValue.Clone(); + break; + case KindOneofCase.NamedTupleValue: + NamedTupleValue = other.NamedTupleValue.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public StructuredValue Clone() { + return new StructuredValue(this); + } + + /// Field number for the "none_value" field. + public const int NoneValueFieldNumber = 1; + /// + /// Represents None. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.NoneValue NoneValue { + get { return kindCase_ == KindOneofCase.NoneValue ? (global::Tensorflow.NoneValue) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.NoneValue; + } + } + + /// Field number for the "float64_value" field. + public const int Float64ValueFieldNumber = 11; + /// + /// Represents a double-precision floating-point value (a Python `float`). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double Float64Value { + get { return kindCase_ == KindOneofCase.Float64Value ? (double) kind_ : 0D; } + set { + kind_ = value; + kindCase_ = KindOneofCase.Float64Value; + } + } + + /// Field number for the "int64_value" field. + public const int Int64ValueFieldNumber = 12; + /// + /// Represents a signed integer value, limited to 64 bits. + /// Larger values from Python's arbitrary-precision integers are unsupported. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Int64Value { + get { return kindCase_ == KindOneofCase.Int64Value ? (long) kind_ : 0L; } + set { + kind_ = value; + kindCase_ = KindOneofCase.Int64Value; + } + } + + /// Field number for the "string_value" field. + public const int StringValueFieldNumber = 13; + /// + /// Represents a string of Unicode characters stored in a Python `str`. + /// In Python 3, this is exactly what type `str` is. + /// In Python 2, this is the UTF-8 encoding of the characters. + /// For strings with ASCII characters only (as often used in TensorFlow code) + /// there is effectively no difference between the language versions. + /// The obsolescent `unicode` type of Python 2 is not supported here. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string StringValue { + get { return kindCase_ == KindOneofCase.StringValue ? (string) kind_ : ""; } + set { + kind_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + kindCase_ = KindOneofCase.StringValue; + } + } + + /// Field number for the "bool_value" field. + public const int BoolValueFieldNumber = 14; + /// + /// Represents a boolean value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool BoolValue { + get { return kindCase_ == KindOneofCase.BoolValue ? (bool) kind_ : false; } + set { + kind_ = value; + kindCase_ = KindOneofCase.BoolValue; + } + } + + /// Field number for the "tensor_shape_value" field. + public const int TensorShapeValueFieldNumber = 31; + /// + /// Represents a TensorShape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto TensorShapeValue { + get { return kindCase_ == KindOneofCase.TensorShapeValue ? (global::Tensorflow.TensorShapeProto) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.TensorShapeValue; + } + } + + /// Field number for the "tensor_dtype_value" field. + public const int TensorDtypeValueFieldNumber = 32; + /// + /// Represents an enum value for dtype. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType TensorDtypeValue { + get { return kindCase_ == KindOneofCase.TensorDtypeValue ? (global::Tensorflow.DataType) kind_ : global::Tensorflow.DataType.DtInvalid; } + set { + kind_ = value; + kindCase_ = KindOneofCase.TensorDtypeValue; + } + } + + /// Field number for the "tensor_spec_value" field. + public const int TensorSpecValueFieldNumber = 33; + /// + /// Represents a value for tf.TensorSpec. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorSpecProto TensorSpecValue { + get { return kindCase_ == KindOneofCase.TensorSpecValue ? (global::Tensorflow.TensorSpecProto) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.TensorSpecValue; + } + } + + /// Field number for the "type_spec_value" field. + public const int TypeSpecValueFieldNumber = 34; + /// + /// Represents a value for tf.TypeSpec. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TypeSpecProto TypeSpecValue { + get { return kindCase_ == KindOneofCase.TypeSpecValue ? (global::Tensorflow.TypeSpecProto) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.TypeSpecValue; + } + } + + /// Field number for the "bounded_tensor_spec_value" field. + public const int BoundedTensorSpecValueFieldNumber = 35; + /// + /// Represents a value for tf.BoundedTensorSpec. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.BoundedTensorSpecProto BoundedTensorSpecValue { + get { return kindCase_ == KindOneofCase.BoundedTensorSpecValue ? (global::Tensorflow.BoundedTensorSpecProto) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.BoundedTensorSpecValue; + } + } + + /// Field number for the "list_value" field. + public const int ListValueFieldNumber = 51; + /// + /// Represents a list of `Value`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.ListValue ListValue { + get { return kindCase_ == KindOneofCase.ListValue ? (global::Tensorflow.ListValue) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.ListValue; + } + } + + /// Field number for the "tuple_value" field. + public const int TupleValueFieldNumber = 52; + /// + /// Represents a tuple of `Value`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TupleValue TupleValue { + get { return kindCase_ == KindOneofCase.TupleValue ? (global::Tensorflow.TupleValue) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.TupleValue; + } + } + + /// Field number for the "dict_value" field. + public const int DictValueFieldNumber = 53; + /// + /// Represents a dict `Value`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DictValue DictValue { + get { return kindCase_ == KindOneofCase.DictValue ? (global::Tensorflow.DictValue) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.DictValue; + } + } + + /// Field number for the "named_tuple_value" field. + public const int NamedTupleValueFieldNumber = 54; + /// + /// Represents Python's namedtuple. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.NamedTupleValue NamedTupleValue { + get { return kindCase_ == KindOneofCase.NamedTupleValue ? (global::Tensorflow.NamedTupleValue) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.NamedTupleValue; + } + } + + private object kind_; + /// Enum of possible cases for the "kind" oneof. + public enum KindOneofCase { + None = 0, + NoneValue = 1, + Float64Value = 11, + Int64Value = 12, + StringValue = 13, + BoolValue = 14, + TensorShapeValue = 31, + TensorDtypeValue = 32, + TensorSpecValue = 33, + TypeSpecValue = 34, + BoundedTensorSpecValue = 35, + ListValue = 51, + TupleValue = 52, + DictValue = 53, + NamedTupleValue = 54, + } + private KindOneofCase kindCase_ = KindOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KindOneofCase KindCase { + get { return kindCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearKind() { + kindCase_ = KindOneofCase.None; + kind_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as StructuredValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(StructuredValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(NoneValue, other.NoneValue)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(Float64Value, other.Float64Value)) return false; + if (Int64Value != other.Int64Value) return false; + if (StringValue != other.StringValue) return false; + if (BoolValue != other.BoolValue) return false; + if (!object.Equals(TensorShapeValue, other.TensorShapeValue)) return false; + if (TensorDtypeValue != other.TensorDtypeValue) return false; + if (!object.Equals(TensorSpecValue, other.TensorSpecValue)) return false; + if (!object.Equals(TypeSpecValue, other.TypeSpecValue)) return false; + if (!object.Equals(BoundedTensorSpecValue, other.BoundedTensorSpecValue)) return false; + if (!object.Equals(ListValue, other.ListValue)) return false; + if (!object.Equals(TupleValue, other.TupleValue)) return false; + if (!object.Equals(DictValue, other.DictValue)) return false; + if (!object.Equals(NamedTupleValue, other.NamedTupleValue)) return false; + if (KindCase != other.KindCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (kindCase_ == KindOneofCase.NoneValue) hash ^= NoneValue.GetHashCode(); + if (kindCase_ == KindOneofCase.Float64Value) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(Float64Value); + if (kindCase_ == KindOneofCase.Int64Value) hash ^= Int64Value.GetHashCode(); + if (kindCase_ == KindOneofCase.StringValue) hash ^= StringValue.GetHashCode(); + if (kindCase_ == KindOneofCase.BoolValue) hash ^= BoolValue.GetHashCode(); + if (kindCase_ == KindOneofCase.TensorShapeValue) hash ^= TensorShapeValue.GetHashCode(); + if (kindCase_ == KindOneofCase.TensorDtypeValue) hash ^= TensorDtypeValue.GetHashCode(); + if (kindCase_ == KindOneofCase.TensorSpecValue) hash ^= TensorSpecValue.GetHashCode(); + if (kindCase_ == KindOneofCase.TypeSpecValue) hash ^= TypeSpecValue.GetHashCode(); + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) hash ^= BoundedTensorSpecValue.GetHashCode(); + if (kindCase_ == KindOneofCase.ListValue) hash ^= ListValue.GetHashCode(); + if (kindCase_ == KindOneofCase.TupleValue) hash ^= TupleValue.GetHashCode(); + if (kindCase_ == KindOneofCase.DictValue) hash ^= DictValue.GetHashCode(); + if (kindCase_ == KindOneofCase.NamedTupleValue) hash ^= NamedTupleValue.GetHashCode(); + hash ^= (int) kindCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (kindCase_ == KindOneofCase.NoneValue) { + output.WriteRawTag(10); + output.WriteMessage(NoneValue); + } + if (kindCase_ == KindOneofCase.Float64Value) { + output.WriteRawTag(89); + output.WriteDouble(Float64Value); + } + if (kindCase_ == KindOneofCase.Int64Value) { + output.WriteRawTag(96); + output.WriteSInt64(Int64Value); + } + if (kindCase_ == KindOneofCase.StringValue) { + output.WriteRawTag(106); + output.WriteString(StringValue); + } + if (kindCase_ == KindOneofCase.BoolValue) { + output.WriteRawTag(112); + output.WriteBool(BoolValue); + } + if (kindCase_ == KindOneofCase.TensorShapeValue) { + output.WriteRawTag(250, 1); + output.WriteMessage(TensorShapeValue); + } + if (kindCase_ == KindOneofCase.TensorDtypeValue) { + output.WriteRawTag(128, 2); + output.WriteEnum((int) TensorDtypeValue); + } + if (kindCase_ == KindOneofCase.TensorSpecValue) { + output.WriteRawTag(138, 2); + output.WriteMessage(TensorSpecValue); + } + if (kindCase_ == KindOneofCase.TypeSpecValue) { + output.WriteRawTag(146, 2); + output.WriteMessage(TypeSpecValue); + } + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) { + output.WriteRawTag(154, 2); + output.WriteMessage(BoundedTensorSpecValue); + } + if (kindCase_ == KindOneofCase.ListValue) { + output.WriteRawTag(154, 3); + output.WriteMessage(ListValue); + } + if (kindCase_ == KindOneofCase.TupleValue) { + output.WriteRawTag(162, 3); + output.WriteMessage(TupleValue); + } + if (kindCase_ == KindOneofCase.DictValue) { + output.WriteRawTag(170, 3); + output.WriteMessage(DictValue); + } + if (kindCase_ == KindOneofCase.NamedTupleValue) { + output.WriteRawTag(178, 3); + output.WriteMessage(NamedTupleValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (kindCase_ == KindOneofCase.NoneValue) { + output.WriteRawTag(10); + output.WriteMessage(NoneValue); + } + if (kindCase_ == KindOneofCase.Float64Value) { + output.WriteRawTag(89); + output.WriteDouble(Float64Value); + } + if (kindCase_ == KindOneofCase.Int64Value) { + output.WriteRawTag(96); + output.WriteSInt64(Int64Value); + } + if (kindCase_ == KindOneofCase.StringValue) { + output.WriteRawTag(106); + output.WriteString(StringValue); + } + if (kindCase_ == KindOneofCase.BoolValue) { + output.WriteRawTag(112); + output.WriteBool(BoolValue); + } + if (kindCase_ == KindOneofCase.TensorShapeValue) { + output.WriteRawTag(250, 1); + output.WriteMessage(TensorShapeValue); + } + if (kindCase_ == KindOneofCase.TensorDtypeValue) { + output.WriteRawTag(128, 2); + output.WriteEnum((int) TensorDtypeValue); + } + if (kindCase_ == KindOneofCase.TensorSpecValue) { + output.WriteRawTag(138, 2); + output.WriteMessage(TensorSpecValue); + } + if (kindCase_ == KindOneofCase.TypeSpecValue) { + output.WriteRawTag(146, 2); + output.WriteMessage(TypeSpecValue); + } + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) { + output.WriteRawTag(154, 2); + output.WriteMessage(BoundedTensorSpecValue); + } + if (kindCase_ == KindOneofCase.ListValue) { + output.WriteRawTag(154, 3); + output.WriteMessage(ListValue); + } + if (kindCase_ == KindOneofCase.TupleValue) { + output.WriteRawTag(162, 3); + output.WriteMessage(TupleValue); + } + if (kindCase_ == KindOneofCase.DictValue) { + output.WriteRawTag(170, 3); + output.WriteMessage(DictValue); + } + if (kindCase_ == KindOneofCase.NamedTupleValue) { + output.WriteRawTag(178, 3); + output.WriteMessage(NamedTupleValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (kindCase_ == KindOneofCase.NoneValue) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(NoneValue); + } + if (kindCase_ == KindOneofCase.Float64Value) { + size += 1 + 8; + } + if (kindCase_ == KindOneofCase.Int64Value) { + size += 1 + pb::CodedOutputStream.ComputeSInt64Size(Int64Value); + } + if (kindCase_ == KindOneofCase.StringValue) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(StringValue); + } + if (kindCase_ == KindOneofCase.BoolValue) { + size += 1 + 1; + } + if (kindCase_ == KindOneofCase.TensorShapeValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(TensorShapeValue); + } + if (kindCase_ == KindOneofCase.TensorDtypeValue) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) TensorDtypeValue); + } + if (kindCase_ == KindOneofCase.TensorSpecValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(TensorSpecValue); + } + if (kindCase_ == KindOneofCase.TypeSpecValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(TypeSpecValue); + } + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(BoundedTensorSpecValue); + } + if (kindCase_ == KindOneofCase.ListValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ListValue); + } + if (kindCase_ == KindOneofCase.TupleValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(TupleValue); + } + if (kindCase_ == KindOneofCase.DictValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(DictValue); + } + if (kindCase_ == KindOneofCase.NamedTupleValue) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(NamedTupleValue); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(StructuredValue other) { + if (other == null) { + return; + } + switch (other.KindCase) { + case KindOneofCase.NoneValue: + if (NoneValue == null) { + NoneValue = new global::Tensorflow.NoneValue(); + } + NoneValue.MergeFrom(other.NoneValue); + break; + case KindOneofCase.Float64Value: + Float64Value = other.Float64Value; + break; + case KindOneofCase.Int64Value: + Int64Value = other.Int64Value; + break; + case KindOneofCase.StringValue: + StringValue = other.StringValue; + break; + case KindOneofCase.BoolValue: + BoolValue = other.BoolValue; + break; + case KindOneofCase.TensorShapeValue: + if (TensorShapeValue == null) { + TensorShapeValue = new global::Tensorflow.TensorShapeProto(); + } + TensorShapeValue.MergeFrom(other.TensorShapeValue); + break; + case KindOneofCase.TensorDtypeValue: + TensorDtypeValue = other.TensorDtypeValue; + break; + case KindOneofCase.TensorSpecValue: + if (TensorSpecValue == null) { + TensorSpecValue = new global::Tensorflow.TensorSpecProto(); + } + TensorSpecValue.MergeFrom(other.TensorSpecValue); + break; + case KindOneofCase.TypeSpecValue: + if (TypeSpecValue == null) { + TypeSpecValue = new global::Tensorflow.TypeSpecProto(); + } + TypeSpecValue.MergeFrom(other.TypeSpecValue); + break; + case KindOneofCase.BoundedTensorSpecValue: + if (BoundedTensorSpecValue == null) { + BoundedTensorSpecValue = new global::Tensorflow.BoundedTensorSpecProto(); + } + BoundedTensorSpecValue.MergeFrom(other.BoundedTensorSpecValue); + break; + case KindOneofCase.ListValue: + if (ListValue == null) { + ListValue = new global::Tensorflow.ListValue(); + } + ListValue.MergeFrom(other.ListValue); + break; + case KindOneofCase.TupleValue: + if (TupleValue == null) { + TupleValue = new global::Tensorflow.TupleValue(); + } + TupleValue.MergeFrom(other.TupleValue); + break; + case KindOneofCase.DictValue: + if (DictValue == null) { + DictValue = new global::Tensorflow.DictValue(); + } + DictValue.MergeFrom(other.DictValue); + break; + case KindOneofCase.NamedTupleValue: + if (NamedTupleValue == null) { + NamedTupleValue = new global::Tensorflow.NamedTupleValue(); + } + NamedTupleValue.MergeFrom(other.NamedTupleValue); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.NoneValue subBuilder = new global::Tensorflow.NoneValue(); + if (kindCase_ == KindOneofCase.NoneValue) { + subBuilder.MergeFrom(NoneValue); + } + input.ReadMessage(subBuilder); + NoneValue = subBuilder; + break; + } + case 89: { + Float64Value = input.ReadDouble(); + break; + } + case 96: { + Int64Value = input.ReadSInt64(); + break; + } + case 106: { + StringValue = input.ReadString(); + break; + } + case 112: { + BoolValue = input.ReadBool(); + break; + } + case 250: { + global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); + if (kindCase_ == KindOneofCase.TensorShapeValue) { + subBuilder.MergeFrom(TensorShapeValue); + } + input.ReadMessage(subBuilder); + TensorShapeValue = subBuilder; + break; + } + case 256: { + kind_ = input.ReadEnum(); + kindCase_ = KindOneofCase.TensorDtypeValue; + break; + } + case 266: { + global::Tensorflow.TensorSpecProto subBuilder = new global::Tensorflow.TensorSpecProto(); + if (kindCase_ == KindOneofCase.TensorSpecValue) { + subBuilder.MergeFrom(TensorSpecValue); + } + input.ReadMessage(subBuilder); + TensorSpecValue = subBuilder; + break; + } + case 274: { + global::Tensorflow.TypeSpecProto subBuilder = new global::Tensorflow.TypeSpecProto(); + if (kindCase_ == KindOneofCase.TypeSpecValue) { + subBuilder.MergeFrom(TypeSpecValue); + } + input.ReadMessage(subBuilder); + TypeSpecValue = subBuilder; + break; + } + case 282: { + global::Tensorflow.BoundedTensorSpecProto subBuilder = new global::Tensorflow.BoundedTensorSpecProto(); + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) { + subBuilder.MergeFrom(BoundedTensorSpecValue); + } + input.ReadMessage(subBuilder); + BoundedTensorSpecValue = subBuilder; + break; + } + case 410: { + global::Tensorflow.ListValue subBuilder = new global::Tensorflow.ListValue(); + if (kindCase_ == KindOneofCase.ListValue) { + subBuilder.MergeFrom(ListValue); + } + input.ReadMessage(subBuilder); + ListValue = subBuilder; + break; + } + case 418: { + global::Tensorflow.TupleValue subBuilder = new global::Tensorflow.TupleValue(); + if (kindCase_ == KindOneofCase.TupleValue) { + subBuilder.MergeFrom(TupleValue); + } + input.ReadMessage(subBuilder); + TupleValue = subBuilder; + break; + } + case 426: { + global::Tensorflow.DictValue subBuilder = new global::Tensorflow.DictValue(); + if (kindCase_ == KindOneofCase.DictValue) { + subBuilder.MergeFrom(DictValue); + } + input.ReadMessage(subBuilder); + DictValue = subBuilder; + break; + } + case 434: { + global::Tensorflow.NamedTupleValue subBuilder = new global::Tensorflow.NamedTupleValue(); + if (kindCase_ == KindOneofCase.NamedTupleValue) { + subBuilder.MergeFrom(NamedTupleValue); + } + input.ReadMessage(subBuilder); + NamedTupleValue = subBuilder; + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + global::Tensorflow.NoneValue subBuilder = new global::Tensorflow.NoneValue(); + if (kindCase_ == KindOneofCase.NoneValue) { + subBuilder.MergeFrom(NoneValue); + } + input.ReadMessage(subBuilder); + NoneValue = subBuilder; + break; + } + case 89: { + Float64Value = input.ReadDouble(); + break; + } + case 96: { + Int64Value = input.ReadSInt64(); + break; + } + case 106: { + StringValue = input.ReadString(); + break; + } + case 112: { + BoolValue = input.ReadBool(); + break; + } + case 250: { + global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); + if (kindCase_ == KindOneofCase.TensorShapeValue) { + subBuilder.MergeFrom(TensorShapeValue); + } + input.ReadMessage(subBuilder); + TensorShapeValue = subBuilder; + break; + } + case 256: { + kind_ = input.ReadEnum(); + kindCase_ = KindOneofCase.TensorDtypeValue; + break; + } + case 266: { + global::Tensorflow.TensorSpecProto subBuilder = new global::Tensorflow.TensorSpecProto(); + if (kindCase_ == KindOneofCase.TensorSpecValue) { + subBuilder.MergeFrom(TensorSpecValue); + } + input.ReadMessage(subBuilder); + TensorSpecValue = subBuilder; + break; + } + case 274: { + global::Tensorflow.TypeSpecProto subBuilder = new global::Tensorflow.TypeSpecProto(); + if (kindCase_ == KindOneofCase.TypeSpecValue) { + subBuilder.MergeFrom(TypeSpecValue); + } + input.ReadMessage(subBuilder); + TypeSpecValue = subBuilder; + break; + } + case 282: { + global::Tensorflow.BoundedTensorSpecProto subBuilder = new global::Tensorflow.BoundedTensorSpecProto(); + if (kindCase_ == KindOneofCase.BoundedTensorSpecValue) { + subBuilder.MergeFrom(BoundedTensorSpecValue); + } + input.ReadMessage(subBuilder); + BoundedTensorSpecValue = subBuilder; + break; + } + case 410: { + global::Tensorflow.ListValue subBuilder = new global::Tensorflow.ListValue(); + if (kindCase_ == KindOneofCase.ListValue) { + subBuilder.MergeFrom(ListValue); + } + input.ReadMessage(subBuilder); + ListValue = subBuilder; + break; + } + case 418: { + global::Tensorflow.TupleValue subBuilder = new global::Tensorflow.TupleValue(); + if (kindCase_ == KindOneofCase.TupleValue) { + subBuilder.MergeFrom(TupleValue); + } + input.ReadMessage(subBuilder); + TupleValue = subBuilder; + break; + } + case 426: { + global::Tensorflow.DictValue subBuilder = new global::Tensorflow.DictValue(); + if (kindCase_ == KindOneofCase.DictValue) { + subBuilder.MergeFrom(DictValue); + } + input.ReadMessage(subBuilder); + DictValue = subBuilder; + break; + } + case 434: { + global::Tensorflow.NamedTupleValue subBuilder = new global::Tensorflow.NamedTupleValue(); + if (kindCase_ == KindOneofCase.NamedTupleValue) { + subBuilder.MergeFrom(NamedTupleValue); + } + input.ReadMessage(subBuilder); + NamedTupleValue = subBuilder; + break; + } + } + } + } + #endif + + } + + /// + /// Represents None. + /// + public sealed partial class NoneValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NoneValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NoneValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NoneValue(NoneValue other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NoneValue Clone() { + return new NoneValue(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as NoneValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(NoneValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(NoneValue other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + /// + /// Represents a Python list. + /// + public sealed partial class ListValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ListValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ListValue(ListValue other) : this() { + values_ = other.values_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ListValue Clone() { + return new ListValue(this); + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.StructuredValue.Parser); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Values { + get { return values_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ListValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ListValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!values_.Equals(other.values_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= values_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + values_.WriteTo(output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + values_.WriteTo(ref output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += values_.CalculateSize(_repeated_values_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ListValue other) { + if (other == null) { + return; + } + values_.Add(other.values_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + values_.AddEntriesFrom(ref input, _repeated_values_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Represents a Python tuple. + /// + public sealed partial class TupleValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TupleValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TupleValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TupleValue(TupleValue other) : this() { + values_ = other.values_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TupleValue Clone() { + return new TupleValue(this); + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.StructuredValue.Parser); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Values { + get { return values_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TupleValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TupleValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!values_.Equals(other.values_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= values_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + values_.WriteTo(output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + values_.WriteTo(ref output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += values_.CalculateSize(_repeated_values_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TupleValue other) { + if (other == null) { + return; + } + values_.Add(other.values_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + values_.AddEntriesFrom(ref input, _repeated_values_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Represents a Python dict keyed by `str`. + /// The comment on Unicode from Value.string_value applies analogously. + /// + public sealed partial class DictValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DictValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DictValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DictValue(DictValue other) : this() { + fields_ = other.fields_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DictValue Clone() { + return new DictValue(this); + } + + /// Field number for the "fields" field. + public const int FieldsFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_fields_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.StructuredValue.Parser), 10); + private readonly pbc::MapField fields_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Fields { + get { return fields_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DictValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DictValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Fields.Equals(other.Fields)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= Fields.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + fields_.WriteTo(output, _map_fields_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + fields_.WriteTo(ref output, _map_fields_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += fields_.CalculateSize(_map_fields_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DictValue other) { + if (other == null) { + return; + } + fields_.Add(other.fields_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + fields_.AddEntriesFrom(input, _map_fields_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + fields_.AddEntriesFrom(ref input, _map_fields_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Represents a (key, value) pair. + /// + public sealed partial class PairValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PairValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PairValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PairValue(PairValue other) : this() { + key_ = other.key_; + value_ = other.value_ != null ? other.value_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PairValue Clone() { + return new PairValue(this); + } + + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Key { + get { return key_; } + set { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 2; + private global::Tensorflow.StructuredValue value_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue Value { + get { return value_; } + set { + value_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as PairValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(PairValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Key != other.Key) return false; + if (!object.Equals(Value, other.Value)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (value_ != null) hash ^= Value.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (value_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Key.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (value_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Value); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Key.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (value_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Value); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(PairValue other) { + if (other == null) { + return; + } + if (other.Key.Length != 0) { + Key = other.Key; + } + if (other.value_ != null) { + if (value_ == null) { + Value = new global::Tensorflow.StructuredValue(); + } + Value.MergeFrom(other.Value); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 18: { + if (value_ == null) { + Value = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(Value); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Key = input.ReadString(); + break; + } + case 18: { + if (value_ == null) { + Value = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(Value); + break; + } + } + } + } + #endif + + } + + /// + /// Represents Python's namedtuple. + /// + public sealed partial class NamedTupleValue : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NamedTupleValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NamedTupleValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NamedTupleValue(NamedTupleValue other) : this() { + name_ = other.name_; + values_ = other.values_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public NamedTupleValue Clone() { + return new NamedTupleValue(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.PairValue.Parser); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Values { + get { return values_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as NamedTupleValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(NamedTupleValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!values_.Equals(other.values_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= values_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + values_.WriteTo(output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + values_.WriteTo(ref output, _repeated_values_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += values_.CalculateSize(_repeated_values_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(NamedTupleValue other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + values_.Add(other.values_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + values_.AddEntriesFrom(ref input, _repeated_values_codec); + break; + } + } + } + } + #endif + + } + + /// + /// A protobuf to represent tf.TensorSpec. + /// + public sealed partial class TensorSpecProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorSpecProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSpecProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSpecProto(TensorSpecProto other) : this() { + name_ = other.name_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + dtype_ = other.dtype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSpecProto Clone() { + return new TensorSpecProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 3; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TensorSpecProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TensorSpecProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (Dtype != other.Dtype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TensorSpecProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + /// + /// A protobuf to represent tf.BoundedTensorSpec. + /// + public sealed partial class BoundedTensorSpecProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BoundedTensorSpecProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BoundedTensorSpecProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BoundedTensorSpecProto(BoundedTensorSpecProto other) : this() { + name_ = other.name_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + dtype_ = other.dtype_; + minimum_ = other.minimum_ != null ? other.minimum_.Clone() : null; + maximum_ = other.maximum_ != null ? other.maximum_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public BoundedTensorSpecProto Clone() { + return new BoundedTensorSpecProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 3; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "minimum" field. + public const int MinimumFieldNumber = 4; + private global::Tensorflow.TensorProto minimum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorProto Minimum { + get { return minimum_; } + set { + minimum_ = value; + } + } + + /// Field number for the "maximum" field. + public const int MaximumFieldNumber = 5; + private global::Tensorflow.TensorProto maximum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorProto Maximum { + get { return maximum_; } + set { + maximum_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as BoundedTensorSpecProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(BoundedTensorSpecProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (Dtype != other.Dtype) return false; + if (!object.Equals(Minimum, other.Minimum)) return false; + if (!object.Equals(Maximum, other.Maximum)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (minimum_ != null) hash ^= Minimum.GetHashCode(); + if (maximum_ != null) hash ^= Maximum.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (minimum_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Minimum); + } + if (maximum_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Maximum); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(24); + output.WriteEnum((int) Dtype); + } + if (minimum_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Minimum); + } + if (maximum_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Maximum); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (minimum_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Minimum); + } + if (maximum_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Maximum); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(BoundedTensorSpecProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.minimum_ != null) { + if (minimum_ == null) { + Minimum = new global::Tensorflow.TensorProto(); + } + Minimum.MergeFrom(other.Minimum); + } + if (other.maximum_ != null) { + if (maximum_ == null) { + Maximum = new global::Tensorflow.TensorProto(); + } + Maximum.MergeFrom(other.Maximum); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + if (minimum_ == null) { + Minimum = new global::Tensorflow.TensorProto(); + } + input.ReadMessage(Minimum); + break; + } + case 42: { + if (maximum_ == null) { + Maximum = new global::Tensorflow.TensorProto(); + } + input.ReadMessage(Maximum); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 24: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + if (minimum_ == null) { + Minimum = new global::Tensorflow.TensorProto(); + } + input.ReadMessage(Minimum); + break; + } + case 42: { + if (maximum_ == null) { + Maximum = new global::Tensorflow.TensorProto(); + } + input.ReadMessage(Maximum); + break; + } + } + } + } + #endif + + } + + /// + /// Represents a tf.TypeSpec + /// + public sealed partial class TypeSpecProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TypeSpecProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.StructReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TypeSpecProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TypeSpecProto(TypeSpecProto other) : this() { + typeSpecClass_ = other.typeSpecClass_; + typeState_ = other.typeState_ != null ? other.typeState_.Clone() : null; + typeSpecClassName_ = other.typeSpecClassName_; + numFlatComponents_ = other.numFlatComponents_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TypeSpecProto Clone() { + return new TypeSpecProto(this); + } + + /// Field number for the "type_spec_class" field. + public const int TypeSpecClassFieldNumber = 1; + private global::Tensorflow.TypeSpecProto.Types.TypeSpecClass typeSpecClass_ = global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TypeSpecProto.Types.TypeSpecClass TypeSpecClass { + get { return typeSpecClass_; } + set { + typeSpecClass_ = value; + } + } + + /// Field number for the "type_state" field. + public const int TypeStateFieldNumber = 2; + private global::Tensorflow.StructuredValue typeState_; + /// + /// The value returned by TypeSpec._serialize(). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.StructuredValue TypeState { + get { return typeState_; } + set { + typeState_ = value; + } + } + + /// Field number for the "type_spec_class_name" field. + public const int TypeSpecClassNameFieldNumber = 3; + private string typeSpecClassName_ = ""; + /// + /// The name of the TypeSpec class. + /// * If type_spec_class == REGISTERED_TYPE_SPEC, the TypeSpec class is + /// the one registered under this name. For types registered outside + /// core TensorFlow by an add-on library, that library must be loaded + /// before this value can be deserialized by nested_structure_coder. + /// * If type_spec_class specifies a particular TypeSpec class, this field is + /// redundant with the type_spec_class enum, and is only used for error + /// reporting in older binaries that do not know the tupe_spec_class enum. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TypeSpecClassName { + get { return typeSpecClassName_; } + set { + typeSpecClassName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "num_flat_components" field. + public const int NumFlatComponentsFieldNumber = 4; + private int numFlatComponents_; + /// + /// The number of flat tensor components required by this TypeSpec. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumFlatComponents { + get { return numFlatComponents_; } + set { + numFlatComponents_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TypeSpecProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TypeSpecProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TypeSpecClass != other.TypeSpecClass) return false; + if (!object.Equals(TypeState, other.TypeState)) return false; + if (TypeSpecClassName != other.TypeSpecClassName) return false; + if (NumFlatComponents != other.NumFlatComponents) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TypeSpecClass != global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown) hash ^= TypeSpecClass.GetHashCode(); + if (typeState_ != null) hash ^= TypeState.GetHashCode(); + if (TypeSpecClassName.Length != 0) hash ^= TypeSpecClassName.GetHashCode(); + if (NumFlatComponents != 0) hash ^= NumFlatComponents.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TypeSpecClass != global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown) { + output.WriteRawTag(8); + output.WriteEnum((int) TypeSpecClass); + } + if (typeState_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TypeState); + } + if (TypeSpecClassName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(TypeSpecClassName); + } + if (NumFlatComponents != 0) { + output.WriteRawTag(32); + output.WriteInt32(NumFlatComponents); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TypeSpecClass != global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown) { + output.WriteRawTag(8); + output.WriteEnum((int) TypeSpecClass); + } + if (typeState_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TypeState); + } + if (TypeSpecClassName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(TypeSpecClassName); + } + if (NumFlatComponents != 0) { + output.WriteRawTag(32); + output.WriteInt32(NumFlatComponents); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TypeSpecClass != global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TypeSpecClass); + } + if (typeState_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TypeState); + } + if (TypeSpecClassName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeSpecClassName); + } + if (NumFlatComponents != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumFlatComponents); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TypeSpecProto other) { + if (other == null) { + return; + } + if (other.TypeSpecClass != global::Tensorflow.TypeSpecProto.Types.TypeSpecClass.Unknown) { + TypeSpecClass = other.TypeSpecClass; + } + if (other.typeState_ != null) { + if (typeState_ == null) { + TypeState = new global::Tensorflow.StructuredValue(); + } + TypeState.MergeFrom(other.TypeState); + } + if (other.TypeSpecClassName.Length != 0) { + TypeSpecClassName = other.TypeSpecClassName; + } + if (other.NumFlatComponents != 0) { + NumFlatComponents = other.NumFlatComponents; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TypeSpecClass = (global::Tensorflow.TypeSpecProto.Types.TypeSpecClass) input.ReadEnum(); + break; + } + case 18: { + if (typeState_ == null) { + TypeState = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(TypeState); + break; + } + case 26: { + TypeSpecClassName = input.ReadString(); + break; + } + case 32: { + NumFlatComponents = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + TypeSpecClass = (global::Tensorflow.TypeSpecProto.Types.TypeSpecClass) input.ReadEnum(); + break; + } + case 18: { + if (typeState_ == null) { + TypeState = new global::Tensorflow.StructuredValue(); + } + input.ReadMessage(TypeState); + break; + } + case 26: { + TypeSpecClassName = input.ReadString(); + break; + } + case 32: { + NumFlatComponents = input.ReadInt32(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TypeSpecProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum TypeSpecClass { + [pbr::OriginalName("UNKNOWN")] Unknown = 0, + /// + /// tf.SparseTensorSpec + /// + [pbr::OriginalName("SPARSE_TENSOR_SPEC")] SparseTensorSpec = 1, + /// + /// tf.IndexedSlicesSpec + /// + [pbr::OriginalName("INDEXED_SLICES_SPEC")] IndexedSlicesSpec = 2, + /// + /// tf.RaggedTensorSpec + /// + [pbr::OriginalName("RAGGED_TENSOR_SPEC")] RaggedTensorSpec = 3, + /// + /// tf.TensorArraySpec + /// + [pbr::OriginalName("TENSOR_ARRAY_SPEC")] TensorArraySpec = 4, + /// + /// tf.data.DatasetSpec + /// + [pbr::OriginalName("DATA_DATASET_SPEC")] DataDatasetSpec = 5, + /// + /// IteratorSpec from data/ops/iterator_ops.py + /// + [pbr::OriginalName("DATA_ITERATOR_SPEC")] DataIteratorSpec = 6, + /// + /// tf.OptionalSpec + /// + [pbr::OriginalName("OPTIONAL_SPEC")] OptionalSpec = 7, + /// + /// PerReplicaSpec from distribute/values.py + /// + [pbr::OriginalName("PER_REPLICA_SPEC")] PerReplicaSpec = 8, + /// + /// tf.VariableSpec + /// + [pbr::OriginalName("VARIABLE_SPEC")] VariableSpec = 9, + /// + /// RowPartitionSpec from ragged/row_partition.py + /// + [pbr::OriginalName("ROW_PARTITION_SPEC")] RowPartitionSpec = 10, + /// + /// The type registered as type_spec_class_name. + /// + [pbr::OriginalName("REGISTERED_TYPE_SPEC")] RegisteredTypeSpec = 12, + /// + /// Subclasses of tf.ExtensionType + /// + [pbr::OriginalName("EXTENSION_TYPE_SPEC")] ExtensionTypeSpec = 13, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Summary.cs b/src/TensorFlowNET.Core/Protobuf/Summary.cs new file mode 100644 index 000000000..8f17e8dff --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Summary.cs @@ -0,0 +1,2336 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/summary.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/summary.proto + public static partial class SummaryReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/summary.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SummaryReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cid0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3N1bW1hcnkucHJvdG8SCnRl", + "bnNvcmZsb3caJ3RlbnNvcmZsb3cvdHNsL3Byb3RvYnVmL2hpc3RvZ3JhbS5w", + "cm90bxomdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90ZW5zb3IucHJvdG8i", + "JwoSU3VtbWFyeURlc2NyaXB0aW9uEhEKCXR5cGVfaGludBgBIAEoCSLgAQoP", + "U3VtbWFyeU1ldGFkYXRhEjsKC3BsdWdpbl9kYXRhGAEgASgLMiYudGVuc29y", + "Zmxvdy5TdW1tYXJ5TWV0YWRhdGEuUGx1Z2luRGF0YRIUCgxkaXNwbGF5X25h", + "bWUYAiABKAkSGwoTc3VtbWFyeV9kZXNjcmlwdGlvbhgDIAEoCRIpCgpkYXRh", + "X2NsYXNzGAQgASgOMhUudGVuc29yZmxvdy5EYXRhQ2xhc3MaMgoKUGx1Z2lu", + "RGF0YRITCgtwbHVnaW5fbmFtZRgBIAEoCRIPCgdjb250ZW50GAIgASgMIt4E", + "CgdTdW1tYXJ5EigKBXZhbHVlGAEgAygLMhkudGVuc29yZmxvdy5TdW1tYXJ5", + "LlZhbHVlGlgKBUltYWdlEg4KBmhlaWdodBgBIAEoBRINCgV3aWR0aBgCIAEo", + "BRISCgpjb2xvcnNwYWNlGAMgASgFEhwKFGVuY29kZWRfaW1hZ2Vfc3RyaW5n", + "GAQgASgMGn0KBUF1ZGlvEhMKC3NhbXBsZV9yYXRlGAEgASgCEhQKDG51bV9j", + "aGFubmVscxgCIAEoAxIVCg1sZW5ndGhfZnJhbWVzGAMgASgDEhwKFGVuY29k", + "ZWRfYXVkaW9fc3RyaW5nGAQgASgMEhQKDGNvbnRlbnRfdHlwZRgFIAEoCRrP", + "AgoFVmFsdWUSEQoJbm9kZV9uYW1lGAcgASgJEgsKA3RhZxgBIAEoCRItCght", + "ZXRhZGF0YRgJIAEoCzIbLnRlbnNvcmZsb3cuU3VtbWFyeU1ldGFkYXRhEhYK", + "DHNpbXBsZV92YWx1ZRgCIAEoAkgAEiYKHG9ic29sZXRlX29sZF9zdHlsZV9o", + "aXN0b2dyYW0YAyABKAxIABIqCgVpbWFnZRgEIAEoCzIZLnRlbnNvcmZsb3cu", + "U3VtbWFyeS5JbWFnZUgAEisKBWhpc3RvGAUgASgLMhoudGVuc29yZmxvdy5I", + "aXN0b2dyYW1Qcm90b0gAEioKBWF1ZGlvGAYgASgLMhkudGVuc29yZmxvdy5T", + "dW1tYXJ5LkF1ZGlvSAASKQoGdGVuc29yGAggASgLMhcudGVuc29yZmxvdy5U", + "ZW5zb3JQcm90b0gAQgcKBXZhbHVlKm8KCURhdGFDbGFzcxIWChJEQVRBX0NM", + "QVNTX1VOS05PV04QABIVChFEQVRBX0NMQVNTX1NDQUxBUhABEhUKEURBVEFf", + "Q0xBU1NfVEVOU09SEAISHAoYREFUQV9DTEFTU19CTE9CX1NFUVVFTkNFEANC", + "fgoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQg1TdW1tYXJ5UHJvdG9zUAFa", + "TmdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cv", + "Z28vY29yZS9mcmFtZXdvcmsvc3VtbWFyeV9nb19wcm90b/gBAVAAYgZwcm90", + "bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.HistogramReflection.Descriptor, global::Tensorflow.TensorReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataClass), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SummaryDescription), global::Tensorflow.SummaryDescription.Parser, new[]{ "TypeHint" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SummaryMetadata), global::Tensorflow.SummaryMetadata.Parser, new[]{ "PluginData", "DisplayName", "SummaryDescription", "DataClass" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SummaryMetadata.Types.PluginData), global::Tensorflow.SummaryMetadata.Types.PluginData.Parser, new[]{ "PluginName", "Content" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Summary), global::Tensorflow.Summary.Parser, new[]{ "Value" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Summary.Types.Image), global::Tensorflow.Summary.Types.Image.Parser, new[]{ "Height", "Width", "Colorspace", "EncodedImageString" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Summary.Types.Audio), global::Tensorflow.Summary.Types.Audio.Parser, new[]{ "SampleRate", "NumChannels", "LengthFrames", "EncodedAudioString", "ContentType" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Summary.Types.Value), global::Tensorflow.Summary.Types.Value.Parser, new[]{ "NodeName", "Tag", "Metadata", "SimpleValue", "ObsoleteOldStyleHistogram", "Image", "Histo", "Audio", "Tensor" }, new[]{ "Value" }, null, null, null)}) + })); + } + #endregion + + } + #region Enums + public enum DataClass { + /// + /// Unknown data class, used (implicitly) for legacy data. Will not be + /// processed by data ingestion pipelines. + /// + [pbr::OriginalName("DATA_CLASS_UNKNOWN")] Unknown = 0, + /// + /// Scalar time series. Each `Value` for the corresponding tag must have + /// `tensor` set to a rank-0 tensor of type `DT_FLOAT` (float32). + /// + [pbr::OriginalName("DATA_CLASS_SCALAR")] Scalar = 1, + /// + /// Tensor time series. Each `Value` for the corresponding tag must have + /// `tensor` set. The tensor value is arbitrary, but should be small to + /// accommodate direct storage in database backends: an upper bound of a few + /// kilobytes is a reasonable rule of thumb. + /// + [pbr::OriginalName("DATA_CLASS_TENSOR")] Tensor = 2, + /// + /// Blob sequence time series. Each `Value` for the corresponding tag must + /// have `tensor` set to a rank-1 tensor of bytestring dtype. + /// + [pbr::OriginalName("DATA_CLASS_BLOB_SEQUENCE")] BlobSequence = 3, + } + + #endregion + + #region Messages + /// + /// Metadata associated with a series of Summary data + /// + public sealed partial class SummaryDescription : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SummaryDescription()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SummaryReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryDescription() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryDescription(SummaryDescription other) : this() { + typeHint_ = other.typeHint_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryDescription Clone() { + return new SummaryDescription(this); + } + + /// Field number for the "type_hint" field. + public const int TypeHintFieldNumber = 1; + private string typeHint_ = ""; + /// + /// Hint on how plugins should process the data in this series. + /// Supported values include "scalar", "histogram", "image", "audio" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string TypeHint { + get { return typeHint_; } + set { + typeHint_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SummaryDescription); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SummaryDescription other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TypeHint != other.TypeHint) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (TypeHint.Length != 0) hash ^= TypeHint.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (TypeHint.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TypeHint); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TypeHint.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TypeHint); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (TypeHint.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeHint); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SummaryDescription other) { + if (other == null) { + return; + } + if (other.TypeHint.Length != 0) { + TypeHint = other.TypeHint; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + TypeHint = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + TypeHint = input.ReadString(); + break; + } + } + } + } + #endif + + } + + /// + /// A SummaryMetadata encapsulates information on which plugins are able to make + /// use of a certain summary value. + /// + public sealed partial class SummaryMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SummaryMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SummaryReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryMetadata(SummaryMetadata other) : this() { + pluginData_ = other.pluginData_ != null ? other.pluginData_.Clone() : null; + displayName_ = other.displayName_; + summaryDescription_ = other.summaryDescription_; + dataClass_ = other.dataClass_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SummaryMetadata Clone() { + return new SummaryMetadata(this); + } + + /// Field number for the "plugin_data" field. + public const int PluginDataFieldNumber = 1; + private global::Tensorflow.SummaryMetadata.Types.PluginData pluginData_; + /// + /// Data that associates a summary with a certain plugin. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SummaryMetadata.Types.PluginData PluginData { + get { return pluginData_; } + set { + pluginData_ = value; + } + } + + /// Field number for the "display_name" field. + public const int DisplayNameFieldNumber = 2; + private string displayName_ = ""; + /// + /// Display name for viewing in TensorBoard. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string DisplayName { + get { return displayName_; } + set { + displayName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "summary_description" field. + public const int SummaryDescriptionFieldNumber = 3; + private string summaryDescription_ = ""; + /// + /// Longform readable description of the summary sequence. Markdown supported. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string SummaryDescription { + get { return summaryDescription_; } + set { + summaryDescription_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "data_class" field. + public const int DataClassFieldNumber = 4; + private global::Tensorflow.DataClass dataClass_ = global::Tensorflow.DataClass.Unknown; + /// + /// Class of data stored in this time series. Required for compatibility with + /// TensorBoard's generic data facilities (`DataProvider`, et al.). This value + /// imposes constraints on the dtype and shape of the corresponding tensor + /// values. See `DataClass` docs for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataClass DataClass { + get { return dataClass_; } + set { + dataClass_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SummaryMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SummaryMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(PluginData, other.PluginData)) return false; + if (DisplayName != other.DisplayName) return false; + if (SummaryDescription != other.SummaryDescription) return false; + if (DataClass != other.DataClass) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (pluginData_ != null) hash ^= PluginData.GetHashCode(); + if (DisplayName.Length != 0) hash ^= DisplayName.GetHashCode(); + if (SummaryDescription.Length != 0) hash ^= SummaryDescription.GetHashCode(); + if (DataClass != global::Tensorflow.DataClass.Unknown) hash ^= DataClass.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (pluginData_ != null) { + output.WriteRawTag(10); + output.WriteMessage(PluginData); + } + if (DisplayName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DisplayName); + } + if (SummaryDescription.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SummaryDescription); + } + if (DataClass != global::Tensorflow.DataClass.Unknown) { + output.WriteRawTag(32); + output.WriteEnum((int) DataClass); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (pluginData_ != null) { + output.WriteRawTag(10); + output.WriteMessage(PluginData); + } + if (DisplayName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(DisplayName); + } + if (SummaryDescription.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SummaryDescription); + } + if (DataClass != global::Tensorflow.DataClass.Unknown) { + output.WriteRawTag(32); + output.WriteEnum((int) DataClass); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (pluginData_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PluginData); + } + if (DisplayName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DisplayName); + } + if (SummaryDescription.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SummaryDescription); + } + if (DataClass != global::Tensorflow.DataClass.Unknown) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DataClass); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SummaryMetadata other) { + if (other == null) { + return; + } + if (other.pluginData_ != null) { + if (pluginData_ == null) { + PluginData = new global::Tensorflow.SummaryMetadata.Types.PluginData(); + } + PluginData.MergeFrom(other.PluginData); + } + if (other.DisplayName.Length != 0) { + DisplayName = other.DisplayName; + } + if (other.SummaryDescription.Length != 0) { + SummaryDescription = other.SummaryDescription; + } + if (other.DataClass != global::Tensorflow.DataClass.Unknown) { + DataClass = other.DataClass; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (pluginData_ == null) { + PluginData = new global::Tensorflow.SummaryMetadata.Types.PluginData(); + } + input.ReadMessage(PluginData); + break; + } + case 18: { + DisplayName = input.ReadString(); + break; + } + case 26: { + SummaryDescription = input.ReadString(); + break; + } + case 32: { + DataClass = (global::Tensorflow.DataClass) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (pluginData_ == null) { + PluginData = new global::Tensorflow.SummaryMetadata.Types.PluginData(); + } + input.ReadMessage(PluginData); + break; + } + case 18: { + DisplayName = input.ReadString(); + break; + } + case 26: { + SummaryDescription = input.ReadString(); + break; + } + case 32: { + DataClass = (global::Tensorflow.DataClass) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the SummaryMetadata message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class PluginData : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PluginData()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SummaryMetadata.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PluginData() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PluginData(PluginData other) : this() { + pluginName_ = other.pluginName_; + content_ = other.content_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PluginData Clone() { + return new PluginData(this); + } + + /// Field number for the "plugin_name" field. + public const int PluginNameFieldNumber = 1; + private string pluginName_ = ""; + /// + /// The name of the plugin this data pertains to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string PluginName { + get { return pluginName_; } + set { + pluginName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "content" field. + public const int ContentFieldNumber = 2; + private pb::ByteString content_ = pb::ByteString.Empty; + /// + /// The content to store for the plugin. The best practice is for this to be + /// a binary serialized protocol buffer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Content { + get { return content_; } + set { + content_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as PluginData); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(PluginData other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (PluginName != other.PluginName) return false; + if (Content != other.Content) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (PluginName.Length != 0) hash ^= PluginName.GetHashCode(); + if (Content.Length != 0) hash ^= Content.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (PluginName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(PluginName); + } + if (Content.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Content); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (PluginName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(PluginName); + } + if (Content.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Content); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (PluginName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PluginName); + } + if (Content.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Content); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(PluginData other) { + if (other == null) { + return; + } + if (other.PluginName.Length != 0) { + PluginName = other.PluginName; + } + if (other.Content.Length != 0) { + Content = other.Content; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + PluginName = input.ReadString(); + break; + } + case 18: { + Content = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + PluginName = input.ReadString(); + break; + } + case 18: { + Content = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// A Summary is a set of named values to be displayed by the + /// visualizer. + /// + /// Summaries are produced regularly during training, as controlled by + /// the "summary_interval_secs" attribute of the training operation. + /// Summaries are also produced at the end of an evaluation. + /// + public sealed partial class Summary : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Summary()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SummaryReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Summary() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Summary(Summary other) : this() { + value_ = other.value_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Summary Clone() { + return new Summary(this); + } + + /// Field number for the "value" field. + public const int ValueFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_value_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Summary.Types.Value.Parser); + private readonly pbc::RepeatedField value_ = new pbc::RepeatedField(); + /// + /// Set of values for the summary. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Value { + get { return value_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Summary); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Summary other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!value_.Equals(other.value_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= value_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + value_.WriteTo(output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + value_.WriteTo(ref output, _repeated_value_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += value_.CalculateSize(_repeated_value_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Summary other) { + if (other == null) { + return; + } + value_.Add(other.value_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + value_.AddEntriesFrom(input, _repeated_value_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + value_.AddEntriesFrom(ref input, _repeated_value_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the Summary message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class Image : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Image()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Summary.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Image() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Image(Image other) : this() { + height_ = other.height_; + width_ = other.width_; + colorspace_ = other.colorspace_; + encodedImageString_ = other.encodedImageString_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Image Clone() { + return new Image(this); + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 1; + private int height_; + /// + /// Dimensions of the image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 2; + private int width_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "colorspace" field. + public const int ColorspaceFieldNumber = 3; + private int colorspace_; + /// + /// Valid colorspace values are + /// 1 - grayscale + /// 2 - grayscale + alpha + /// 3 - RGB + /// 4 - RGBA + /// 5 - DIGITAL_YUV + /// 6 - BGRA + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Colorspace { + get { return colorspace_; } + set { + colorspace_ = value; + } + } + + /// Field number for the "encoded_image_string" field. + public const int EncodedImageStringFieldNumber = 4; + private pb::ByteString encodedImageString_ = pb::ByteString.Empty; + /// + /// Image data in encoded format. All image formats supported by + /// image_codec::CoderUtil can be stored here. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString EncodedImageString { + get { return encodedImageString_; } + set { + encodedImageString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Image); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Image other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Height != other.Height) return false; + if (Width != other.Width) return false; + if (Colorspace != other.Colorspace) return false; + if (EncodedImageString != other.EncodedImageString) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Height != 0) hash ^= Height.GetHashCode(); + if (Width != 0) hash ^= Width.GetHashCode(); + if (Colorspace != 0) hash ^= Colorspace.GetHashCode(); + if (EncodedImageString.Length != 0) hash ^= EncodedImageString.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Height != 0) { + output.WriteRawTag(8); + output.WriteInt32(Height); + } + if (Width != 0) { + output.WriteRawTag(16); + output.WriteInt32(Width); + } + if (Colorspace != 0) { + output.WriteRawTag(24); + output.WriteInt32(Colorspace); + } + if (EncodedImageString.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(EncodedImageString); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Height != 0) { + output.WriteRawTag(8); + output.WriteInt32(Height); + } + if (Width != 0) { + output.WriteRawTag(16); + output.WriteInt32(Width); + } + if (Colorspace != 0) { + output.WriteRawTag(24); + output.WriteInt32(Colorspace); + } + if (EncodedImageString.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(EncodedImageString); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (Colorspace != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Colorspace); + } + if (EncodedImageString.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(EncodedImageString); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Image other) { + if (other == null) { + return; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.Colorspace != 0) { + Colorspace = other.Colorspace; + } + if (other.EncodedImageString.Length != 0) { + EncodedImageString = other.EncodedImageString; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Height = input.ReadInt32(); + break; + } + case 16: { + Width = input.ReadInt32(); + break; + } + case 24: { + Colorspace = input.ReadInt32(); + break; + } + case 34: { + EncodedImageString = input.ReadBytes(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Height = input.ReadInt32(); + break; + } + case 16: { + Width = input.ReadInt32(); + break; + } + case 24: { + Colorspace = input.ReadInt32(); + break; + } + case 34: { + EncodedImageString = input.ReadBytes(); + break; + } + } + } + } + #endif + + } + + public sealed partial class Audio : pb::IMessage [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public global::Tensorflow.TensorShapeProto TensorShape { get { return tensorShape_; } set { @@ -143,6 +158,7 @@ public TensorProto Clone() { /// to represent a constant Tensor with a single value. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int VersionNumber { get { return versionNumber_; } set { @@ -161,6 +177,7 @@ public int VersionNumber { /// many repeated small items. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pb::ByteString TensorContent { get { return tensorContent_; } set { @@ -178,6 +195,7 @@ public int VersionNumber { /// have some pointless zero padding for each value here. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField HalfVal { get { return halfVal_; } } @@ -191,6 +209,7 @@ public int VersionNumber { /// DT_FLOAT. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField FloatVal { get { return floatVal_; } } @@ -204,6 +223,7 @@ public int VersionNumber { /// DT_DOUBLE. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField DoubleVal { get { return doubleVal_; } } @@ -214,9 +234,10 @@ public int VersionNumber { = pb::FieldCodec.ForInt32(58); private readonly pbc::RepeatedField intVal_ = new pbc::RepeatedField(); /// - /// DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + /// DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField IntVal { get { return intVal_; } } @@ -230,6 +251,7 @@ public int VersionNumber { /// DT_STRING /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField StringVal { get { return stringVal_; } } @@ -244,6 +266,7 @@ public int VersionNumber { /// and imaginary parts of i-th single precision complex. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField ScomplexVal { get { return scomplexVal_; } } @@ -257,6 +280,7 @@ public int VersionNumber { /// DT_INT64 /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Int64Val { get { return int64Val_; } } @@ -270,6 +294,7 @@ public int VersionNumber { /// DT_BOOL /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField BoolVal { get { return boolVal_; } } @@ -284,6 +309,7 @@ public int VersionNumber { /// and imaginary parts of i-th double precision complex. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField DcomplexVal { get { return dcomplexVal_; } } @@ -297,6 +323,7 @@ public int VersionNumber { /// DT_RESOURCE /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField ResourceHandleVal { get { return resourceHandleVal_; } } @@ -310,6 +337,7 @@ public int VersionNumber { /// DT_VARIANT /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField VariantVal { get { return variantVal_; } } @@ -323,6 +351,7 @@ public int VersionNumber { /// DT_UINT32 /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Uint32Val { get { return uint32Val_; } } @@ -336,16 +365,19 @@ public int VersionNumber { /// DT_UINT64 /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Uint64Val { get { return uint64Val_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as TensorProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(TensorProto other) { if (ReferenceEquals(other, null)) { return false; @@ -374,9 +406,10 @@ public bool Equals(TensorProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; - if (Dtype != 0) hash ^= Dtype.GetHashCode(); + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); if (tensorShape_ != null) hash ^= TensorShape.GetHashCode(); if (VersionNumber != 0) hash ^= VersionNumber.GetHashCode(); if (TensorContent.Length != 0) hash ^= TensorContent.GetHashCode(); @@ -400,13 +433,18 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { - if (Dtype != 0) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Dtype != global::Tensorflow.DataType.DtInvalid) { output.WriteRawTag(8); output.WriteEnum((int) Dtype); } @@ -438,12 +476,53 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (tensorShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TensorShape); + } + if (VersionNumber != 0) { + output.WriteRawTag(24); + output.WriteInt32(VersionNumber); + } + if (TensorContent.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(TensorContent); + } + floatVal_.WriteTo(ref output, _repeated_floatVal_codec); + doubleVal_.WriteTo(ref output, _repeated_doubleVal_codec); + intVal_.WriteTo(ref output, _repeated_intVal_codec); + stringVal_.WriteTo(ref output, _repeated_stringVal_codec); + scomplexVal_.WriteTo(ref output, _repeated_scomplexVal_codec); + int64Val_.WriteTo(ref output, _repeated_int64Val_codec); + boolVal_.WriteTo(ref output, _repeated_boolVal_codec); + dcomplexVal_.WriteTo(ref output, _repeated_dcomplexVal_codec); + halfVal_.WriteTo(ref output, _repeated_halfVal_codec); + resourceHandleVal_.WriteTo(ref output, _repeated_resourceHandleVal_codec); + variantVal_.WriteTo(ref output, _repeated_variantVal_codec); + uint32Val_.WriteTo(ref output, _repeated_uint32Val_codec); + uint64Val_.WriteTo(ref output, _repeated_uint64Val_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; - if (Dtype != 0) { + if (Dtype != global::Tensorflow.DataType.DtInvalid) { size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); } if (tensorShape_ != null) { @@ -475,16 +554,17 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(TensorProto other) { if (other == null) { return; } - if (other.Dtype != 0) { + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { Dtype = other.Dtype; } if (other.tensorShape_ != null) { if (tensorShape_ == null) { - tensorShape_ = new global::Tensorflow.TensorShapeProto(); + TensorShape = new global::Tensorflow.TensorShapeProto(); } TensorShape.MergeFrom(other.TensorShape); } @@ -511,7 +591,11 @@ public void MergeFrom(TensorProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -519,14 +603,14 @@ public void MergeFrom(pb::CodedInputStream input) { _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); break; case 8: { - dtype_ = (global::Tensorflow.DataType) input.ReadEnum(); + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); break; } case 18: { if (tensorShape_ == null) { - tensorShape_ = new global::Tensorflow.TensorShapeProto(); + TensorShape = new global::Tensorflow.TensorShapeProto(); } - input.ReadMessage(tensorShape_); + input.ReadMessage(TensorShape); break; } case 24: { @@ -601,30 +685,135 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (tensorShape_ == null) { + TensorShape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(TensorShape); + break; + } + case 24: { + VersionNumber = input.ReadInt32(); + break; + } + case 34: { + TensorContent = input.ReadBytes(); + break; + } + case 42: + case 45: { + floatVal_.AddEntriesFrom(ref input, _repeated_floatVal_codec); + break; + } + case 50: + case 49: { + doubleVal_.AddEntriesFrom(ref input, _repeated_doubleVal_codec); + break; + } + case 58: + case 56: { + intVal_.AddEntriesFrom(ref input, _repeated_intVal_codec); + break; + } + case 66: { + stringVal_.AddEntriesFrom(ref input, _repeated_stringVal_codec); + break; + } + case 74: + case 77: { + scomplexVal_.AddEntriesFrom(ref input, _repeated_scomplexVal_codec); + break; + } + case 82: + case 80: { + int64Val_.AddEntriesFrom(ref input, _repeated_int64Val_codec); + break; + } + case 90: + case 88: { + boolVal_.AddEntriesFrom(ref input, _repeated_boolVal_codec); + break; + } + case 98: + case 97: { + dcomplexVal_.AddEntriesFrom(ref input, _repeated_dcomplexVal_codec); + break; + } + case 106: + case 104: { + halfVal_.AddEntriesFrom(ref input, _repeated_halfVal_codec); + break; + } + case 114: { + resourceHandleVal_.AddEntriesFrom(ref input, _repeated_resourceHandleVal_codec); + break; + } + case 122: { + variantVal_.AddEntriesFrom(ref input, _repeated_variantVal_codec); + break; + } + case 130: + case 128: { + uint32Val_.AddEntriesFrom(ref input, _repeated_uint32Val_codec); + break; + } + case 138: + case 136: { + uint64Val_.AddEntriesFrom(ref input, _repeated_uint64Val_codec); + break; + } + } + } + } + #endif + } /// /// Protocol buffer representing the serialization format of DT_VARIANT tensors. /// - public sealed partial class VariantTensorDataProto : pb::IMessage { + public sealed partial class VariantTensorDataProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VariantTensorDataProto()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.TensorReflection.Descriptor.MessageTypes[1]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public VariantTensorDataProto() { OnConstruction(); } @@ -632,6 +821,7 @@ public VariantTensorDataProto() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public VariantTensorDataProto(VariantTensorDataProto other) : this() { typeName_ = other.typeName_; metadata_ = other.metadata_; @@ -640,6 +830,7 @@ public VariantTensorDataProto(VariantTensorDataProto other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public VariantTensorDataProto Clone() { return new VariantTensorDataProto(this); } @@ -651,6 +842,7 @@ public VariantTensorDataProto Clone() { /// Name of the type of objects being serialized. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string TypeName { get { return typeName_; } set { @@ -665,6 +857,7 @@ public string TypeName { /// Portions of the object that are not Tensors. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pb::ByteString Metadata { get { return metadata_; } set { @@ -681,16 +874,19 @@ public string TypeName { /// Tensors contained within objects being serialized. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Tensors { get { return tensors_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as VariantTensorDataProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(VariantTensorDataProto other) { if (ReferenceEquals(other, null)) { return false; @@ -705,6 +901,7 @@ public bool Equals(VariantTensorDataProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (TypeName.Length != 0) hash ^= TypeName.GetHashCode(); @@ -717,12 +914,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (TypeName.Length != 0) { output.WriteRawTag(10); output.WriteString(TypeName); @@ -735,9 +937,30 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (TypeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TypeName); + } + if (Metadata.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Metadata); + } + tensors_.WriteTo(ref output, _repeated_tensors_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } } + #endif [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (TypeName.Length != 0) { @@ -754,6 +977,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(VariantTensorDataProto other) { if (other == null) { return; @@ -769,7 +993,11 @@ public void MergeFrom(VariantTensorDataProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -790,7 +1018,35 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + TypeName = input.ReadString(); + break; + } + case 18: { + Metadata = input.ReadBytes(); + break; + } + case 26: { + tensors_.AddEntriesFrom(ref input, _repeated_tensors_codec); + break; + } + } + } } + #endif } diff --git a/src/TensorFlowNET.Core/Protobuf/TensorDescription.cs b/src/TensorFlowNET.Core/Protobuf/TensorDescription.cs new file mode 100644 index 000000000..81b170abe --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/TensorDescription.cs @@ -0,0 +1,343 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/tensor_description.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/tensor_description.proto + public static partial class TensorDescriptionReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/tensor_description.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TensorDescriptionReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjJ0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvcl9kZXNjcmlwdGlv", + "bi5wcm90bxIKdGVuc29yZmxvdxo2dGVuc29yZmxvdy9jb3JlL2ZyYW1ld29y", + "ay9hbGxvY2F0aW9uX2Rlc2NyaXB0aW9uLnByb3RvGix0ZW5zb3JmbG93L2Nv", + "cmUvZnJhbWV3b3JrL3RlbnNvcl9zaGFwZS5wcm90bxoldGVuc29yZmxvdy9j", + "b3JlL2ZyYW1ld29yay90eXBlcy5wcm90byKoAQoRVGVuc29yRGVzY3JpcHRp", + "b24SIwoFZHR5cGUYASABKA4yFC50ZW5zb3JmbG93LkRhdGFUeXBlEisKBXNo", + "YXBlGAIgASgLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3RvEkEKFmFs", + "bG9jYXRpb25fZGVzY3JpcHRpb24YBCABKAsyIS50ZW5zb3JmbG93LkFsbG9j", + "YXRpb25EZXNjcmlwdGlvbkKTAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3Jr", + "QhdUZW5zb3JEZXNjcmlwdGlvblByb3Rvc1ABWllnaXRodWIuY29tL3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr", + "L3RlbnNvcl9kZXNjcmlwdGlvbl9nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AllocationDescriptionReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorDescription), global::Tensorflow.TensorDescription.Parser, new[]{ "Dtype", "Shape", "AllocationDescription" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class TensorDescription : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorDescription()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorDescriptionReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorDescription() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorDescription(TensorDescription other) : this() { + dtype_ = other.dtype_; + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + allocationDescription_ = other.allocationDescription_ != null ? other.allocationDescription_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorDescription Clone() { + return new TensorDescription(this); + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 1; + private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; + /// + /// Data type of tensor elements + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto shape_; + /// + /// Shape of the tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.TensorShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "allocation_description" field. + public const int AllocationDescriptionFieldNumber = 4; + private global::Tensorflow.AllocationDescription allocationDescription_; + /// + /// Information about the size and allocator used for the data + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.AllocationDescription AllocationDescription { + get { return allocationDescription_; } + set { + allocationDescription_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TensorDescription); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TensorDescription other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Dtype != other.Dtype) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (!object.Equals(AllocationDescription, other.AllocationDescription)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (allocationDescription_ != null) hash ^= AllocationDescription.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (allocationDescription_ != null) { + output.WriteRawTag(34); + output.WriteMessage(AllocationDescription); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (shape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (allocationDescription_ != null) { + output.WriteRawTag(34); + output.WriteMessage(AllocationDescription); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Dtype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (allocationDescription_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AllocationDescription); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TensorDescription other) { + if (other == null) { + return; + } + if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { + Dtype = other.Dtype; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + if (other.allocationDescription_ != null) { + if (allocationDescription_ == null) { + AllocationDescription = new global::Tensorflow.AllocationDescription(); + } + AllocationDescription.MergeFrom(other.AllocationDescription); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 34: { + if (allocationDescription_ == null) { + AllocationDescription = new global::Tensorflow.AllocationDescription(); + } + input.ReadMessage(AllocationDescription); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Dtype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (shape_ == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 34: { + if (allocationDescription_ == null) { + AllocationDescription = new global::Tensorflow.AllocationDescription(); + } + input.ReadMessage(AllocationDescription); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/TensorShape.cs b/src/TensorFlowNET.Core/Protobuf/TensorShape.cs index 0a891dd0e..e22ed820b 100644 --- a/src/TensorFlowNET.Core/Protobuf/TensorShape.cs +++ b/src/TensorFlowNET.Core/Protobuf/TensorShape.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: tensor_shape.proto +// source: tensorflow/core/framework/tensor_shape.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from tensor_shape.proto + /// Holder for reflection information generated from tensorflow/core/framework/tensor_shape.proto public static partial class TensorShapeReflection { #region Descriptor - /// File descriptor for tensor_shape.proto + /// File descriptor for tensorflow/core/framework/tensor_shape.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,17 +24,18 @@ public static partial class TensorShapeReflection { static TensorShapeReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "ChJ0ZW5zb3Jfc2hhcGUucHJvdG8SCnRlbnNvcmZsb3ciegoQVGVuc29yU2hh", - "cGVQcm90bxItCgNkaW0YAiADKAsyIC50ZW5zb3JmbG93LlRlbnNvclNoYXBl", - "UHJvdG8uRGltEhQKDHVua25vd25fcmFuaxgDIAEoCBohCgNEaW0SDAoEc2l6", - "ZRgBIAEoAxIMCgRuYW1lGAIgASgJQnEKGG9yZy50ZW5zb3JmbG93LmZyYW1l", - "d29ya0IRVGVuc29yU2hhcGVQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3Jm", - "bG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gB", - "AWIGcHJvdG8z")); + "Cix0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvcl9zaGFwZS5wcm90", + "bxIKdGVuc29yZmxvdyJ6ChBUZW5zb3JTaGFwZVByb3RvEi0KA2RpbRgCIAMo", + "CzIgLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90by5EaW0SFAoMdW5rbm93", + "bl9yYW5rGAMgASgIGiEKA0RpbRIMCgRzaXplGAEgASgDEgwKBG5hbWUYAiAB", + "KAlChwEKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IRVGVuc29yU2hhcGVQ", + "cm90b3NQAVpTZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVu", + "c29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay90ZW5zb3Jfc2hhcGVfZ29fcHJv", + "dG/4AQFiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto), global::Tensorflow.TensorShapeProto.Parser, new[]{ "Dim", "UnknownRank" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto.Types.Dim), global::Tensorflow.TensorShapeProto.Types.Dim.Parser, new[]{ "Size", "Name" }, null, null, null)}) + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto), global::Tensorflow.TensorShapeProto.Parser, new[]{ "Dim", "UnknownRank" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto.Types.Dim), global::Tensorflow.TensorShapeProto.Types.Dim.Parser, new[]{ "Size", "Name" }, null, null, null, null)}) })); } #endregion @@ -44,23 +45,31 @@ static TensorShapeReflection() { /// /// Dimensions of a tensor. /// - public sealed partial class TensorShapeProto : pb::IMessage { + public sealed partial class TensorShapeProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorShapeProto()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.TensorShapeReflection.Descriptor.MessageTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public TensorShapeProto() { OnConstruction(); } @@ -68,6 +77,7 @@ public TensorShapeProto() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public TensorShapeProto(TensorShapeProto other) : this() { dim_ = other.dim_.Clone(); unknownRank_ = other.unknownRank_; @@ -75,6 +85,7 @@ public TensorShapeProto(TensorShapeProto other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public TensorShapeProto Clone() { return new TensorShapeProto(this); } @@ -100,6 +111,7 @@ public TensorShapeProto Clone() { /// If "dim.size()" > 0, "unknown_rank" must be false. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public pbc::RepeatedField Dim { get { return dim_; } } @@ -113,6 +125,7 @@ public TensorShapeProto Clone() { /// If true, "dim.size()" must be 0. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool UnknownRank { get { return unknownRank_; } set { @@ -121,11 +134,13 @@ public bool UnknownRank { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as TensorShapeProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(TensorShapeProto other) { if (ReferenceEquals(other, null)) { return false; @@ -139,6 +154,7 @@ public bool Equals(TensorShapeProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; hash ^= dim_.GetHashCode(); @@ -150,12 +166,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else dim_.WriteTo(output, _repeated_dim_codec); if (UnknownRank != false) { output.WriteRawTag(24); @@ -164,9 +185,26 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + dim_.WriteTo(ref output, _repeated_dim_codec); + if (UnknownRank != false) { + output.WriteRawTag(24); + output.WriteBool(UnknownRank); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; size += dim_.CalculateSize(_repeated_dim_codec); @@ -180,6 +218,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(TensorShapeProto other) { if (other == null) { return; @@ -192,7 +231,11 @@ public void MergeFrom(TensorShapeProto other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -209,32 +252,65 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 18: { + dim_.AddEntriesFrom(ref input, _repeated_dim_codec); + break; + } + case 24: { + UnknownRank = input.ReadBool(); + break; + } + } + } + } + #endif + #region Nested types /// Container for nested types declared in the TensorShapeProto message type. [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static partial class Types { /// /// One dimension of the tensor. /// - public sealed partial class Dim : pb::IMessage { + public sealed partial class Dim : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Dim()); private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public static pbr::MessageDescriptor Descriptor { get { return global::Tensorflow.TensorShapeProto.Descriptor.NestedTypes[0]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] pbr::MessageDescriptor pb::IMessage.Descriptor { get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public Dim() { OnConstruction(); } @@ -242,6 +318,7 @@ public Dim() { partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public Dim(Dim other) : this() { size_ = other.size_; name_ = other.name_; @@ -249,6 +326,7 @@ public Dim(Dim other) : this() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public Dim Clone() { return new Dim(this); } @@ -264,6 +342,7 @@ public Dim Clone() { /// a TensorShapeProto containing a dim value of -1. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public long Size { get { return size_; } set { @@ -278,6 +357,7 @@ public long Size { /// Optional name of the tensor dimension. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public string Name { get { return name_; } set { @@ -286,11 +366,13 @@ public string Name { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override bool Equals(object other) { return Equals(other as Dim); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public bool Equals(Dim other) { if (ReferenceEquals(other, null)) { return false; @@ -304,6 +386,7 @@ public bool Equals(Dim other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override int GetHashCode() { int hash = 1; if (Size != 0L) hash ^= Size.GetHashCode(); @@ -315,12 +398,17 @@ public override int GetHashCode() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public override string ToString() { return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else if (Size != 0L) { output.WriteRawTag(8); output.WriteInt64(Size); @@ -332,9 +420,29 @@ public void WriteTo(pb::CodedOutputStream output) { if (_unknownFields != null) { _unknownFields.WriteTo(output); } + #endif } + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public int CalculateSize() { int size = 0; if (Size != 0L) { @@ -350,6 +458,7 @@ public int CalculateSize() { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(Dim other) { if (other == null) { return; @@ -364,7 +473,11 @@ public void MergeFrom(Dim other) { } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else uint tag; while ((tag = input.ReadTag()) != 0) { switch(tag) { @@ -381,7 +494,31 @@ public void MergeFrom(pb::CodedInputStream input) { } } } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + } + } } + #endif } diff --git a/src/TensorFlowNET.Core/Protobuf/TensorSlice.cs b/src/TensorFlowNET.Core/Protobuf/TensorSlice.cs new file mode 100644 index 000000000..cf1c44d35 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/TensorSlice.cs @@ -0,0 +1,507 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/tensor_slice.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/tensor_slice.proto + public static partial class TensorSliceReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/tensor_slice.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TensorSliceReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cix0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3RlbnNvcl9zbGljZS5wcm90", + "bxIKdGVuc29yZmxvdyKAAQoQVGVuc29yU2xpY2VQcm90bxIzCgZleHRlbnQY", + "ASADKAsyIy50ZW5zb3JmbG93LlRlbnNvclNsaWNlUHJvdG8uRXh0ZW50GjcK", + "BkV4dGVudBINCgVzdGFydBgBIAEoAxIQCgZsZW5ndGgYAiABKANIAEIMCgpo", + "YXNfbGVuZ3RoQocBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCEVRlbnNv", + "clNsaWNlUHJvdG9zUAFaU2dpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3Jm", + "bG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmsvdGVuc29yX3NsaWNl", + "X2dvX3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorSliceProto), global::Tensorflow.TensorSliceProto.Parser, new[]{ "Extent" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorSliceProto.Types.Extent), global::Tensorflow.TensorSliceProto.Types.Extent.Parser, new[]{ "Start", "Length" }, new[]{ "HasLength" }, null, null, null)}) + })); + } + #endregion + + } + #region Messages + /// + /// Can only be interpreted if you know the corresponding TensorShape. + /// + public sealed partial class TensorSliceProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorSliceProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorSliceReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSliceProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSliceProto(TensorSliceProto other) : this() { + extent_ = other.extent_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TensorSliceProto Clone() { + return new TensorSliceProto(this); + } + + /// Field number for the "extent" field. + public const int ExtentFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_extent_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.TensorSliceProto.Types.Extent.Parser); + private readonly pbc::RepeatedField extent_ = new pbc::RepeatedField(); + /// + /// Extent of the slice in all tensor dimensions. + /// + /// Must have one entry for each of the dimension of the tensor that this + /// slice belongs to. The order of sizes is the same as the order of + /// dimensions in the TensorShape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Extent { + get { return extent_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TensorSliceProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TensorSliceProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!extent_.Equals(other.extent_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= extent_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + extent_.WriteTo(output, _repeated_extent_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + extent_.WriteTo(ref output, _repeated_extent_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += extent_.CalculateSize(_repeated_extent_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TensorSliceProto other) { + if (other == null) { + return; + } + extent_.Add(other.extent_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + extent_.AddEntriesFrom(input, _repeated_extent_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + extent_.AddEntriesFrom(ref input, _repeated_extent_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TensorSliceProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Extent of the slice in one dimension. + /// + public sealed partial class Extent : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Extent()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorSliceProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Extent() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Extent(Extent other) : this() { + start_ = other.start_; + switch (other.HasLengthCase) { + case HasLengthOneofCase.Length: + Length = other.Length; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Extent Clone() { + return new Extent(this); + } + + /// Field number for the "start" field. + public const int StartFieldNumber = 1; + private long start_; + /// + /// Start index of the slice, starting at 0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Start { + get { return start_; } + set { + start_ = value; + } + } + + /// Field number for the "length" field. + public const int LengthFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Length { + get { return hasLengthCase_ == HasLengthOneofCase.Length ? (long) hasLength_ : 0L; } + set { + hasLength_ = value; + hasLengthCase_ = HasLengthOneofCase.Length; + } + } + + private object hasLength_; + /// Enum of possible cases for the "has_length" oneof. + public enum HasLengthOneofCase { + None = 0, + Length = 2, + } + private HasLengthOneofCase hasLengthCase_ = HasLengthOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public HasLengthOneofCase HasLengthCase { + get { return hasLengthCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearHasLength() { + hasLengthCase_ = HasLengthOneofCase.None; + hasLength_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Extent); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Extent other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Start != other.Start) return false; + if (Length != other.Length) return false; + if (HasLengthCase != other.HasLengthCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Start != 0L) hash ^= Start.GetHashCode(); + if (hasLengthCase_ == HasLengthOneofCase.Length) hash ^= Length.GetHashCode(); + hash ^= (int) hasLengthCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Start != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Start); + } + if (hasLengthCase_ == HasLengthOneofCase.Length) { + output.WriteRawTag(16); + output.WriteInt64(Length); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Start != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Start); + } + if (hasLengthCase_ == HasLengthOneofCase.Length) { + output.WriteRawTag(16); + output.WriteInt64(Length); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Start != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Start); + } + if (hasLengthCase_ == HasLengthOneofCase.Length) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Length); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Extent other) { + if (other == null) { + return; + } + if (other.Start != 0L) { + Start = other.Start; + } + switch (other.HasLengthCase) { + case HasLengthOneofCase.Length: + Length = other.Length; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Start = input.ReadInt64(); + break; + } + case 16: { + Length = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Start = input.ReadInt64(); + break; + } + case 16: { + Length = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs new file mode 100644 index 000000000..89bc07521 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -0,0 +1,1617 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/trackable_object_graph.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/trackable_object_graph.proto + public static partial class TrackableObjectGraphReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/trackable_object_graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TrackableObjectGraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjV0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvdHJhY2thYmxlX29iamVjdF9n", + "cmFwaC5wcm90bxIKdGVuc29yZmxvdxoeZ29vZ2xlL3Byb3RvYnVmL3dyYXBw", + "ZXJzLnByb3RvIvMFChRUcmFja2FibGVPYmplY3RHcmFwaBI/CgVub2RlcxgB", + "IAMoCzIwLnRlbnNvcmZsb3cuVHJhY2thYmxlT2JqZWN0R3JhcGguVHJhY2th", + "YmxlT2JqZWN0GpkFCg9UcmFja2FibGVPYmplY3QSUgoIY2hpbGRyZW4YASAD", + "KAsyQC50ZW5zb3JmbG93LlRyYWNrYWJsZU9iamVjdEdyYXBoLlRyYWNrYWJs", + "ZU9iamVjdC5PYmplY3RSZWZlcmVuY2USVQoKYXR0cmlidXRlcxgCIAMoCzJB", + "LnRlbnNvcmZsb3cuVHJhY2thYmxlT2JqZWN0R3JhcGguVHJhY2thYmxlT2Jq", + "ZWN0LlNlcmlhbGl6ZWRUZW5zb3ISXgoOc2xvdF92YXJpYWJsZXMYAyADKAsy", + "Ri50ZW5zb3JmbG93LlRyYWNrYWJsZU9iamVjdEdyYXBoLlRyYWNrYWJsZU9i", + "amVjdC5TbG90VmFyaWFibGVSZWZlcmVuY2USNQoQcmVnaXN0ZXJlZF9zYXZl", + "chgEIAEoCzIbLnRlbnNvcmZsb3cuUmVnaXN0ZXJlZFNhdmVyEjkKFWhhc19j", + "aGVja3BvaW50X3ZhbHVlcxgFIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5Cb29s", + "VmFsdWUaNgoPT2JqZWN0UmVmZXJlbmNlEg8KB25vZGVfaWQYASABKAUSEgoK", + "bG9jYWxfbmFtZRgCIAEoCRpjChBTZXJpYWxpemVkVGVuc29yEgwKBG5hbWUY", + "ASABKAkSEQoJZnVsbF9uYW1lGAIgASgJEhYKDmNoZWNrcG9pbnRfa2V5GAMg", + "ASgJSgQIBBAFUhBvcHRpb25hbF9yZXN0b3JlGmwKFVNsb3RWYXJpYWJsZVJl", + "ZmVyZW5jZRIhChlvcmlnaW5hbF92YXJpYWJsZV9ub2RlX2lkGAEgASgFEhEK", + "CXNsb3RfbmFtZRgCIAEoCRIdChVzbG90X3ZhcmlhYmxlX25vZGVfaWQYAyAB", + "KAUiNAoPUmVnaXN0ZXJlZFNhdmVyEgwKBG5hbWUYASABKAkSEwoLb2JqZWN0", + "X25hbWUYAiABKAlCWlpVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs", + "b3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3JlX3Byb3Rv", + "c19nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Google.Protobuf.WellKnownTypes.WrappersReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TrackableObjectGraph), global::Tensorflow.TrackableObjectGraph.Parser, new[]{ "Nodes" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TrackableObjectGraph.Types.TrackableObject), global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Parser, new[]{ "Children", "Attributes", "SlotVariables", "RegisteredSaver", "HasCheckpointValues" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference), global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser, new[]{ "NodeId", "LocalName" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor), global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor.Parser, new[]{ "Name", "FullName", "CheckpointKey" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference), global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference.Parser, new[]{ "OriginalVariableNodeId", "SlotName", "SlotVariableNodeId" }, null, null, null, null)})}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RegisteredSaver), global::Tensorflow.RegisteredSaver.Parser, new[]{ "Name", "ObjectName" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class TrackableObjectGraph : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrackableObjectGraph()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObjectGraph() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObjectGraph(TrackableObjectGraph other) : this() { + nodes_ = other.nodes_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObjectGraph Clone() { + return new TrackableObjectGraph(this); + } + + /// Field number for the "nodes" field. + public const int NodesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_nodes_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Parser); + private readonly pbc::RepeatedField nodes_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Nodes { + get { return nodes_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TrackableObjectGraph); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TrackableObjectGraph other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!nodes_.Equals(other.nodes_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= nodes_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + nodes_.WriteTo(output, _repeated_nodes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + nodes_.WriteTo(ref output, _repeated_nodes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += nodes_.CalculateSize(_repeated_nodes_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TrackableObjectGraph other) { + if (other == null) { + return; + } + nodes_.Add(other.nodes_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + nodes_.AddEntriesFrom(input, _repeated_nodes_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + nodes_.AddEntriesFrom(ref input, _repeated_nodes_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TrackableObjectGraph message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class TrackableObject : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrackableObject()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraph.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObject() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObject(TrackableObject other) : this() { + children_ = other.children_.Clone(); + attributes_ = other.attributes_.Clone(); + slotVariables_ = other.slotVariables_.Clone(); + registeredSaver_ = other.registeredSaver_ != null ? other.registeredSaver_.Clone() : null; + HasCheckpointValues = other.HasCheckpointValues; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TrackableObject Clone() { + return new TrackableObject(this); + } + + /// Field number for the "children" field. + public const int ChildrenFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_children_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); + private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + /// + /// Objects which this object depends on. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Children { + get { return children_; } + } + + /// Field number for the "attributes" field. + public const int AttributesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_attributes_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor.Parser); + private readonly pbc::RepeatedField attributes_ = new pbc::RepeatedField(); + /// + /// Serialized data specific to this object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Attributes { + get { return attributes_; } + } + + /// Field number for the "slot_variables" field. + public const int SlotVariablesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_slotVariables_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference.Parser); + private readonly pbc::RepeatedField slotVariables_ = new pbc::RepeatedField(); + /// + /// Slot variables owned by this object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SlotVariables { + get { return slotVariables_; } + } + + /// Field number for the "registered_saver" field. + public const int RegisteredSaverFieldNumber = 4; + private global::Tensorflow.RegisteredSaver registeredSaver_; + /// + /// The registered saver used to save this object. If this saver is not + /// present when loading the checkpoint, then loading will fail. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.RegisteredSaver RegisteredSaver { + get { return registeredSaver_; } + set { + registeredSaver_ = value; + } + } + + /// Field number for the "has_checkpoint_values" field. + public const int HasCheckpointValuesFieldNumber = 5; + private static readonly pb::FieldCodec _single_hasCheckpointValues_codec = pb::FieldCodec.ForStructWrapper(42); + private bool? hasCheckpointValues_; + /// + /// Whether this object has checkpoint values or descendants with checkpoint + /// values. This is computed at save time to avoid traversing the entire + /// object graph proto when restoring (which also has to traverse the live + /// object graph). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool? HasCheckpointValues { + get { return hasCheckpointValues_; } + set { + hasCheckpointValues_ = value; + } + } + + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TrackableObject); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TrackableObject other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!children_.Equals(other.children_)) return false; + if(!attributes_.Equals(other.attributes_)) return false; + if(!slotVariables_.Equals(other.slotVariables_)) return false; + if (!object.Equals(RegisteredSaver, other.RegisteredSaver)) return false; + if (HasCheckpointValues != other.HasCheckpointValues) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= children_.GetHashCode(); + hash ^= attributes_.GetHashCode(); + hash ^= slotVariables_.GetHashCode(); + if (registeredSaver_ != null) hash ^= RegisteredSaver.GetHashCode(); + if (hasCheckpointValues_ != null) hash ^= HasCheckpointValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + children_.WriteTo(output, _repeated_children_codec); + attributes_.WriteTo(output, _repeated_attributes_codec); + slotVariables_.WriteTo(output, _repeated_slotVariables_codec); + if (registeredSaver_ != null) { + output.WriteRawTag(34); + output.WriteMessage(RegisteredSaver); + } + if (hasCheckpointValues_ != null) { + _single_hasCheckpointValues_codec.WriteTagAndValue(output, HasCheckpointValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + children_.WriteTo(ref output, _repeated_children_codec); + attributes_.WriteTo(ref output, _repeated_attributes_codec); + slotVariables_.WriteTo(ref output, _repeated_slotVariables_codec); + if (registeredSaver_ != null) { + output.WriteRawTag(34); + output.WriteMessage(RegisteredSaver); + } + if (hasCheckpointValues_ != null) { + _single_hasCheckpointValues_codec.WriteTagAndValue(ref output, HasCheckpointValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += children_.CalculateSize(_repeated_children_codec); + size += attributes_.CalculateSize(_repeated_attributes_codec); + size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); + if (registeredSaver_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RegisteredSaver); + } + if (hasCheckpointValues_ != null) { + size += _single_hasCheckpointValues_codec.CalculateSizeWithTag(HasCheckpointValues); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TrackableObject other) { + if (other == null) { + return; + } + children_.Add(other.children_); + attributes_.Add(other.attributes_); + slotVariables_.Add(other.slotVariables_); + if (other.registeredSaver_ != null) { + if (registeredSaver_ == null) { + RegisteredSaver = new global::Tensorflow.RegisteredSaver(); + } + RegisteredSaver.MergeFrom(other.RegisteredSaver); + } + if (other.hasCheckpointValues_ != null) { + if (hasCheckpointValues_ == null || other.HasCheckpointValues != false) { + HasCheckpointValues = other.HasCheckpointValues; + } + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + children_.AddEntriesFrom(input, _repeated_children_codec); + break; + } + case 18: { + attributes_.AddEntriesFrom(input, _repeated_attributes_codec); + break; + } + case 26: { + slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); + break; + } + case 34: { + if (registeredSaver_ == null) { + RegisteredSaver = new global::Tensorflow.RegisteredSaver(); + } + input.ReadMessage(RegisteredSaver); + break; + } + case 42: { + bool? value = _single_hasCheckpointValues_codec.Read(input); + if (hasCheckpointValues_ == null || value != false) { + HasCheckpointValues = value; + } + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + children_.AddEntriesFrom(ref input, _repeated_children_codec); + break; + } + case 18: { + attributes_.AddEntriesFrom(ref input, _repeated_attributes_codec); + break; + } + case 26: { + slotVariables_.AddEntriesFrom(ref input, _repeated_slotVariables_codec); + break; + } + case 34: { + if (registeredSaver_ == null) { + RegisteredSaver = new global::Tensorflow.RegisteredSaver(); + } + input.ReadMessage(RegisteredSaver); + break; + } + case 42: { + bool? value = _single_hasCheckpointValues_codec.Read(ref input); + if (hasCheckpointValues_ == null || value != false) { + HasCheckpointValues = value; + } + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TrackableObject message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class ObjectReference : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ObjectReference()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ObjectReference() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ObjectReference(ObjectReference other) : this() { + nodeId_ = other.nodeId_; + localName_ = other.localName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ObjectReference Clone() { + return new ObjectReference(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 1; + private int nodeId_; + /// + /// An index into `TrackableObjectGraph.nodes`, indicating the object + /// being referenced. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "local_name" field. + public const int LocalNameFieldNumber = 2; + private string localName_ = ""; + /// + /// A user-provided name for the edge. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string LocalName { + get { return localName_; } + set { + localName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ObjectReference); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ObjectReference other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if (LocalName != other.LocalName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (LocalName.Length != 0) hash ^= LocalName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + if (LocalName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(LocalName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (NodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(NodeId); + } + if (LocalName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(LocalName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (LocalName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(LocalName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ObjectReference other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.LocalName.Length != 0) { + LocalName = other.LocalName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + LocalName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + NodeId = input.ReadInt32(); + break; + } + case 18: { + LocalName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class SerializedTensor : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SerializedTensor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedTensor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedTensor(SerializedTensor other) : this() { + name_ = other.name_; + fullName_ = other.fullName_; + checkpointKey_ = other.checkpointKey_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedTensor Clone() { + return new SerializedTensor(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// A name for the Tensor. Simple variables have only one + /// `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may + /// be restored on object creation as an optimization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "full_name" field. + public const int FullNameFieldNumber = 2; + private string fullName_ = ""; + /// + /// The full name of the variable/tensor, if applicable. Used to allow + /// name-based loading of checkpoints which were saved using an + /// object-based API. Should match the checkpoint key which would have been + /// assigned by tf.train.Saver. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FullName { + get { return fullName_; } + set { + fullName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "checkpoint_key" field. + public const int CheckpointKeyFieldNumber = 3; + private string checkpointKey_ = ""; + /// + /// The generated name of the Tensor in the checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string CheckpointKey { + get { return checkpointKey_; } + set { + checkpointKey_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SerializedTensor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SerializedTensor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (FullName != other.FullName) return false; + if (CheckpointKey != other.CheckpointKey) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (FullName.Length != 0) hash ^= FullName.GetHashCode(); + if (CheckpointKey.Length != 0) hash ^= CheckpointKey.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (FullName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(FullName); + } + if (CheckpointKey.Length != 0) { + output.WriteRawTag(26); + output.WriteString(CheckpointKey); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (FullName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(FullName); + } + if (CheckpointKey.Length != 0) { + output.WriteRawTag(26); + output.WriteString(CheckpointKey); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (FullName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FullName); + } + if (CheckpointKey.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(CheckpointKey); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SerializedTensor other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.FullName.Length != 0) { + FullName = other.FullName; + } + if (other.CheckpointKey.Length != 0) { + CheckpointKey = other.CheckpointKey; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + FullName = input.ReadString(); + break; + } + case 26: { + CheckpointKey = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + FullName = input.ReadString(); + break; + } + case 26: { + CheckpointKey = input.ReadString(); + break; + } + } + } + } + #endif + + } + + public sealed partial class SlotVariableReference : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SlotVariableReference()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Descriptor.NestedTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SlotVariableReference() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SlotVariableReference(SlotVariableReference other) : this() { + originalVariableNodeId_ = other.originalVariableNodeId_; + slotName_ = other.slotName_; + slotVariableNodeId_ = other.slotVariableNodeId_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SlotVariableReference Clone() { + return new SlotVariableReference(this); + } + + /// Field number for the "original_variable_node_id" field. + public const int OriginalVariableNodeIdFieldNumber = 1; + private int originalVariableNodeId_; + /// + /// An index into `TrackableObjectGraph.nodes`, indicating the + /// variable object this slot was created for. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int OriginalVariableNodeId { + get { return originalVariableNodeId_; } + set { + originalVariableNodeId_ = value; + } + } + + /// Field number for the "slot_name" field. + public const int SlotNameFieldNumber = 2; + private string slotName_ = ""; + /// + /// The name of the slot (e.g. "m"/"v"). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string SlotName { + get { return slotName_; } + set { + slotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "slot_variable_node_id" field. + public const int SlotVariableNodeIdFieldNumber = 3; + private int slotVariableNodeId_; + /// + /// An index into `TrackableObjectGraph.nodes`, indicating the + /// `Object` with the value of the slot variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int SlotVariableNodeId { + get { return slotVariableNodeId_; } + set { + slotVariableNodeId_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SlotVariableReference); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SlotVariableReference other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (OriginalVariableNodeId != other.OriginalVariableNodeId) return false; + if (SlotName != other.SlotName) return false; + if (SlotVariableNodeId != other.SlotVariableNodeId) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (OriginalVariableNodeId != 0) hash ^= OriginalVariableNodeId.GetHashCode(); + if (SlotName.Length != 0) hash ^= SlotName.GetHashCode(); + if (SlotVariableNodeId != 0) hash ^= SlotVariableNodeId.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (OriginalVariableNodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(OriginalVariableNodeId); + } + if (SlotName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(SlotName); + } + if (SlotVariableNodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(SlotVariableNodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (OriginalVariableNodeId != 0) { + output.WriteRawTag(8); + output.WriteInt32(OriginalVariableNodeId); + } + if (SlotName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(SlotName); + } + if (SlotVariableNodeId != 0) { + output.WriteRawTag(24); + output.WriteInt32(SlotVariableNodeId); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (OriginalVariableNodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(OriginalVariableNodeId); + } + if (SlotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SlotName); + } + if (SlotVariableNodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(SlotVariableNodeId); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SlotVariableReference other) { + if (other == null) { + return; + } + if (other.OriginalVariableNodeId != 0) { + OriginalVariableNodeId = other.OriginalVariableNodeId; + } + if (other.SlotName.Length != 0) { + SlotName = other.SlotName; + } + if (other.SlotVariableNodeId != 0) { + SlotVariableNodeId = other.SlotVariableNodeId; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + OriginalVariableNodeId = input.ReadInt32(); + break; + } + case 18: { + SlotName = input.ReadString(); + break; + } + case 24: { + SlotVariableNodeId = input.ReadInt32(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + OriginalVariableNodeId = input.ReadInt32(); + break; + } + case 18: { + SlotName = input.ReadString(); + break; + } + case 24: { + SlotVariableNodeId = input.ReadInt32(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + } + #endregion + + } + + public sealed partial class RegisteredSaver : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegisteredSaver()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TrackableObjectGraphReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredSaver() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredSaver(RegisteredSaver other) : this() { + name_ = other.name_; + objectName_ = other.objectName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public RegisteredSaver Clone() { + return new RegisteredSaver(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// The name of the registered saver/restore function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "object_name" field. + public const int ObjectNameFieldNumber = 2; + private string objectName_ = ""; + /// + /// Unique auto-generated name of the object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ObjectName { + get { return objectName_; } + set { + objectName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as RegisteredSaver); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(RegisteredSaver other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ObjectName != other.ObjectName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ObjectName.Length != 0) hash ^= ObjectName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ObjectName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ObjectName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ObjectName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ObjectName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ObjectName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ObjectName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(RegisteredSaver other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ObjectName.Length != 0) { + ObjectName = other.ObjectName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + ObjectName = input.ReadString(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + ObjectName = input.ReadString(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Types.cs b/src/TensorFlowNET.Core/Protobuf/Types.cs index 887ff3223..a2d3bac5d 100644 --- a/src/TensorFlowNET.Core/Protobuf/Types.cs +++ b/src/TensorFlowNET.Core/Protobuf/Types.cs @@ -1,8 +1,8 @@ // // Generated by the protocol buffer compiler. DO NOT EDIT! -// source: types.proto +// source: tensorflow/core/framework/types.proto // -#pragma warning disable 1591, 0612, 3021 +#pragma warning disable 1591, 0612, 3021, 8981 #region Designer generated code using pb = global::Google.Protobuf; @@ -11,11 +11,11 @@ using scg = global::System.Collections.Generic; namespace Tensorflow { - /// Holder for reflection information generated from types.proto + /// Holder for reflection information generated from tensorflow/core/framework/types.proto public static partial class TypesReflection { #region Descriptor - /// File descriptor for types.proto + /// File descriptor for tensorflow/core/framework/types.proto public static pbr::FileDescriptor Descriptor { get { return descriptor; } } @@ -24,37 +24,42 @@ public static partial class TypesReflection { static TypesReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "Cgt0eXBlcy5wcm90bxIKdGVuc29yZmxvdyqqBgoIRGF0YVR5cGUSDgoKRFRf", - "SU5WQUxJRBAAEgwKCERUX0ZMT0FUEAESDQoJRFRfRE9VQkxFEAISDAoIRFRf", - "SU5UMzIQAxIMCghEVF9VSU5UOBAEEgwKCERUX0lOVDE2EAUSCwoHRFRfSU5U", - "OBAGEg0KCURUX1NUUklORxAHEhAKDERUX0NPTVBMRVg2NBAIEgwKCERUX0lO", - "VDY0EAkSCwoHRFRfQk9PTBAKEgwKCERUX1FJTlQ4EAsSDQoJRFRfUVVJTlQ4", - "EAwSDQoJRFRfUUlOVDMyEA0SDwoLRFRfQkZMT0FUMTYQDhINCglEVF9RSU5U", - "MTYQDxIOCgpEVF9RVUlOVDE2EBASDQoJRFRfVUlOVDE2EBESEQoNRFRfQ09N", - "UExFWDEyOBASEgsKB0RUX0hBTEYQExIPCgtEVF9SRVNPVVJDRRAUEg4KCkRU", - "X1ZBUklBTlQQFRINCglEVF9VSU5UMzIQFhINCglEVF9VSU5UNjQQFxIQCgxE", - "VF9GTE9BVF9SRUYQZRIRCg1EVF9ET1VCTEVfUkVGEGYSEAoMRFRfSU5UMzJf", - "UkVGEGcSEAoMRFRfVUlOVDhfUkVGEGgSEAoMRFRfSU5UMTZfUkVGEGkSDwoL", - "RFRfSU5UOF9SRUYQahIRCg1EVF9TVFJJTkdfUkVGEGsSFAoQRFRfQ09NUExF", - "WDY0X1JFRhBsEhAKDERUX0lOVDY0X1JFRhBtEg8KC0RUX0JPT0xfUkVGEG4S", - "EAoMRFRfUUlOVDhfUkVGEG8SEQoNRFRfUVVJTlQ4X1JFRhBwEhEKDURUX1FJ", - "TlQzMl9SRUYQcRITCg9EVF9CRkxPQVQxNl9SRUYQchIRCg1EVF9RSU5UMTZf", - "UkVGEHMSEgoORFRfUVVJTlQxNl9SRUYQdBIRCg1EVF9VSU5UMTZfUkVGEHUS", - "FQoRRFRfQ09NUExFWDEyOF9SRUYQdhIPCgtEVF9IQUxGX1JFRhB3EhMKD0RU", - "X1JFU09VUkNFX1JFRhB4EhIKDkRUX1ZBUklBTlRfUkVGEHkSEQoNRFRfVUlO", - "VDMyX1JFRhB6EhEKDURUX1VJTlQ2NF9SRUYQe0JrChhvcmcudGVuc29yZmxv", - "dy5mcmFtZXdvcmtCC1R5cGVzUHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29y", - "Zmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4", - "AQFiBnByb3RvMw==")); + "CiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3R5cGVzLnByb3RvEgp0ZW5z", + "b3JmbG93IjkKD1NlcmlhbGl6ZWREVHlwZRImCghkYXRhdHlwZRgBIAEoDjIU", + "LnRlbnNvcmZsb3cuRGF0YVR5cGUqqgYKCERhdGFUeXBlEg4KCkRUX0lOVkFM", + "SUQQABIMCghEVF9GTE9BVBABEg0KCURUX0RPVUJMRRACEgwKCERUX0lOVDMy", + "EAMSDAoIRFRfVUlOVDgQBBIMCghEVF9JTlQxNhAFEgsKB0RUX0lOVDgQBhIN", + "CglEVF9TVFJJTkcQBxIQCgxEVF9DT01QTEVYNjQQCBIMCghEVF9JTlQ2NBAJ", + "EgsKB0RUX0JPT0wQChIMCghEVF9RSU5UOBALEg0KCURUX1FVSU5UOBAMEg0K", + "CURUX1FJTlQzMhANEg8KC0RUX0JGTE9BVDE2EA4SDQoJRFRfUUlOVDE2EA8S", + "DgoKRFRfUVVJTlQxNhAQEg0KCURUX1VJTlQxNhAREhEKDURUX0NPTVBMRVgx", + "MjgQEhILCgdEVF9IQUxGEBMSDwoLRFRfUkVTT1VSQ0UQFBIOCgpEVF9WQVJJ", + "QU5UEBUSDQoJRFRfVUlOVDMyEBYSDQoJRFRfVUlOVDY0EBcSEAoMRFRfRkxP", + "QVRfUkVGEGUSEQoNRFRfRE9VQkxFX1JFRhBmEhAKDERUX0lOVDMyX1JFRhBn", + "EhAKDERUX1VJTlQ4X1JFRhBoEhAKDERUX0lOVDE2X1JFRhBpEg8KC0RUX0lO", + "VDhfUkVGEGoSEQoNRFRfU1RSSU5HX1JFRhBrEhQKEERUX0NPTVBMRVg2NF9S", + "RUYQbBIQCgxEVF9JTlQ2NF9SRUYQbRIPCgtEVF9CT09MX1JFRhBuEhAKDERU", + "X1FJTlQ4X1JFRhBvEhEKDURUX1FVSU5UOF9SRUYQcBIRCg1EVF9RSU5UMzJf", + "UkVGEHESEwoPRFRfQkZMT0FUMTZfUkVGEHISEQoNRFRfUUlOVDE2X1JFRhBz", + "EhIKDkRUX1FVSU5UMTZfUkVGEHQSEQoNRFRfVUlOVDE2X1JFRhB1EhUKEURU", + "X0NPTVBMRVgxMjhfUkVGEHYSDwoLRFRfSEFMRl9SRUYQdxITCg9EVF9SRVNP", + "VVJDRV9SRUYQeBISCg5EVF9WQVJJQU5UX1JFRhB5EhEKDURUX1VJTlQzMl9S", + "RUYQehIRCg1EVF9VSU5UNjRfUkVGEHtCegoYb3JnLnRlbnNvcmZsb3cuZnJh", + "bWV3b3JrQgtUeXBlc1Byb3Rvc1ABWkxnaXRodWIuY29tL3RlbnNvcmZsb3cv", + "dGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL3R5cGVz", + "X2dvX3Byb3Rv+AEBYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataType), }, null)); + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataType), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SerializedDType), global::Tensorflow.SerializedDType.Parser, new[]{ "Datatype" }, null, null, null, null) + })); } #endregion } #region Enums /// + /// (== suppress_warning documentation-presence ==) /// LINT.IfChange /// public enum DataType { @@ -148,6 +153,201 @@ public enum DataType { #endregion + #region Messages + /// + /// Represents a serialized tf.dtypes.Dtype + /// + public sealed partial class SerializedDType : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SerializedDType()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TypesReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedDType() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedDType(SerializedDType other) : this() { + datatype_ = other.datatype_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SerializedDType Clone() { + return new SerializedDType(this); + } + + /// Field number for the "datatype" field. + public const int DatatypeFieldNumber = 1; + private global::Tensorflow.DataType datatype_ = global::Tensorflow.DataType.DtInvalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.DataType Datatype { + get { return datatype_; } + set { + datatype_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SerializedDType); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SerializedDType other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Datatype != other.Datatype) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Datatype != global::Tensorflow.DataType.DtInvalid) hash ^= Datatype.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Datatype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Datatype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Datatype != global::Tensorflow.DataType.DtInvalid) { + output.WriteRawTag(8); + output.WriteEnum((int) Datatype); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Datatype != global::Tensorflow.DataType.DtInvalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Datatype); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SerializedDType other) { + if (other == null) { + return; + } + if (other.Datatype != global::Tensorflow.DataType.DtInvalid) { + Datatype = other.Datatype; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Datatype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Datatype = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + #endregion + } #endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Variable.cs b/src/TensorFlowNET.Core/Protobuf/Variable.cs new file mode 100644 index 000000000..1bb8f0120 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Variable.cs @@ -0,0 +1,930 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/variable.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/variable.proto + public static partial class VariableReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/variable.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VariableReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZhcmlhYmxlLnByb3RvEgp0", + "ZW5zb3JmbG93IsgCCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg", + "ASgJEhoKEmluaXRpYWxfdmFsdWVfbmFtZRgGIAEoCRIYChBpbml0aWFsaXpl", + "cl9uYW1lGAIgASgJEhUKDXNuYXBzaG90X25hbWUYAyABKAkSOQoTc2F2ZV9z", + "bGljZV9pbmZvX2RlZhgEIAEoCzIcLnRlbnNvcmZsb3cuU2F2ZVNsaWNlSW5m", + "b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgS", + "PAoPc3luY2hyb25pemF0aW9uGAggASgOMiMudGVuc29yZmxvdy5WYXJpYWJs", + "ZVN5bmNocm9uaXphdGlvbhI0CgthZ2dyZWdhdGlvbhgJIAEoDjIfLnRlbnNv", + "cmZsb3cuVmFyaWFibGVBZ2dyZWdhdGlvbiJgChBTYXZlU2xpY2VJbmZvRGVm", + "EhEKCWZ1bGxfbmFtZRgBIAEoCRISCgpmdWxsX3NoYXBlGAIgAygDEhIKCnZh", + "cl9vZmZzZXQYAyADKAMSEQoJdmFyX3NoYXBlGAQgAygDKqwBChdWYXJpYWJs", + "ZVN5bmNocm9uaXphdGlvbhIhCh1WQVJJQUJMRV9TWU5DSFJPTklaQVRJT05f", + "QVVUTxAAEiEKHVZBUklBQkxFX1NZTkNIUk9OSVpBVElPTl9OT05FEAESJQoh", + "VkFSSUFCTEVfU1lOQ0hST05JWkFUSU9OX09OX1dSSVRFEAISJAogVkFSSUFC", + "TEVfU1lOQ0hST05JWkFUSU9OX09OX1JFQUQQAyqeAQoTVmFyaWFibGVBZ2dy", + "ZWdhdGlvbhIdChlWQVJJQUJMRV9BR0dSRUdBVElPTl9OT05FEAASHAoYVkFS", + "SUFCTEVfQUdHUkVHQVRJT05fU1VNEAESHQoZVkFSSUFCTEVfQUdHUkVHQVRJ", + "T05fTUVBThACEisKJ1ZBUklBQkxFX0FHR1JFR0FUSU9OX09OTFlfRklSU1Rf", + "UkVQTElDQRADQoABChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDlZhcmlh", + "YmxlUHJvdG9zUAFaT2dpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93", + "L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmsvdmFyaWFibGVfZ29fcHJv", + "dG/4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.VariableSynchronization), typeof(global::Tensorflow.VariableAggregation), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable", "Synchronization", "Aggregation" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveSliceInfoDef), global::Tensorflow.SaveSliceInfoDef.Parser, new[]{ "FullName", "FullShape", "VarOffset", "VarShape" }, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Indicates when a distributed variable will be synced. + /// + public enum VariableSynchronization { + /// + /// `AUTO`: Indicates that the synchronization will be determined by the + /// current `DistributionStrategy` (eg. With `MirroredStrategy` this would be + /// `ON_WRITE`). + /// + [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_AUTO")] Auto = 0, + /// + /// `NONE`: Indicates that there will only be one copy of the variable, so + /// there is no need to sync. + /// + [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_NONE")] None = 1, + /// + /// `ON_WRITE`: Indicates that the variable will be updated across devices + /// every time it is written. + /// + [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_ON_WRITE")] OnWrite = 2, + /// + /// `ON_READ`: Indicates that the variable will be aggregated across devices + /// when it is read (eg. when checkpointing or when evaluating an op that uses + /// the variable). + /// + [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_ON_READ")] OnRead = 3, + } + + /// + /// Indicates how a distributed variable will be aggregated. + /// + public enum VariableAggregation { + /// + /// `NONE`: This is the default, giving an error if you use a + /// variable-update operation with multiple replicas. + /// + [pbr::OriginalName("VARIABLE_AGGREGATION_NONE")] None = 0, + /// + /// `SUM`: Add the updates across replicas. + /// + [pbr::OriginalName("VARIABLE_AGGREGATION_SUM")] Sum = 1, + /// + /// `MEAN`: Take the arithmetic mean ("average") of the updates across + /// replicas. + /// + [pbr::OriginalName("VARIABLE_AGGREGATION_MEAN")] Mean = 2, + /// + /// `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same + /// update, but we only want to perform the update once. Used, e.g., for the + /// global step counter. + /// + [pbr::OriginalName("VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA")] OnlyFirstReplica = 3, + } + + #endregion + + #region Messages + /// + /// Protocol buffer representing a Variable. + /// + public sealed partial class VariableDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VariableDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VariableDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VariableDef(VariableDef other) : this() { + variableName_ = other.variableName_; + initialValueName_ = other.initialValueName_; + initializerName_ = other.initializerName_; + snapshotName_ = other.snapshotName_; + saveSliceInfoDef_ = other.saveSliceInfoDef_ != null ? other.saveSliceInfoDef_.Clone() : null; + isResource_ = other.isResource_; + trainable_ = other.trainable_; + synchronization_ = other.synchronization_; + aggregation_ = other.aggregation_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VariableDef Clone() { + return new VariableDef(this); + } + + /// Field number for the "variable_name" field. + public const int VariableNameFieldNumber = 1; + private string variableName_ = ""; + /// + /// Name of the variable tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string VariableName { + get { return variableName_; } + set { + variableName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "initial_value_name" field. + public const int InitialValueNameFieldNumber = 6; + private string initialValueName_ = ""; + /// + /// Name of the tensor holding the variable's initial value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string InitialValueName { + get { return initialValueName_; } + set { + initialValueName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "initializer_name" field. + public const int InitializerNameFieldNumber = 2; + private string initializerName_ = ""; + /// + /// Name of the initializer op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string InitializerName { + get { return initializerName_; } + set { + initializerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "snapshot_name" field. + public const int SnapshotNameFieldNumber = 3; + private string snapshotName_ = ""; + /// + /// Name of the snapshot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string SnapshotName { + get { return snapshotName_; } + set { + snapshotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "save_slice_info_def" field. + public const int SaveSliceInfoDefFieldNumber = 4; + private global::Tensorflow.SaveSliceInfoDef saveSliceInfoDef_; + /// + /// Support for saving variables as slices of a larger variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.SaveSliceInfoDef SaveSliceInfoDef { + get { return saveSliceInfoDef_; } + set { + saveSliceInfoDef_ = value; + } + } + + /// Field number for the "is_resource" field. + public const int IsResourceFieldNumber = 5; + private bool isResource_; + /// + /// Whether to represent this as a ResourceVariable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool IsResource { + get { return isResource_; } + set { + isResource_ = value; + } + } + + /// Field number for the "trainable" field. + public const int TrainableFieldNumber = 7; + private bool trainable_; + /// + /// Whether this variable should be trained. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Trainable { + get { return trainable_; } + set { + trainable_ = value; + } + } + + /// Field number for the "synchronization" field. + public const int SynchronizationFieldNumber = 8; + private global::Tensorflow.VariableSynchronization synchronization_ = global::Tensorflow.VariableSynchronization.Auto; + /// + /// Indicates when a distributed variable will be synced. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VariableSynchronization Synchronization { + get { return synchronization_; } + set { + synchronization_ = value; + } + } + + /// Field number for the "aggregation" field. + public const int AggregationFieldNumber = 9; + private global::Tensorflow.VariableAggregation aggregation_ = global::Tensorflow.VariableAggregation.None; + /// + /// Indicates how a distributed variable will be aggregated. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VariableAggregation Aggregation { + get { return aggregation_; } + set { + aggregation_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as VariableDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(VariableDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (VariableName != other.VariableName) return false; + if (InitialValueName != other.InitialValueName) return false; + if (InitializerName != other.InitializerName) return false; + if (SnapshotName != other.SnapshotName) return false; + if (!object.Equals(SaveSliceInfoDef, other.SaveSliceInfoDef)) return false; + if (IsResource != other.IsResource) return false; + if (Trainable != other.Trainable) return false; + if (Synchronization != other.Synchronization) return false; + if (Aggregation != other.Aggregation) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (VariableName.Length != 0) hash ^= VariableName.GetHashCode(); + if (InitialValueName.Length != 0) hash ^= InitialValueName.GetHashCode(); + if (InitializerName.Length != 0) hash ^= InitializerName.GetHashCode(); + if (SnapshotName.Length != 0) hash ^= SnapshotName.GetHashCode(); + if (saveSliceInfoDef_ != null) hash ^= SaveSliceInfoDef.GetHashCode(); + if (IsResource != false) hash ^= IsResource.GetHashCode(); + if (Trainable != false) hash ^= Trainable.GetHashCode(); + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) hash ^= Synchronization.GetHashCode(); + if (Aggregation != global::Tensorflow.VariableAggregation.None) hash ^= Aggregation.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (VariableName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(VariableName); + } + if (InitializerName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(InitializerName); + } + if (SnapshotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SnapshotName); + } + if (saveSliceInfoDef_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SaveSliceInfoDef); + } + if (IsResource != false) { + output.WriteRawTag(40); + output.WriteBool(IsResource); + } + if (InitialValueName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(InitialValueName); + } + if (Trainable != false) { + output.WriteRawTag(56); + output.WriteBool(Trainable); + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + output.WriteRawTag(64); + output.WriteEnum((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + output.WriteRawTag(72); + output.WriteEnum((int) Aggregation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (VariableName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(VariableName); + } + if (InitializerName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(InitializerName); + } + if (SnapshotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SnapshotName); + } + if (saveSliceInfoDef_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SaveSliceInfoDef); + } + if (IsResource != false) { + output.WriteRawTag(40); + output.WriteBool(IsResource); + } + if (InitialValueName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(InitialValueName); + } + if (Trainable != false) { + output.WriteRawTag(56); + output.WriteBool(Trainable); + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + output.WriteRawTag(64); + output.WriteEnum((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + output.WriteRawTag(72); + output.WriteEnum((int) Aggregation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (VariableName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VariableName); + } + if (InitialValueName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InitialValueName); + } + if (InitializerName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(InitializerName); + } + if (SnapshotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SnapshotName); + } + if (saveSliceInfoDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SaveSliceInfoDef); + } + if (IsResource != false) { + size += 1 + 1; + } + if (Trainable != false) { + size += 1 + 1; + } + if (Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Synchronization); + } + if (Aggregation != global::Tensorflow.VariableAggregation.None) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Aggregation); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(VariableDef other) { + if (other == null) { + return; + } + if (other.VariableName.Length != 0) { + VariableName = other.VariableName; + } + if (other.InitialValueName.Length != 0) { + InitialValueName = other.InitialValueName; + } + if (other.InitializerName.Length != 0) { + InitializerName = other.InitializerName; + } + if (other.SnapshotName.Length != 0) { + SnapshotName = other.SnapshotName; + } + if (other.saveSliceInfoDef_ != null) { + if (saveSliceInfoDef_ == null) { + SaveSliceInfoDef = new global::Tensorflow.SaveSliceInfoDef(); + } + SaveSliceInfoDef.MergeFrom(other.SaveSliceInfoDef); + } + if (other.IsResource != false) { + IsResource = other.IsResource; + } + if (other.Trainable != false) { + Trainable = other.Trainable; + } + if (other.Synchronization != global::Tensorflow.VariableSynchronization.Auto) { + Synchronization = other.Synchronization; + } + if (other.Aggregation != global::Tensorflow.VariableAggregation.None) { + Aggregation = other.Aggregation; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + VariableName = input.ReadString(); + break; + } + case 18: { + InitializerName = input.ReadString(); + break; + } + case 26: { + SnapshotName = input.ReadString(); + break; + } + case 34: { + if (saveSliceInfoDef_ == null) { + SaveSliceInfoDef = new global::Tensorflow.SaveSliceInfoDef(); + } + input.ReadMessage(SaveSliceInfoDef); + break; + } + case 40: { + IsResource = input.ReadBool(); + break; + } + case 50: { + InitialValueName = input.ReadString(); + break; + } + case 56: { + Trainable = input.ReadBool(); + break; + } + case 64: { + Synchronization = (global::Tensorflow.VariableSynchronization) input.ReadEnum(); + break; + } + case 72: { + Aggregation = (global::Tensorflow.VariableAggregation) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + VariableName = input.ReadString(); + break; + } + case 18: { + InitializerName = input.ReadString(); + break; + } + case 26: { + SnapshotName = input.ReadString(); + break; + } + case 34: { + if (saveSliceInfoDef_ == null) { + SaveSliceInfoDef = new global::Tensorflow.SaveSliceInfoDef(); + } + input.ReadMessage(SaveSliceInfoDef); + break; + } + case 40: { + IsResource = input.ReadBool(); + break; + } + case 50: { + InitialValueName = input.ReadString(); + break; + } + case 56: { + Trainable = input.ReadBool(); + break; + } + case 64: { + Synchronization = (global::Tensorflow.VariableSynchronization) input.ReadEnum(); + break; + } + case 72: { + Aggregation = (global::Tensorflow.VariableAggregation) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + public sealed partial class SaveSliceInfoDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaveSliceInfoDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VariableReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveSliceInfoDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveSliceInfoDef(SaveSliceInfoDef other) : this() { + fullName_ = other.fullName_; + fullShape_ = other.fullShape_.Clone(); + varOffset_ = other.varOffset_.Clone(); + varShape_ = other.varShape_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SaveSliceInfoDef Clone() { + return new SaveSliceInfoDef(this); + } + + /// Field number for the "full_name" field. + public const int FullNameFieldNumber = 1; + private string fullName_ = ""; + /// + /// Name of the full variable of which this is a slice. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string FullName { + get { return fullName_; } + set { + fullName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "full_shape" field. + public const int FullShapeFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_fullShape_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField fullShape_ = new pbc::RepeatedField(); + /// + /// Shape of the full variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField FullShape { + get { return fullShape_; } + } + + /// Field number for the "var_offset" field. + public const int VarOffsetFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_varOffset_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField varOffset_ = new pbc::RepeatedField(); + /// + /// Offset of this variable into the full variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField VarOffset { + get { return varOffset_; } + } + + /// Field number for the "var_shape" field. + public const int VarShapeFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_varShape_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField varShape_ = new pbc::RepeatedField(); + /// + /// Shape of this variable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField VarShape { + get { return varShape_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SaveSliceInfoDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SaveSliceInfoDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FullName != other.FullName) return false; + if(!fullShape_.Equals(other.fullShape_)) return false; + if(!varOffset_.Equals(other.varOffset_)) return false; + if(!varShape_.Equals(other.varShape_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (FullName.Length != 0) hash ^= FullName.GetHashCode(); + hash ^= fullShape_.GetHashCode(); + hash ^= varOffset_.GetHashCode(); + hash ^= varShape_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (FullName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FullName); + } + fullShape_.WriteTo(output, _repeated_fullShape_codec); + varOffset_.WriteTo(output, _repeated_varOffset_codec); + varShape_.WriteTo(output, _repeated_varShape_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (FullName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FullName); + } + fullShape_.WriteTo(ref output, _repeated_fullShape_codec); + varOffset_.WriteTo(ref output, _repeated_varOffset_codec); + varShape_.WriteTo(ref output, _repeated_varShape_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (FullName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FullName); + } + size += fullShape_.CalculateSize(_repeated_fullShape_codec); + size += varOffset_.CalculateSize(_repeated_varOffset_codec); + size += varShape_.CalculateSize(_repeated_varShape_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SaveSliceInfoDef other) { + if (other == null) { + return; + } + if (other.FullName.Length != 0) { + FullName = other.FullName; + } + fullShape_.Add(other.fullShape_); + varOffset_.Add(other.varOffset_); + varShape_.Add(other.varShape_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FullName = input.ReadString(); + break; + } + case 18: + case 16: { + fullShape_.AddEntriesFrom(input, _repeated_fullShape_codec); + break; + } + case 26: + case 24: { + varOffset_.AddEntriesFrom(input, _repeated_varOffset_codec); + break; + } + case 34: + case 32: { + varShape_.AddEntriesFrom(input, _repeated_varShape_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + FullName = input.ReadString(); + break; + } + case 18: + case 16: { + fullShape_.AddEntriesFrom(ref input, _repeated_fullShape_codec); + break; + } + case 26: + case 24: { + varOffset_.AddEntriesFrom(ref input, _repeated_varOffset_codec); + break; + } + case 34: + case 32: { + varShape_.AddEntriesFrom(ref input, _repeated_varShape_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/VerifierConfig.cs b/src/TensorFlowNET.Core/Protobuf/VerifierConfig.cs new file mode 100644 index 000000000..904196b1f --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/VerifierConfig.cs @@ -0,0 +1,300 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/verifier_config.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/verifier_config.proto + public static partial class VerifierConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/verifier_config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VerifierConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci50ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvdmVyaWZpZXJfY29uZmlnLnBy", + "b3RvEgp0ZW5zb3JmbG93IpsBCg5WZXJpZmllckNvbmZpZxIiChp2ZXJpZmlj", + "YXRpb25fdGltZW91dF9pbl9tcxgBIAEoAxI9ChJzdHJ1Y3R1cmVfdmVyaWZp", + "ZXIYAiABKA4yIS50ZW5zb3JmbG93LlZlcmlmaWVyQ29uZmlnLlRvZ2dsZSIm", + "CgZUb2dnbGUSCwoHREVGQVVMVBAAEgYKAk9OEAESBwoDT0ZGEAJCjAEKGG9y", + "Zy50ZW5zb3JmbG93LmZyYW1ld29ya0IUVmVyaWZpZXJDb25maWdQcm90b3NQ", + "AVpVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxv", + "dy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3JlX3Byb3Rvc19nb19wcm90b/gB", + "AWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VerifierConfig), global::Tensorflow.VerifierConfig.Parser, new[]{ "VerificationTimeoutInMs", "StructureVerifier" }, null, new[]{ typeof(global::Tensorflow.VerifierConfig.Types.Toggle) }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// The config for graph verifiers. + /// + public sealed partial class VerifierConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VerifierConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VerifierConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VerifierConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VerifierConfig(VerifierConfig other) : this() { + verificationTimeoutInMs_ = other.verificationTimeoutInMs_; + structureVerifier_ = other.structureVerifier_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VerifierConfig Clone() { + return new VerifierConfig(this); + } + + /// Field number for the "verification_timeout_in_ms" field. + public const int VerificationTimeoutInMsFieldNumber = 1; + private long verificationTimeoutInMs_; + /// + /// Deadline for completion of all verification i.e. all the Toggle ON + /// verifiers must complete execution within this time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long VerificationTimeoutInMs { + get { return verificationTimeoutInMs_; } + set { + verificationTimeoutInMs_ = value; + } + } + + /// Field number for the "structure_verifier" field. + public const int StructureVerifierFieldNumber = 2; + private global::Tensorflow.VerifierConfig.Types.Toggle structureVerifier_ = global::Tensorflow.VerifierConfig.Types.Toggle.Default; + /// + /// Perform structural validation on a tensorflow graph. Default is OFF. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Tensorflow.VerifierConfig.Types.Toggle StructureVerifier { + get { return structureVerifier_; } + set { + structureVerifier_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as VerifierConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(VerifierConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (VerificationTimeoutInMs != other.VerificationTimeoutInMs) return false; + if (StructureVerifier != other.StructureVerifier) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (VerificationTimeoutInMs != 0L) hash ^= VerificationTimeoutInMs.GetHashCode(); + if (StructureVerifier != global::Tensorflow.VerifierConfig.Types.Toggle.Default) hash ^= StructureVerifier.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (VerificationTimeoutInMs != 0L) { + output.WriteRawTag(8); + output.WriteInt64(VerificationTimeoutInMs); + } + if (StructureVerifier != global::Tensorflow.VerifierConfig.Types.Toggle.Default) { + output.WriteRawTag(16); + output.WriteEnum((int) StructureVerifier); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (VerificationTimeoutInMs != 0L) { + output.WriteRawTag(8); + output.WriteInt64(VerificationTimeoutInMs); + } + if (StructureVerifier != global::Tensorflow.VerifierConfig.Types.Toggle.Default) { + output.WriteRawTag(16); + output.WriteEnum((int) StructureVerifier); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (VerificationTimeoutInMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(VerificationTimeoutInMs); + } + if (StructureVerifier != global::Tensorflow.VerifierConfig.Types.Toggle.Default) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) StructureVerifier); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(VerifierConfig other) { + if (other == null) { + return; + } + if (other.VerificationTimeoutInMs != 0L) { + VerificationTimeoutInMs = other.VerificationTimeoutInMs; + } + if (other.StructureVerifier != global::Tensorflow.VerifierConfig.Types.Toggle.Default) { + StructureVerifier = other.StructureVerifier; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + VerificationTimeoutInMs = input.ReadInt64(); + break; + } + case 16: { + StructureVerifier = (global::Tensorflow.VerifierConfig.Types.Toggle) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + VerificationTimeoutInMs = input.ReadInt64(); + break; + } + case 16: { + StructureVerifier = (global::Tensorflow.VerifierConfig.Types.Toggle) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the VerifierConfig message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Toggle { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("ON")] On = 1, + [pbr::OriginalName("OFF")] Off = 2, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Versions.cs b/src/TensorFlowNET.Core/Protobuf/Versions.cs new file mode 100644 index 000000000..d3e9fc512 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Versions.cs @@ -0,0 +1,324 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/versions.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/versions.proto + public static partial class VersionsReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/versions.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VersionsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZlcnNpb25zLnByb3RvEgp0", + "ZW5zb3JmbG93IksKClZlcnNpb25EZWYSEAoIcHJvZHVjZXIYASABKAUSFAoM", + "bWluX2NvbnN1bWVyGAIgASgFEhUKDWJhZF9jb25zdW1lcnMYAyADKAVCgAEK", + "GG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IOVmVyc2lvbnNQcm90b3NQAVpP", + "Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9n", + "by9jb3JlL2ZyYW1ld29yay92ZXJzaW9uc19nb19wcm90b/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Version information for a piece of serialized data + /// + /// There are different types of versions for each type of data + /// (GraphDef, etc.), but they all have the same common shape + /// described here. + /// + /// Each consumer has "consumer" and "min_producer" versions (specified + /// elsewhere). A consumer is allowed to consume this data if + /// + /// producer >= min_producer + /// consumer >= min_consumer + /// consumer not in bad_consumers + /// + public sealed partial class VersionDef : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VersionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VersionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VersionDef(VersionDef other) : this() { + producer_ = other.producer_; + minConsumer_ = other.minConsumer_; + badConsumers_ = other.badConsumers_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public VersionDef Clone() { + return new VersionDef(this); + } + + /// Field number for the "producer" field. + public const int ProducerFieldNumber = 1; + private int producer_; + /// + /// The version of the code that produced this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int Producer { + get { return producer_; } + set { + producer_ = value; + } + } + + /// Field number for the "min_consumer" field. + public const int MinConsumerFieldNumber = 2; + private int minConsumer_; + /// + /// Any consumer below this version is not allowed to consume this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int MinConsumer { + get { return minConsumer_; } + set { + minConsumer_ = value; + } + } + + /// Field number for the "bad_consumers" field. + public const int BadConsumersFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_badConsumers_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField badConsumers_ = new pbc::RepeatedField(); + /// + /// Specific consumer versions which are disallowed (e.g. due to bugs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField BadConsumers { + get { return badConsumers_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as VersionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(VersionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Producer != other.Producer) return false; + if (MinConsumer != other.MinConsumer) return false; + if(!badConsumers_.Equals(other.badConsumers_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Producer != 0) hash ^= Producer.GetHashCode(); + if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); + hash ^= badConsumers_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Producer != 0) { + output.WriteRawTag(8); + output.WriteInt32(Producer); + } + if (MinConsumer != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinConsumer); + } + badConsumers_.WriteTo(output, _repeated_badConsumers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Producer != 0) { + output.WriteRawTag(8); + output.WriteInt32(Producer); + } + if (MinConsumer != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinConsumer); + } + badConsumers_.WriteTo(ref output, _repeated_badConsumers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Producer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); + } + if (MinConsumer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); + } + size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(VersionDef other) { + if (other == null) { + return; + } + if (other.Producer != 0) { + Producer = other.Producer; + } + if (other.MinConsumer != 0) { + MinConsumer = other.MinConsumer; + } + badConsumers_.Add(other.badConsumers_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Producer = input.ReadInt32(); + break; + } + case 16: { + MinConsumer = input.ReadInt32(); + break; + } + case 26: + case 24: { + badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Producer = input.ReadInt32(); + break; + } + case 16: { + MinConsumer = input.ReadInt32(); + break; + } + case 26: + case 24: { + badConsumers_.AddEntriesFrom(ref input, _repeated_badConsumers_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Xla.cs b/src/TensorFlowNET.Core/Protobuf/Xla.cs new file mode 100644 index 000000000..24f46594c --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Xla.cs @@ -0,0 +1,12788 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/xla.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla { + + /// Holder for reflection information generated from tensorflow/compiler/xla/xla.proto + public static partial class XlaReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/xla.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static XlaReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiF0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS94bGEucHJvdG8SA3hsYRopdGVu", + "c29yZmxvdy9jb21waWxlci94bGEvc2VydmljZS9obG8ucHJvdG8aJnRlbnNv", + "cmZsb3cvY29tcGlsZXIveGxhL3hsYV9kYXRhLnByb3RvIscdCgxEZWJ1Z09w", + "dGlvbnMSHwoXeGxhX2hsb19ncmFwaF9hZGRyZXNzZXMYAiABKAgSFwoPeGxh", + "X2hsb19wcm9maWxlGAkgASgIEh4KFnhsYV9kaXNhYmxlX2hsb19wYXNzZXMY", + "HiADKAkSIgoaeGxhX2VuYWJsZV9obG9fcGFzc2VzX29ubHkYfCADKAkSIgoa", + "eGxhX2Rpc2FibGVfYWxsX2hsb19wYXNzZXMYaCABKAgSJgoeeGxhX2JhY2tl", + "bmRfb3B0aW1pemF0aW9uX2xldmVsGB8gASgFEiIKGnhsYV9lbWJlZF9pcl9p", + "bl9leGVjdXRhYmxlGCEgASgIEiwKJHhsYV9lbGltaW5hdGVfaGxvX2ltcGxp", + "Y2l0X2Jyb2FkY2FzdBgjIAEoCBIiChp4bGFfY3B1X211bHRpX3RocmVhZF9l", + "aWdlbhg8IAEoCBIdChV4bGFfZ3B1X2N1ZGFfZGF0YV9kaXIYPSABKAkSEwoL", + "eGxhX2dwdV9mdHoYPiABKAgSLAokeGxhX2xsdm1fZW5hYmxlX2FsaWFzX3Nj", + "b3BlX21ldGFkYXRhGEYgASgIEigKIHhsYV9sbHZtX2VuYWJsZV9ub2FsaWFz", + "X21ldGFkYXRhGEcgASgIEi8KJ3hsYV9sbHZtX2VuYWJsZV9pbnZhcmlhbnRf", + "bG9hZF9tZXRhZGF0YRhIIAEoCBIpCiF4bGFfbGx2bV9kaXNhYmxlX2V4cGVu", + "c2l2ZV9wYXNzZXMYSSABKAgSIwobeGxhX3Rlc3RfYWxsX291dHB1dF9sYXlv", + "dXRzGFogASgIEiIKGnhsYV90ZXN0X2FsbF9pbnB1dF9sYXlvdXRzGFsgASgI", + "EiQKHHhsYV9obG9fZ3JhcGhfc2hhcmRpbmdfY29sb3IYXCABKAgSGwoTeGxh", + "X2NwdV91c2VfbWtsX2RubhhhIAEoCBIgChd4bGFfY3B1X3VzZV94bGFfcnVu", + "dGltZRixASABKAgSKAogeGxhX2dwdV9tYXhfa2VybmVsX3Vucm9sbF9mYWN0", + "b3IYYiABKAUSIAoYeGxhX2NwdV9lbmFibGVfZmFzdF9tYXRoGGMgASgIEiQK", + "HHhsYV9jcHVfZmFzdF9tYXRoX2hvbm9yX25hbnMYeCABKAgSJAoceGxhX2Nw", + "dV9mYXN0X21hdGhfaG9ub3JfaW5mcxh5IAEoCBIoCiB4bGFfY3B1X2Zhc3Rf", + "bWF0aF9ob25vcl9kaXZpc2lvbhh+IAEoCBIqCiF4bGFfY3B1X2Zhc3RfbWF0", + "aF9ob25vcl9mdW5jdGlvbnMYgQEgASgIEiQKG3hsYV9jcHVfZW5hYmxlX2Zh", + "c3RfbWluX21heBiMASABKAgSIwobeGxhX2dwdV9lbmFibGVfZmFzdF9taW5f", + "bWF4GGQgASgIEiIKGnhsYV9hbGxvd19leGNlc3NfcHJlY2lzaW9uGHogASgI", + "Ei4KJnhsYV9ncHVfY3Jhc2hfb25fdmVyaWZpY2F0aW9uX2ZhaWx1cmVzGGUg", + "ASgIEh4KFnhsYV9ncHVfYXV0b3R1bmVfbGV2ZWwYeyABKAUSLAokeGxhX2Zv", + "cmNlX2hvc3RfcGxhdGZvcm1fZGV2aWNlX2NvdW50GGYgASgFEiwKJHhsYV9n", + "cHVfZGlzYWJsZV9ncHVhc21fb3B0aW1pemF0aW9ucxhnIAEoCBI8ChR4bGFf", + "Z3B1X3NoYXBlX2NoZWNrcxiqASABKA4yHS54bGEuRGVidWdPcHRpb25zLlNo", + "YXBlQ2hlY2tzEiUKHHhsYV9jcHVfZW5hYmxlX21saXJfbG93ZXJpbmcYqwEg", + "ASgIEiUKHHhsYV9ncHVfZW5hYmxlX21saXJfbG93ZXJpbmcYrQEgASgIEicK", + "H3hsYV9obG9fZXZhbHVhdG9yX3VzZV9mYXN0X3BhdGgYaiABKAgSKgoieGxh", + "X2FsbG93X3NjYWxhcl9pbmRleF9keW5hbWljX29wcxhrIAEoCBJGChh4bGFf", + "c3RlcF9tYXJrZXJfbG9jYXRpb24YbCABKA4yJC54bGEuRGVidWdPcHRpb25z", + "LlN0ZXBNYXJrZXJMb2NhdGlvbhITCgt4bGFfZHVtcF90bxhtIAEoCRIeChZ4", + "bGFfZHVtcF9obG9fbW9kdWxlX3JlGG4gASgJEhwKFHhsYV9kdW1wX2hsb19w", + "YXNzX3JlGG8gASgJEhwKFHhsYV9kdW1wX2hsb19hc190ZXh0GHAgASgIEh0K", + "FXhsYV9kdW1wX2hsb19hc19wcm90bxhxIAEoCBIbChN4bGFfZHVtcF9obG9f", + "YXNfZG90GHIgASgIEhsKE3hsYV9kdW1wX2hsb19hc191cmwYcyABKAgSHAoU", + "eGxhX2R1bXBfaGxvX2FzX2h0bWwYdCABKAgSJgodeGxhX2R1bXBfZnVzaW9u", + "X3Zpc3VhbGl6YXRpb24YlQEgASgIEh4KFnhsYV9kdW1wX2hsb19zbmFwc2hv", + "dHMYdiABKAgSIwoaeGxhX2R1bXBfaW5jbHVkZV90aW1lc3RhbXAYgwEgASgI", + "EiEKGHhsYV9kdW1wX21heF9obG9fbW9kdWxlcxiEASABKAUSIQoYeGxhX2R1", + "bXBfbW9kdWxlX21ldGFkYXRhGJABIAEoCBIhChh4bGFfZHVtcF9jb21wcmVz", + "c19wcm90b3MYlwEgASgIEiIKGXhsYV9kdW1wX2hsb19hc19sb25nX3RleHQY", + "pAEgASgIEh8KF3hsYV9ncHVfZm9yY2VfY29udl9uY2h3GH0gASgIEiAKF3hs", + "YV9ncHVfZm9yY2VfY29udl9uaHdjGJIBIAEoCBIYChB4bGFfZ3B1X3B0eF9m", + "aWxlGH8gAygJEhwKE3hsYV9ncHVfZHVtcF9sbHZtaXIYmwEgASgIEigKH3hs", + "YV9ncHVfYWxnb3JpdGhtX2RlbnlsaXN0X3BhdGgYgAEgASgJEhsKEnhsYV90", + "cHVfZGV0ZWN0X25hbhiHASABKAgSGwoSeGxhX3RwdV9kZXRlY3RfaW5mGIgB", + "IAEoCBIlChx4bGFfY3B1X2VuYWJsZV94cHJvZl90cmFjZW1lGIkBIAEoCBI9", + "CjR4bGFfZ3B1X3Vuc2FmZV9mYWxsYmFja190b19kcml2ZXJfb25fcHR4YXNf", + "bm90X2ZvdW5kGIoBIAEoCBIgChd4bGFfZ3B1X2FzbV9leHRyYV9mbGFncxiN", + "ASABKAkSLwomeGxhX211bHRpaGVhcF9zaXplX2NvbnN0cmFpbnRfcGVyX2hl", + "YXAYjgEgASgFEikKIHhsYV9kZXRhaWxlZF9sb2dnaW5nX2FuZF9kdW1waW5n", + "GI8BIAEoCBIuCiV4bGFfZ3B1X2ZvcmNlX2NvbXBpbGF0aW9uX3BhcmFsbGVs", + "aXNtGJMBIAEoBRIiChl4bGFfZ3B1X2RldGVybWluaXN0aWNfb3BzGJQBIAEo", + "CBIdChR4bGFfZ3B1X2xsdm1faXJfZmlsZRiWASADKAkSKAofeGxhX2dwdV9l", + "bmFibGVfYXN5bmNfYWxsX3JlZHVjZRiYASABKAgSMwoqeGxhX2dwdV9hbGxf", + "cmVkdWNlX2NvbWJpbmVfdGhyZXNob2xkX2J5dGVzGJ0BIAEoAxImCh14bGFf", + "Z3B1X2FsbF9yZWR1Y2VfY29udGlndW91cxieASABKAgSPAozeGxhX2dwdV9h", + "bGxfcmVkdWNlX2JsdWVjb25uZWN0X251bV9kZXZpY2VzX3Blcl9ob3N0GJ8B", + "IAEoBRImCh14bGFfZ3B1X2VuYWJsZV9jdWRubl9mcm9udGVuZBigASABKAgS", + "IgoZeGxhX2R1bXBfZGlzYWJsZV9tZXRhZGF0YRiZASABKAgSIQoYeGxhX2R1", + "bXBfaGxvX3BpcGVsaW5lX3JlGJoBIAEoCRItCiR4bGFfZ3B1X3N0cmljdF9j", + "b252X2FsZ29yaXRobV9waWNrZXIYnAEgASgIEi4KJXhsYV9ncHVfZW5hYmxl", + "X3hsYV9ydW50aW1lX2V4ZWN1dGFibGUYqQEgASgIEjEKKHhsYV9ncHVfbmNj", + "bF90ZXJtaW5hdGlvbl90aW1lb3V0X3NlY29uZHMYowEgASgDEigKH3hsYV9n", + "cHVfZW5hYmxlX3NoYXJlZF9jb25zdGFudHMYpQEgASgIEiAKF3hsYV9ncHVf", + "ZW5hYmxlX2N1Ymxhc2x0GKYBIAEoCBIuCiV4bGFfZ3B1X3JlZHpvbmVfc2Ny", + "YXRjaF9tYXhfbWVnYWJ5dGVzGKcBIAEoAxIsCiN4bGFfZ3B1X3NpbXBsaWZ5", + "X2FsbF9mcF9jb252ZXJzaW9ucxioASABKAgSIgoZeGxhX2dwdV9ub3JtYWxp", + "emVfbGF5b3V0cxisASABKAgSGAoPeGxhX2NwdV91c2VfYWNsGK4BIAEoCBIl", + "Chx4bGFfY3B1X3N0cmljdF9kb3RfY29udl9tYXRoGK8BIAEoCBJRChl4bGFf", + "YmFja2VuZF9leHRyYV9vcHRpb25zGPQDIAMoCzItLnhsYS5EZWJ1Z09wdGlv", + "bnMuWGxhQmFja2VuZEV4dHJhT3B0aW9uc0VudHJ5Gj0KG1hsYUJhY2tlbmRF", + "eHRyYU9wdGlvbnNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6", + "AjgBIjgKC1NoYXBlQ2hlY2tzEgoKBklHTk9SRRAAEgsKB1JVTlRJTUUQARIQ", + "CgxDT01QSUxFX1RJTUUQAiKRAQoSU3RlcE1hcmtlckxvY2F0aW9uEhYKElNU", + "RVBfTUFSS19BVF9FTlRSWRAAEiUKIVNURVBfTUFSS19BVF9UT1BfTEVWRUxf", + "V0hJTEVfTE9PUBABEigKJFNURVBfTUFSS19BVF9TRUNPTkRfTEVWRUxfV0hJ", + "TEVfTE9PUBADEhIKDlNURVBfTUFSS19OT05FEAJKBAg/EEBKBgiGARCHAUoE", + "CFAQUUoECF0QXkoECF4QX0oGCIIBEIMBSgYIoQEQogFKBgiiARCjAUoECAUQ", + "BkoECHUQdkoGCIUBEIYBSgYIiwEQjAFKBgiwARCxAUoGCLIBELMBIqsEChBF", + "eGVjdXRpb25PcHRpb25zEjEKGHNoYXBlX3dpdGhfb3V0cHV0X2xheW91dBgC", + "IAEoCzIPLnhsYS5TaGFwZVByb3RvEgwKBHNlZWQYAyABKAQSKAoNZGVidWdf", + "b3B0aW9ucxgEIAEoCzIRLnhsYS5EZWJ1Z09wdGlvbnMSKQoOZGV2aWNlX2hh", + "bmRsZXMYBSADKAsyES54bGEuRGV2aWNlSGFuZGxlEhQKDG51bV9yZXBsaWNh", + "cxgGIAEoBRI1ChFkZXZpY2VfYXNzaWdubWVudBgHIAEoCzIaLnhsYS5EZXZp", + "Y2VBc3NpZ25tZW50UHJvdG8SIAoYYWxpYXNfcGFzc3Rocm91Z2hfcGFyYW1z", + "GAggASgIEhYKDm51bV9wYXJ0aXRpb25zGAkgASgFEhEKCWxhdW5jaF9pZBgK", + "IAEoBRIdChV1c2Vfc3BtZF9wYXJ0aXRpb25pbmcYCyABKAgSIgoadXNlX2F1", + "dG9fc3BtZF9wYXJ0aXRpb25pbmcYDyABKAgSKQohYXV0b19zcG1kX3BhcnRp", + "dGlvbmluZ19tZXNoX3NoYXBlGBAgAygDEicKH2F1dG9fc3BtZF9wYXJ0aXRp", + "b25pbmdfbWVzaF9pZHMYESADKAMSFwoPZGVkdXBsaWNhdGVfaGxvGAwgASgI", + "EjEKKWFsbG93X3NwbWRfc2hhcmRpbmdfcHJvcGFnYXRpb25fdG9fb3V0cHV0", + "GA4gASgISgQIDRAOIi8KF0dldERldmljZUhhbmRsZXNSZXF1ZXN0EhQKDGRl", + "dmljZV9jb3VudBgBIAEoAyJFChhHZXREZXZpY2VIYW5kbGVzUmVzcG9uc2US", + "KQoOZGV2aWNlX2hhbmRsZXMYASADKAsyES54bGEuRGV2aWNlSGFuZGxlImoK", + "F1RyYW5zZmVyVG9DbGllbnRSZXF1ZXN0EiMKBGRhdGEYASABKAsyFS54bGEu", + "R2xvYmFsRGF0YUhhbmRsZRIqChFzaGFwZV93aXRoX2xheW91dBgCIAEoCzIP", + "LnhsYS5TaGFwZVByb3RvIj4KGFRyYW5zZmVyVG9DbGllbnRSZXNwb25zZRIi", + "CgdsaXRlcmFsGAEgASgLMhEueGxhLkxpdGVyYWxQcm90byJnChdUcmFuc2Zl", + "clRvU2VydmVyUmVxdWVzdBIiCgdsaXRlcmFsGAEgASgLMhEueGxhLkxpdGVy", + "YWxQcm90bxIoCg1kZXZpY2VfaGFuZGxlGAIgASgLMhEueGxhLkRldmljZUhh", + "bmRsZSI/ChhUcmFuc2ZlclRvU2VydmVyUmVzcG9uc2USIwoEZGF0YRgBIAEo", + "CzIVLnhsYS5HbG9iYWxEYXRhSGFuZGxlInsKF1RyYW5zZmVyVG9JbmZlZWRS", + "ZXF1ZXN0EiIKB2xpdGVyYWwYASABKAsyES54bGEuTGl0ZXJhbFByb3RvEhIK", + "CnJlcGxpY2FfaWQYAiABKAMSKAoNZGV2aWNlX2hhbmRsZRgDIAEoCzIRLnhs", + "YS5EZXZpY2VIYW5kbGUiGgoYVHJhbnNmZXJUb0luZmVlZFJlc3BvbnNlIoYB", + "ChpUcmFuc2ZlckZyb21PdXRmZWVkUmVxdWVzdBIqChFzaGFwZV93aXRoX2xh", + "eW91dBgBIAEoCzIPLnhsYS5TaGFwZVByb3RvEhIKCnJlcGxpY2FfaWQYAiAB", + "KAMSKAoNZGV2aWNlX2hhbmRsZRgDIAEoCzIRLnhsYS5EZXZpY2VIYW5kbGUi", + "QQobVHJhbnNmZXJGcm9tT3V0ZmVlZFJlc3BvbnNlEiIKB2xpdGVyYWwYASAB", + "KAsyES54bGEuTGl0ZXJhbFByb3RvIj4KElJlc2V0RGV2aWNlUmVxdWVzdBIo", + "Cg1kZXZpY2VfaGFuZGxlGAEgASgLMhEueGxhLkRldmljZUhhbmRsZSIVChNS", + "ZXNldERldmljZVJlc3BvbnNlInIKHENvbXB1dGF0aW9uR3JhcGhTdGF0c1Jl", + "cXVlc3QSKAoLY29tcHV0YXRpb24YASABKAsyEy54bGEuSGxvTW9kdWxlUHJv", + "dG8SKAoNZGVidWdfb3B0aW9ucxgCIAEoCzIRLnhsYS5EZWJ1Z09wdGlvbnMi", + "QAoYQ29tcHV0YXRpb25TdGF0c1Jlc3BvbnNlEiQKBXN0YXRzGAEgASgLMhUu", + "eGxhLkNvbXB1dGF0aW9uU3RhdHMiUgoaQ3JlYXRlQ2hhbm5lbEhhbmRsZVJl", + "cXVlc3QSNAoMY2hhbm5lbF90eXBlGAEgASgOMh4ueGxhLkNoYW5uZWxIYW5k", + "bGUuQ2hhbm5lbFR5cGUiQgobQ3JlYXRlQ2hhbm5lbEhhbmRsZVJlc3BvbnNl", + "EiMKB2NoYW5uZWwYASABKAsyEi54bGEuQ2hhbm5lbEhhbmRsZSI4ChFVbnJl", + "Z2lzdGVyUmVxdWVzdBIjCgRkYXRhGAEgAygLMhUueGxhLkdsb2JhbERhdGFI", + "YW5kbGUiFAoSVW5yZWdpc3RlclJlc3BvbnNlIp4BCg5Db21waWxlUmVxdWVz", + "dBIoCgtjb21wdXRhdGlvbhgBIAEoCzITLnhsYS5IbG9Nb2R1bGVQcm90bxIw", + "ChFleGVjdXRpb25fb3B0aW9ucxgCIAEoCzIVLnhsYS5FeGVjdXRpb25PcHRp", + "b25zEjAKF2lucHV0X3NoYXBlX3dpdGhfbGF5b3V0GAMgAygLMg8ueGxhLlNo", + "YXBlUHJvdG8iNwoPQ29tcGlsZVJlc3BvbnNlEiQKBmhhbmRsZRgBIAEoCzIU", + "LnhsYS5FeGVjdXRpb25IYW5kbGUiYAoORXhlY3V0ZVJlcXVlc3QSJAoGaGFu", + "ZGxlGAEgASgLMhQueGxhLkV4ZWN1dGlvbkhhbmRsZRIoCglhcmd1bWVudHMY", + "AiADKAsyFS54bGEuR2xvYmFsRGF0YUhhbmRsZSKbAQoTRXhlY3V0ZUdyYXBo", + "UmVxdWVzdBIoCgtjb21wdXRhdGlvbhgBIAEoCzITLnhsYS5IbG9Nb2R1bGVQ", + "cm90bxIoCglhcmd1bWVudHMYAiADKAsyFS54bGEuR2xvYmFsRGF0YUhhbmRs", + "ZRIwChFleGVjdXRpb25fb3B0aW9ucxgDIAEoCzIVLnhsYS5FeGVjdXRpb25P", + "cHRpb25zIkkKG0V4ZWN1dGVHcmFwaFBhcmFsbGVsUmVxdWVzdBIqCghyZXF1", + "ZXN0cxgBIAMoCzIYLnhsYS5FeGVjdXRlR3JhcGhSZXF1ZXN0ImAKD0V4ZWN1", + "dGVSZXNwb25zZRIlCgZvdXRwdXQYASABKAsyFS54bGEuR2xvYmFsRGF0YUhh", + "bmRsZRImCgdwcm9maWxlGAIgASgLMhUueGxhLkV4ZWN1dGlvblByb2ZpbGUi", + "QgoXRXhlY3V0ZVBhcmFsbGVsUmVzcG9uc2USJwoJcmVzcG9uc2VzGAEgAygL", + "MhQueGxhLkV4ZWN1dGVSZXNwb25zZSJCChdXYWl0Rm9yRXhlY3V0aW9uUmVx", + "dWVzdBInCglleGVjdXRpb24YASABKAsyFC54bGEuRXhlY3V0aW9uSGFuZGxl", + "ImkKGFdhaXRGb3JFeGVjdXRpb25SZXNwb25zZRIlCgZvdXRwdXQYASABKAsy", + "FS54bGEuR2xvYmFsRGF0YUhhbmRsZRImCgdwcm9maWxlGAIgASgLMhUueGxh", + "LkV4ZWN1dGlvblByb2ZpbGUicAobQ29tcHV0ZUNvbnN0YW50R3JhcGhSZXF1", + "ZXN0EigKC2NvbXB1dGF0aW9uGAEgASgLMhMueGxhLkhsb01vZHVsZVByb3Rv", + "EicKDW91dHB1dF9sYXlvdXQYAiABKAsyEC54bGEuTGF5b3V0UHJvdG8iPQoX", + "Q29tcHV0ZUNvbnN0YW50UmVzcG9uc2USIgoHbGl0ZXJhbBgBIAEoCzIRLnhs", + "YS5MaXRlcmFsUHJvdG8iRgoXRGVjb25zdHJ1Y3RUdXBsZVJlcXVlc3QSKwoM", + "dHVwbGVfaGFuZGxlGAIgASgLMhUueGxhLkdsb2JhbERhdGFIYW5kbGUiSgoY", + "RGVjb25zdHJ1Y3RUdXBsZVJlc3BvbnNlEi4KD2VsZW1lbnRfaGFuZGxlcxgB", + "IAMoCzIVLnhsYS5HbG9iYWxEYXRhSGFuZGxlIpsBCg9Mb2FkRGF0YVJlcXVl", + "c3QSHAoUY29sdW1uaW9fdGFibGV0X3BhdGgYASABKAkSFgoOY29sdW1uaW9f", + "ZmllbGQYAiABKAkSJgoNZWxlbWVudF9zaGFwZRgDIAEoCzIPLnhsYS5TaGFw", + "ZVByb3RvEg4KBm9mZnNldBgEIAEoAxINCgVsaW1pdBgFIAEoAxILCgN6aXAY", + "BiABKAgingEKEExvYWREYXRhUmVzcG9uc2USIwoEZGF0YRgBIAEoCzIVLnhs", + "YS5HbG9iYWxEYXRhSGFuZGxlEiMKCmRhdGFfc2hhcGUYAiABKAsyDy54bGEu", + "U2hhcGVQcm90bxIWCg5hdmFpbGFibGVfcm93cxgDIAEoAxITCgtyb3dzX2xv", + "YWRlZBgEIAEoAxITCgtuYW5vc2Vjb25kcxgFIAEoAyI2Cg9HZXRTaGFwZVJl", + "cXVlc3QSIwoEZGF0YRgBIAEoCzIVLnhsYS5HbG9iYWxEYXRhSGFuZGxlIjIK", + "EEdldFNoYXBlUmVzcG9uc2USHgoFc2hhcGUYASABKAsyDy54bGEuU2hhcGVQ", + "cm90byI0Cg1VbnBhY2tSZXF1ZXN0EiMKBGRhdGEYASABKAsyFS54bGEuR2xv", + "YmFsRGF0YUhhbmRsZSI6Cg5VbnBhY2tSZXNwb25zZRIoCgl0aWVkX2RhdGEY", + "ASADKAsyFS54bGEuR2xvYmFsRGF0YUhhbmRsZWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Xla.HloReflection.Descriptor, global::Xla.XlaDataReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DebugOptions), global::Xla.DebugOptions.Parser, new[]{ "XlaHloGraphAddresses", "XlaHloProfile", "XlaDisableHloPasses", "XlaEnableHloPassesOnly", "XlaDisableAllHloPasses", "XlaBackendOptimizationLevel", "XlaEmbedIrInExecutable", "XlaEliminateHloImplicitBroadcast", "XlaCpuMultiThreadEigen", "XlaGpuCudaDataDir", "XlaGpuFtz", "XlaLlvmEnableAliasScopeMetadata", "XlaLlvmEnableNoaliasMetadata", "XlaLlvmEnableInvariantLoadMetadata", "XlaLlvmDisableExpensivePasses", "XlaTestAllOutputLayouts", "XlaTestAllInputLayouts", "XlaHloGraphShardingColor", "XlaCpuUseMklDnn", "XlaCpuUseXlaRuntime", "XlaGpuMaxKernelUnrollFactor", "XlaCpuEnableFastMath", "XlaCpuFastMathHonorNans", "XlaCpuFastMathHonorInfs", "XlaCpuFastMathHonorDivision", "XlaCpuFastMathHonorFunctions", "XlaCpuEnableFastMinMax", "XlaGpuEnableFastMinMax", "XlaAllowExcessPrecision", "XlaGpuCrashOnVerificationFailures", "XlaGpuAutotuneLevel", "XlaForceHostPlatformDeviceCount", "XlaGpuDisableGpuasmOptimizations", "XlaGpuShapeChecks", "XlaCpuEnableMlirLowering", "XlaGpuEnableMlirLowering", "XlaHloEvaluatorUseFastPath", "XlaAllowScalarIndexDynamicOps", "XlaStepMarkerLocation", "XlaDumpTo", "XlaDumpHloModuleRe", "XlaDumpHloPassRe", "XlaDumpHloAsText", "XlaDumpHloAsProto", "XlaDumpHloAsDot", "XlaDumpHloAsUrl", "XlaDumpHloAsHtml", "XlaDumpFusionVisualization", "XlaDumpHloSnapshots", "XlaDumpIncludeTimestamp", "XlaDumpMaxHloModules", "XlaDumpModuleMetadata", "XlaDumpCompressProtos", "XlaDumpHloAsLongText", "XlaGpuForceConvNchw", "XlaGpuForceConvNhwc", "XlaGpuPtxFile", "XlaGpuDumpLlvmir", "XlaGpuAlgorithmDenylistPath", "XlaTpuDetectNan", "XlaTpuDetectInf", "XlaCpuEnableXprofTraceme", "XlaGpuUnsafeFallbackToDriverOnPtxasNotFound", "XlaGpuAsmExtraFlags", "XlaMultiheapSizeConstraintPerHeap", "XlaDetailedLoggingAndDumping", "XlaGpuForceCompilationParallelism", "XlaGpuDeterministicOps", "XlaGpuLlvmIrFile", "XlaGpuEnableAsyncAllReduce", "XlaGpuAllReduceCombineThresholdBytes", "XlaGpuAllReduceContiguous", "XlaGpuAllReduceBlueconnectNumDevicesPerHost", "XlaGpuEnableCudnnFrontend", "XlaDumpDisableMetadata", "XlaDumpHloPipelineRe", "XlaGpuStrictConvAlgorithmPicker", "XlaGpuEnableXlaRuntimeExecutable", "XlaGpuNcclTerminationTimeoutSeconds", "XlaGpuEnableSharedConstants", "XlaGpuEnableCublaslt", "XlaGpuRedzoneScratchMaxMegabytes", "XlaGpuSimplifyAllFpConversions", "XlaGpuNormalizeLayouts", "XlaCpuUseAcl", "XlaCpuStrictDotConvMath", "XlaBackendExtraOptions" }, null, new[]{ typeof(global::Xla.DebugOptions.Types.ShapeChecks), typeof(global::Xla.DebugOptions.Types.StepMarkerLocation) }, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecutionOptions), global::Xla.ExecutionOptions.Parser, new[]{ "ShapeWithOutputLayout", "Seed", "DebugOptions", "DeviceHandles", "NumReplicas", "DeviceAssignment", "AliasPassthroughParams", "NumPartitions", "LaunchId", "UseSpmdPartitioning", "UseAutoSpmdPartitioning", "AutoSpmdPartitioningMeshShape", "AutoSpmdPartitioningMeshIds", "DeduplicateHlo", "AllowSpmdShardingPropagationToOutput" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GetDeviceHandlesRequest), global::Xla.GetDeviceHandlesRequest.Parser, new[]{ "DeviceCount" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GetDeviceHandlesResponse), global::Xla.GetDeviceHandlesResponse.Parser, new[]{ "DeviceHandles" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToClientRequest), global::Xla.TransferToClientRequest.Parser, new[]{ "Data", "ShapeWithLayout" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToClientResponse), global::Xla.TransferToClientResponse.Parser, new[]{ "Literal" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToServerRequest), global::Xla.TransferToServerRequest.Parser, new[]{ "Literal", "DeviceHandle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToServerResponse), global::Xla.TransferToServerResponse.Parser, new[]{ "Data" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToInfeedRequest), global::Xla.TransferToInfeedRequest.Parser, new[]{ "Literal", "ReplicaId", "DeviceHandle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferToInfeedResponse), global::Xla.TransferToInfeedResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferFromOutfeedRequest), global::Xla.TransferFromOutfeedRequest.Parser, new[]{ "ShapeWithLayout", "ReplicaId", "DeviceHandle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TransferFromOutfeedResponse), global::Xla.TransferFromOutfeedResponse.Parser, new[]{ "Literal" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ResetDeviceRequest), global::Xla.ResetDeviceRequest.Parser, new[]{ "DeviceHandle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ResetDeviceResponse), global::Xla.ResetDeviceResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ComputationGraphStatsRequest), global::Xla.ComputationGraphStatsRequest.Parser, new[]{ "Computation", "DebugOptions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ComputationStatsResponse), global::Xla.ComputationStatsResponse.Parser, new[]{ "Stats" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CreateChannelHandleRequest), global::Xla.CreateChannelHandleRequest.Parser, new[]{ "ChannelType" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CreateChannelHandleResponse), global::Xla.CreateChannelHandleResponse.Parser, new[]{ "Channel" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.UnregisterRequest), global::Xla.UnregisterRequest.Parser, new[]{ "Data" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.UnregisterResponse), global::Xla.UnregisterResponse.Parser, null, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CompileRequest), global::Xla.CompileRequest.Parser, new[]{ "Computation", "ExecutionOptions", "InputShapeWithLayout" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CompileResponse), global::Xla.CompileResponse.Parser, new[]{ "Handle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecuteRequest), global::Xla.ExecuteRequest.Parser, new[]{ "Handle", "Arguments" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecuteGraphRequest), global::Xla.ExecuteGraphRequest.Parser, new[]{ "Computation", "Arguments", "ExecutionOptions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecuteGraphParallelRequest), global::Xla.ExecuteGraphParallelRequest.Parser, new[]{ "Requests" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecuteResponse), global::Xla.ExecuteResponse.Parser, new[]{ "Output", "Profile" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecuteParallelResponse), global::Xla.ExecuteParallelResponse.Parser, new[]{ "Responses" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WaitForExecutionRequest), global::Xla.WaitForExecutionRequest.Parser, new[]{ "Execution" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WaitForExecutionResponse), global::Xla.WaitForExecutionResponse.Parser, new[]{ "Output", "Profile" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ComputeConstantGraphRequest), global::Xla.ComputeConstantGraphRequest.Parser, new[]{ "Computation", "OutputLayout" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ComputeConstantResponse), global::Xla.ComputeConstantResponse.Parser, new[]{ "Literal" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeconstructTupleRequest), global::Xla.DeconstructTupleRequest.Parser, new[]{ "TupleHandle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeconstructTupleResponse), global::Xla.DeconstructTupleResponse.Parser, new[]{ "ElementHandles" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LoadDataRequest), global::Xla.LoadDataRequest.Parser, new[]{ "ColumnioTabletPath", "ColumnioField", "ElementShape", "Offset", "Limit", "Zip" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LoadDataResponse), global::Xla.LoadDataResponse.Parser, new[]{ "Data", "DataShape", "AvailableRows", "RowsLoaded", "Nanoseconds" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GetShapeRequest), global::Xla.GetShapeRequest.Parser, new[]{ "Data" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GetShapeResponse), global::Xla.GetShapeResponse.Parser, new[]{ "Shape" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.UnpackRequest), global::Xla.UnpackRequest.Parser, new[]{ "Data" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.UnpackResponse), global::Xla.UnpackResponse.Parser, new[]{ "TiedData" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Debugging options for XLA. These options may change at any time - there are + /// no guarantees about backward or forward compatibility for these fields. + /// + public sealed partial class DebugOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DebugOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions(DebugOptions other) : this() { + xlaHloGraphAddresses_ = other.xlaHloGraphAddresses_; + xlaHloProfile_ = other.xlaHloProfile_; + xlaDisableHloPasses_ = other.xlaDisableHloPasses_.Clone(); + xlaEnableHloPassesOnly_ = other.xlaEnableHloPassesOnly_.Clone(); + xlaDisableAllHloPasses_ = other.xlaDisableAllHloPasses_; + xlaBackendOptimizationLevel_ = other.xlaBackendOptimizationLevel_; + xlaEmbedIrInExecutable_ = other.xlaEmbedIrInExecutable_; + xlaEliminateHloImplicitBroadcast_ = other.xlaEliminateHloImplicitBroadcast_; + xlaCpuMultiThreadEigen_ = other.xlaCpuMultiThreadEigen_; + xlaGpuCudaDataDir_ = other.xlaGpuCudaDataDir_; + xlaGpuFtz_ = other.xlaGpuFtz_; + xlaLlvmEnableAliasScopeMetadata_ = other.xlaLlvmEnableAliasScopeMetadata_; + xlaLlvmEnableNoaliasMetadata_ = other.xlaLlvmEnableNoaliasMetadata_; + xlaLlvmEnableInvariantLoadMetadata_ = other.xlaLlvmEnableInvariantLoadMetadata_; + xlaLlvmDisableExpensivePasses_ = other.xlaLlvmDisableExpensivePasses_; + xlaTestAllOutputLayouts_ = other.xlaTestAllOutputLayouts_; + xlaTestAllInputLayouts_ = other.xlaTestAllInputLayouts_; + xlaHloGraphShardingColor_ = other.xlaHloGraphShardingColor_; + xlaCpuUseMklDnn_ = other.xlaCpuUseMklDnn_; + xlaCpuUseXlaRuntime_ = other.xlaCpuUseXlaRuntime_; + xlaGpuMaxKernelUnrollFactor_ = other.xlaGpuMaxKernelUnrollFactor_; + xlaCpuEnableFastMath_ = other.xlaCpuEnableFastMath_; + xlaCpuFastMathHonorNans_ = other.xlaCpuFastMathHonorNans_; + xlaCpuFastMathHonorInfs_ = other.xlaCpuFastMathHonorInfs_; + xlaCpuFastMathHonorDivision_ = other.xlaCpuFastMathHonorDivision_; + xlaCpuFastMathHonorFunctions_ = other.xlaCpuFastMathHonorFunctions_; + xlaCpuEnableFastMinMax_ = other.xlaCpuEnableFastMinMax_; + xlaGpuEnableFastMinMax_ = other.xlaGpuEnableFastMinMax_; + xlaAllowExcessPrecision_ = other.xlaAllowExcessPrecision_; + xlaGpuCrashOnVerificationFailures_ = other.xlaGpuCrashOnVerificationFailures_; + xlaGpuAutotuneLevel_ = other.xlaGpuAutotuneLevel_; + xlaForceHostPlatformDeviceCount_ = other.xlaForceHostPlatformDeviceCount_; + xlaGpuDisableGpuasmOptimizations_ = other.xlaGpuDisableGpuasmOptimizations_; + xlaGpuShapeChecks_ = other.xlaGpuShapeChecks_; + xlaCpuEnableMlirLowering_ = other.xlaCpuEnableMlirLowering_; + xlaGpuEnableMlirLowering_ = other.xlaGpuEnableMlirLowering_; + xlaHloEvaluatorUseFastPath_ = other.xlaHloEvaluatorUseFastPath_; + xlaAllowScalarIndexDynamicOps_ = other.xlaAllowScalarIndexDynamicOps_; + xlaStepMarkerLocation_ = other.xlaStepMarkerLocation_; + xlaDumpTo_ = other.xlaDumpTo_; + xlaDumpHloModuleRe_ = other.xlaDumpHloModuleRe_; + xlaDumpHloPassRe_ = other.xlaDumpHloPassRe_; + xlaDumpHloAsText_ = other.xlaDumpHloAsText_; + xlaDumpHloAsProto_ = other.xlaDumpHloAsProto_; + xlaDumpHloAsDot_ = other.xlaDumpHloAsDot_; + xlaDumpHloAsUrl_ = other.xlaDumpHloAsUrl_; + xlaDumpHloAsHtml_ = other.xlaDumpHloAsHtml_; + xlaDumpFusionVisualization_ = other.xlaDumpFusionVisualization_; + xlaDumpHloSnapshots_ = other.xlaDumpHloSnapshots_; + xlaDumpIncludeTimestamp_ = other.xlaDumpIncludeTimestamp_; + xlaDumpMaxHloModules_ = other.xlaDumpMaxHloModules_; + xlaDumpModuleMetadata_ = other.xlaDumpModuleMetadata_; + xlaDumpCompressProtos_ = other.xlaDumpCompressProtos_; + xlaDumpHloAsLongText_ = other.xlaDumpHloAsLongText_; + xlaGpuForceConvNchw_ = other.xlaGpuForceConvNchw_; + xlaGpuForceConvNhwc_ = other.xlaGpuForceConvNhwc_; + xlaGpuPtxFile_ = other.xlaGpuPtxFile_.Clone(); + xlaGpuDumpLlvmir_ = other.xlaGpuDumpLlvmir_; + xlaGpuAlgorithmDenylistPath_ = other.xlaGpuAlgorithmDenylistPath_; + xlaTpuDetectNan_ = other.xlaTpuDetectNan_; + xlaTpuDetectInf_ = other.xlaTpuDetectInf_; + xlaCpuEnableXprofTraceme_ = other.xlaCpuEnableXprofTraceme_; + xlaGpuUnsafeFallbackToDriverOnPtxasNotFound_ = other.xlaGpuUnsafeFallbackToDriverOnPtxasNotFound_; + xlaGpuAsmExtraFlags_ = other.xlaGpuAsmExtraFlags_; + xlaMultiheapSizeConstraintPerHeap_ = other.xlaMultiheapSizeConstraintPerHeap_; + xlaDetailedLoggingAndDumping_ = other.xlaDetailedLoggingAndDumping_; + xlaGpuForceCompilationParallelism_ = other.xlaGpuForceCompilationParallelism_; + xlaGpuDeterministicOps_ = other.xlaGpuDeterministicOps_; + xlaGpuLlvmIrFile_ = other.xlaGpuLlvmIrFile_.Clone(); + xlaGpuEnableAsyncAllReduce_ = other.xlaGpuEnableAsyncAllReduce_; + xlaGpuAllReduceCombineThresholdBytes_ = other.xlaGpuAllReduceCombineThresholdBytes_; + xlaGpuAllReduceContiguous_ = other.xlaGpuAllReduceContiguous_; + xlaGpuAllReduceBlueconnectNumDevicesPerHost_ = other.xlaGpuAllReduceBlueconnectNumDevicesPerHost_; + xlaGpuEnableCudnnFrontend_ = other.xlaGpuEnableCudnnFrontend_; + xlaDumpDisableMetadata_ = other.xlaDumpDisableMetadata_; + xlaDumpHloPipelineRe_ = other.xlaDumpHloPipelineRe_; + xlaGpuStrictConvAlgorithmPicker_ = other.xlaGpuStrictConvAlgorithmPicker_; + xlaGpuEnableXlaRuntimeExecutable_ = other.xlaGpuEnableXlaRuntimeExecutable_; + xlaGpuNcclTerminationTimeoutSeconds_ = other.xlaGpuNcclTerminationTimeoutSeconds_; + xlaGpuEnableSharedConstants_ = other.xlaGpuEnableSharedConstants_; + xlaGpuEnableCublaslt_ = other.xlaGpuEnableCublaslt_; + xlaGpuRedzoneScratchMaxMegabytes_ = other.xlaGpuRedzoneScratchMaxMegabytes_; + xlaGpuSimplifyAllFpConversions_ = other.xlaGpuSimplifyAllFpConversions_; + xlaGpuNormalizeLayouts_ = other.xlaGpuNormalizeLayouts_; + xlaCpuUseAcl_ = other.xlaCpuUseAcl_; + xlaCpuStrictDotConvMath_ = other.xlaCpuStrictDotConvMath_; + xlaBackendExtraOptions_ = other.xlaBackendExtraOptions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DebugOptions Clone() { + return new DebugOptions(this); + } + + /// Field number for the "xla_hlo_graph_addresses" field. + public const int XlaHloGraphAddressesFieldNumber = 2; + private bool xlaHloGraphAddresses_; + /// + /// Show addresses of HLO ops in graph dump. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaHloGraphAddresses { + get { return xlaHloGraphAddresses_; } + set { + xlaHloGraphAddresses_ = value; + } + } + + /// Field number for the "xla_hlo_profile" field. + public const int XlaHloProfileFieldNumber = 9; + private bool xlaHloProfile_; + /// + /// Instrument the computation to collect per-HLO cycle counts. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaHloProfile { + get { return xlaHloProfile_; } + set { + xlaHloProfile_ = value; + } + } + + /// Field number for the "xla_disable_hlo_passes" field. + public const int XlaDisableHloPassesFieldNumber = 30; + private static readonly pb::FieldCodec _repeated_xlaDisableHloPasses_codec + = pb::FieldCodec.ForString(242); + private readonly pbc::RepeatedField xlaDisableHloPasses_ = new pbc::RepeatedField(); + /// + /// List of HLO passes to disable/enable. These names must exactly match the + /// pass names as specified by the HloPassInterface::name() method. + /// + /// At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must + /// be empty. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField XlaDisableHloPasses { + get { return xlaDisableHloPasses_; } + } + + /// Field number for the "xla_enable_hlo_passes_only" field. + public const int XlaEnableHloPassesOnlyFieldNumber = 124; + private static readonly pb::FieldCodec _repeated_xlaEnableHloPassesOnly_codec + = pb::FieldCodec.ForString(994); + private readonly pbc::RepeatedField xlaEnableHloPassesOnly_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField XlaEnableHloPassesOnly { + get { return xlaEnableHloPassesOnly_; } + } + + /// Field number for the "xla_disable_all_hlo_passes" field. + public const int XlaDisableAllHloPassesFieldNumber = 104; + private bool xlaDisableAllHloPasses_; + /// + /// Disables all HLO passes. Notes that some passes are necessary for + /// correctness and the invariants that must be satisfied by "fully optimized" + /// HLO are different for different devices and may change over time. The only + /// "guarantee", such as it is, is that if you compile XLA and dump the + /// optimized HLO for some graph, you should be able to run it again on the + /// same device with the same build of XLA. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDisableAllHloPasses { + get { return xlaDisableAllHloPasses_; } + set { + xlaDisableAllHloPasses_ = value; + } + } + + /// Field number for the "xla_backend_optimization_level" field. + public const int XlaBackendOptimizationLevelFieldNumber = 31; + private int xlaBackendOptimizationLevel_; + /// + /// Numerical optimization level for the XLA compiler backend; the specific + /// interpretation of this value is left to the backends. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaBackendOptimizationLevel { + get { return xlaBackendOptimizationLevel_; } + set { + xlaBackendOptimizationLevel_ = value; + } + } + + /// Field number for the "xla_embed_ir_in_executable" field. + public const int XlaEmbedIrInExecutableFieldNumber = 33; + private bool xlaEmbedIrInExecutable_; + /// + /// Embed the compiler IR as a string in the executable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaEmbedIrInExecutable { + get { return xlaEmbedIrInExecutable_; } + set { + xlaEmbedIrInExecutable_ = value; + } + } + + /// Field number for the "xla_eliminate_hlo_implicit_broadcast" field. + public const int XlaEliminateHloImplicitBroadcastFieldNumber = 35; + private bool xlaEliminateHloImplicitBroadcast_; + /// + /// Eliminate implicit broadcasts when lowering user computations to HLO + /// instructions; use explicit broadcast instead. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaEliminateHloImplicitBroadcast { + get { return xlaEliminateHloImplicitBroadcast_; } + set { + xlaEliminateHloImplicitBroadcast_ = value; + } + } + + /// Field number for the "xla_cpu_multi_thread_eigen" field. + public const int XlaCpuMultiThreadEigenFieldNumber = 60; + private bool xlaCpuMultiThreadEigen_; + /// + /// When generating calls to Eigen in the CPU backend, use multi-threaded Eigen + /// mode. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuMultiThreadEigen { + get { return xlaCpuMultiThreadEigen_; } + set { + xlaCpuMultiThreadEigen_ = value; + } + } + + /// Field number for the "xla_gpu_cuda_data_dir" field. + public const int XlaGpuCudaDataDirFieldNumber = 61; + private string xlaGpuCudaDataDir_ = ""; + /// + /// Path to directory with cuda/ptx tools and libraries. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaGpuCudaDataDir { + get { return xlaGpuCudaDataDir_; } + set { + xlaGpuCudaDataDir_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_gpu_ftz" field. + public const int XlaGpuFtzFieldNumber = 62; + private bool xlaGpuFtz_; + /// + /// Enable flush-to-zero semantics in the GPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuFtz { + get { return xlaGpuFtz_; } + set { + xlaGpuFtz_ = value; + } + } + + /// Field number for the "xla_llvm_enable_alias_scope_metadata" field. + public const int XlaLlvmEnableAliasScopeMetadataFieldNumber = 70; + private bool xlaLlvmEnableAliasScopeMetadata_; + /// + /// If true, in LLVM-based backends, emit !alias.scope metadata in + /// generated IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaLlvmEnableAliasScopeMetadata { + get { return xlaLlvmEnableAliasScopeMetadata_; } + set { + xlaLlvmEnableAliasScopeMetadata_ = value; + } + } + + /// Field number for the "xla_llvm_enable_noalias_metadata" field. + public const int XlaLlvmEnableNoaliasMetadataFieldNumber = 71; + private bool xlaLlvmEnableNoaliasMetadata_; + /// + /// If true, in LLVM-based backends, emit !noalias metadata in the + /// generated IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaLlvmEnableNoaliasMetadata { + get { return xlaLlvmEnableNoaliasMetadata_; } + set { + xlaLlvmEnableNoaliasMetadata_ = value; + } + } + + /// Field number for the "xla_llvm_enable_invariant_load_metadata" field. + public const int XlaLlvmEnableInvariantLoadMetadataFieldNumber = 72; + private bool xlaLlvmEnableInvariantLoadMetadata_; + /// + /// If true, in LLVM-based backends, emit !invariant.load metadata in + /// the generated IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaLlvmEnableInvariantLoadMetadata { + get { return xlaLlvmEnableInvariantLoadMetadata_; } + set { + xlaLlvmEnableInvariantLoadMetadata_ = value; + } + } + + /// Field number for the "xla_llvm_disable_expensive_passes" field. + public const int XlaLlvmDisableExpensivePassesFieldNumber = 73; + private bool xlaLlvmDisableExpensivePasses_; + /// + /// If true, a set of expensive LLVM optimization passes will not be run. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaLlvmDisableExpensivePasses { + get { return xlaLlvmDisableExpensivePasses_; } + set { + xlaLlvmDisableExpensivePasses_ = value; + } + } + + /// Field number for the "xla_test_all_output_layouts" field. + public const int XlaTestAllOutputLayoutsFieldNumber = 90; + private bool xlaTestAllOutputLayouts_; + /// + /// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the + /// computation will run n! times with all permunations of layouts for the + /// output shape in rank n. For example, with a 3D shape, all permutations of + /// the set {0, 1, 2} are tried. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaTestAllOutputLayouts { + get { return xlaTestAllOutputLayouts_; } + set { + xlaTestAllOutputLayouts_ = value; + } + } + + /// Field number for the "xla_test_all_input_layouts" field. + public const int XlaTestAllInputLayoutsFieldNumber = 91; + private bool xlaTestAllInputLayouts_; + /// + /// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the + /// computation will run for all permunations of layouts of all input + /// arguments. For example, with 2 input arguments in 2D and 4D shapes, the + /// computation will run 2! * 4! times. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaTestAllInputLayouts { + get { return xlaTestAllInputLayouts_; } + set { + xlaTestAllInputLayouts_ = value; + } + } + + /// Field number for the "xla_hlo_graph_sharding_color" field. + public const int XlaHloGraphShardingColorFieldNumber = 92; + private bool xlaHloGraphShardingColor_; + /// + /// Assign colors based on sharding information when generating the Graphviz + /// HLO graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaHloGraphShardingColor { + get { return xlaHloGraphShardingColor_; } + set { + xlaHloGraphShardingColor_ = value; + } + } + + /// Field number for the "xla_cpu_use_mkl_dnn" field. + public const int XlaCpuUseMklDnnFieldNumber = 97; + private bool xlaCpuUseMklDnn_; + /// + /// Generate calls to MKL-DNN in the CPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuUseMklDnn { + get { return xlaCpuUseMklDnn_; } + set { + xlaCpuUseMklDnn_ = value; + } + } + + /// Field number for the "xla_cpu_use_xla_runtime" field. + public const int XlaCpuUseXlaRuntimeFieldNumber = 177; + private bool xlaCpuUseXlaRuntime_; + /// + /// Enable XLA Runtime in the CPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuUseXlaRuntime { + get { return xlaCpuUseXlaRuntime_; } + set { + xlaCpuUseXlaRuntime_ = value; + } + } + + /// Field number for the "xla_gpu_max_kernel_unroll_factor" field. + public const int XlaGpuMaxKernelUnrollFactorFieldNumber = 98; + private int xlaGpuMaxKernelUnrollFactor_; + /// + /// Maximum kernel unroll factor for the GPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaGpuMaxKernelUnrollFactor { + get { return xlaGpuMaxKernelUnrollFactor_; } + set { + xlaGpuMaxKernelUnrollFactor_ = value; + } + } + + /// Field number for the "xla_cpu_enable_fast_math" field. + public const int XlaCpuEnableFastMathFieldNumber = 99; + private bool xlaCpuEnableFastMath_; + /// + /// When true, "unsafe" mathematical optimizations are enabled. These + /// transformations include but are not limited to: + /// + /// - Reducing the precision of operations (e.g. using an approximate sin + /// function, or transforming x/y into x * (1/y)). + /// - Assuming that operations never produce or consume NaN or +/- Inf (this + /// behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). + /// - Assuming that +0 and -0 are indistinguishable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuEnableFastMath { + get { return xlaCpuEnableFastMath_; } + set { + xlaCpuEnableFastMath_ = value; + } + } + + /// Field number for the "xla_cpu_fast_math_honor_nans" field. + public const int XlaCpuFastMathHonorNansFieldNumber = 120; + private bool xlaCpuFastMathHonorNans_; + /// + /// When xla_cpu_enable_fast_math is true then this controls whether we allow + /// operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + /// false. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuFastMathHonorNans { + get { return xlaCpuFastMathHonorNans_; } + set { + xlaCpuFastMathHonorNans_ = value; + } + } + + /// Field number for the "xla_cpu_fast_math_honor_infs" field. + public const int XlaCpuFastMathHonorInfsFieldNumber = 121; + private bool xlaCpuFastMathHonorInfs_; + /// + /// When xla_cpu_enable_fast_math is true then this controls whether we allow + /// operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + /// false. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuFastMathHonorInfs { + get { return xlaCpuFastMathHonorInfs_; } + set { + xlaCpuFastMathHonorInfs_ = value; + } + } + + /// Field number for the "xla_cpu_fast_math_honor_division" field. + public const int XlaCpuFastMathHonorDivisionFieldNumber = 126; + private bool xlaCpuFastMathHonorDivision_; + /// + /// When xla_cpu_enable_fast_math is true then this controls whether we forbid + /// to use the reciprocal of an argument instead of division. Ignored when + /// xla_cpu_enable_fast_math is false. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuFastMathHonorDivision { + get { return xlaCpuFastMathHonorDivision_; } + set { + xlaCpuFastMathHonorDivision_ = value; + } + } + + /// Field number for the "xla_cpu_fast_math_honor_functions" field. + public const int XlaCpuFastMathHonorFunctionsFieldNumber = 129; + private bool xlaCpuFastMathHonorFunctions_; + /// + /// When xla_cpu_enable_fast_math is true then this controls whether we forbid + /// to approximate calculations for functions. Ignored when + /// xla_cpu_enable_fast_math is false. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuFastMathHonorFunctions { + get { return xlaCpuFastMathHonorFunctions_; } + set { + xlaCpuFastMathHonorFunctions_ = value; + } + } + + /// Field number for the "xla_cpu_enable_fast_min_max" field. + public const int XlaCpuEnableFastMinMaxFieldNumber = 140; + private bool xlaCpuEnableFastMinMax_; + /// + /// When false we lower the Minimum and Maximum hlos in the CPU backend such + /// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + /// this is false we always propagate NaNs through Min and Max. + /// + /// Note, this does not correspond to the exact same behavior as the gpu flag + /// below! + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuEnableFastMinMax { + get { return xlaCpuEnableFastMinMax_; } + set { + xlaCpuEnableFastMinMax_ = value; + } + } + + /// Field number for the "xla_gpu_enable_fast_min_max" field. + public const int XlaGpuEnableFastMinMaxFieldNumber = 100; + private bool xlaGpuEnableFastMinMax_; + /// + /// When true we lower the Minimum and Maximum hlos in the GPU backend such + /// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + /// this is true we don't propagate NaNs through Min and Max. + /// + /// Note, this does not correspond to the exact same behavior as the cpu flag + /// above! + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableFastMinMax { + get { return xlaGpuEnableFastMinMax_; } + set { + xlaGpuEnableFastMinMax_ = value; + } + } + + /// Field number for the "xla_allow_excess_precision" field. + public const int XlaAllowExcessPrecisionFieldNumber = 122; + private bool xlaAllowExcessPrecision_; + /// + /// Allows xla to increase the output precision of floating point operations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaAllowExcessPrecision { + get { return xlaAllowExcessPrecision_; } + set { + xlaAllowExcessPrecision_ = value; + } + } + + /// Field number for the "xla_gpu_crash_on_verification_failures" field. + public const int XlaGpuCrashOnVerificationFailuresFieldNumber = 101; + private bool xlaGpuCrashOnVerificationFailures_; + /// + /// Crashes the program when any kind of verification fails, instead of just + /// logging the failures. One example is cross checking of convolution results + /// among different algorithms. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuCrashOnVerificationFailures { + get { return xlaGpuCrashOnVerificationFailures_; } + set { + xlaGpuCrashOnVerificationFailures_ = value; + } + } + + /// Field number for the "xla_gpu_autotune_level" field. + public const int XlaGpuAutotuneLevelFieldNumber = 123; + private int xlaGpuAutotuneLevel_; + /// + /// 0: Disable gemm and convolution autotuning. + /// 1: Enable autotuning, but disable correctness checking. + /// 2: Also set output buffers to random numbers during autotuning. + /// 3: Also reset output buffers to random numbers after autotuning each + /// algorithm. + /// 4+: Also check for correct outputs and for out-of-bounds reads/writes. + /// + /// Default: 4. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaGpuAutotuneLevel { + get { return xlaGpuAutotuneLevel_; } + set { + xlaGpuAutotuneLevel_ = value; + } + } + + /// Field number for the "xla_force_host_platform_device_count" field. + public const int XlaForceHostPlatformDeviceCountFieldNumber = 102; + private int xlaForceHostPlatformDeviceCount_; + /// + /// Force the host platform to pretend that there are these many host + /// "devices". All these devices are backed by the same threadpool. Defaults + /// to 1. + /// + /// Setting this to anything other than 1 can increase overhead from context + /// switching but we let the user override this behavior to help run tests on + /// the host that run models in parallel across multiple devices. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaForceHostPlatformDeviceCount { + get { return xlaForceHostPlatformDeviceCount_; } + set { + xlaForceHostPlatformDeviceCount_ = value; + } + } + + /// Field number for the "xla_gpu_disable_gpuasm_optimizations" field. + public const int XlaGpuDisableGpuasmOptimizationsFieldNumber = 103; + private bool xlaGpuDisableGpuasmOptimizations_; + /// + /// If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuDisableGpuasmOptimizations { + get { return xlaGpuDisableGpuasmOptimizations_; } + set { + xlaGpuDisableGpuasmOptimizations_ = value; + } + } + + /// Field number for the "xla_gpu_shape_checks" field. + public const int XlaGpuShapeChecksFieldNumber = 170; + private global::Xla.DebugOptions.Types.ShapeChecks xlaGpuShapeChecks_ = global::Xla.DebugOptions.Types.ShapeChecks.Ignore; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DebugOptions.Types.ShapeChecks XlaGpuShapeChecks { + get { return xlaGpuShapeChecks_; } + set { + xlaGpuShapeChecks_ = value; + } + } + + /// Field number for the "xla_cpu_enable_mlir_lowering" field. + public const int XlaCpuEnableMlirLoweringFieldNumber = 171; + private bool xlaCpuEnableMlirLowering_; + /// + /// Enable MLIR-based lowering in XLA:CPU instead of LLVM emitters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuEnableMlirLowering { + get { return xlaCpuEnableMlirLowering_; } + set { + xlaCpuEnableMlirLowering_ = value; + } + } + + /// Field number for the "xla_gpu_enable_mlir_lowering" field. + public const int XlaGpuEnableMlirLoweringFieldNumber = 173; + private bool xlaGpuEnableMlirLowering_; + /// + /// If true, use MLIR instead of IR emitter to generate device code for + /// supported lmhlo.fusion ops. See xla::gpu::RewriteFusionOps() for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableMlirLowering { + get { return xlaGpuEnableMlirLowering_; } + set { + xlaGpuEnableMlirLowering_ = value; + } + } + + /// Field number for the "xla_hlo_evaluator_use_fast_path" field. + public const int XlaHloEvaluatorUseFastPathFieldNumber = 106; + private bool xlaHloEvaluatorUseFastPath_; + /// + /// Enable fast math with eigen in the HLO evaluator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaHloEvaluatorUseFastPath { + get { return xlaHloEvaluatorUseFastPath_; } + set { + xlaHloEvaluatorUseFastPath_ = value; + } + } + + /// Field number for the "xla_allow_scalar_index_dynamic_ops" field. + public const int XlaAllowScalarIndexDynamicOpsFieldNumber = 107; + private bool xlaAllowScalarIndexDynamicOps_; + /// + /// Temporary option to allow support for both the R1 and the scalar index + /// versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaAllowScalarIndexDynamicOps { + get { return xlaAllowScalarIndexDynamicOps_; } + set { + xlaAllowScalarIndexDynamicOps_ = value; + } + } + + /// Field number for the "xla_step_marker_location" field. + public const int XlaStepMarkerLocationFieldNumber = 108; + private global::Xla.DebugOptions.Types.StepMarkerLocation xlaStepMarkerLocation_ = global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry; + /// + /// Option to emit a target-specific marker to indicate the start of a training + /// step. The location of the marker (if any) is determined by the option + /// value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DebugOptions.Types.StepMarkerLocation XlaStepMarkerLocation { + get { return xlaStepMarkerLocation_; } + set { + xlaStepMarkerLocation_ = value; + } + } + + /// Field number for the "xla_dump_to" field. + public const int XlaDumpToFieldNumber = 109; + private string xlaDumpTo_ = ""; + /// + /// Directory to dump into. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaDumpTo { + get { return xlaDumpTo_; } + set { + xlaDumpTo_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_dump_hlo_module_re" field. + public const int XlaDumpHloModuleReFieldNumber = 110; + private string xlaDumpHloModuleRe_ = ""; + /// + /// If specified, will only dump modules which match this regexp. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaDumpHloModuleRe { + get { return xlaDumpHloModuleRe_; } + set { + xlaDumpHloModuleRe_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_dump_hlo_pass_re" field. + public const int XlaDumpHloPassReFieldNumber = 111; + private string xlaDumpHloPassRe_ = ""; + /// + /// If this flag is specified, will also dump HLO before and after passes that + /// match this regular expression. Set to .* to dump before/after all passes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaDumpHloPassRe { + get { return xlaDumpHloPassRe_; } + set { + xlaDumpHloPassRe_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_dump_hlo_as_text" field. + public const int XlaDumpHloAsTextFieldNumber = 112; + private bool xlaDumpHloAsText_; + /// + /// Specifies the format that HLO is dumped in. Multiple of these may be + /// specified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsText { + get { return xlaDumpHloAsText_; } + set { + xlaDumpHloAsText_ = value; + } + } + + /// Field number for the "xla_dump_hlo_as_proto" field. + public const int XlaDumpHloAsProtoFieldNumber = 113; + private bool xlaDumpHloAsProto_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsProto { + get { return xlaDumpHloAsProto_; } + set { + xlaDumpHloAsProto_ = value; + } + } + + /// Field number for the "xla_dump_hlo_as_dot" field. + public const int XlaDumpHloAsDotFieldNumber = 114; + private bool xlaDumpHloAsDot_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsDot { + get { return xlaDumpHloAsDot_; } + set { + xlaDumpHloAsDot_ = value; + } + } + + /// Field number for the "xla_dump_hlo_as_url" field. + public const int XlaDumpHloAsUrlFieldNumber = 115; + private bool xlaDumpHloAsUrl_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsUrl { + get { return xlaDumpHloAsUrl_; } + set { + xlaDumpHloAsUrl_ = value; + } + } + + /// Field number for the "xla_dump_hlo_as_html" field. + public const int XlaDumpHloAsHtmlFieldNumber = 116; + private bool xlaDumpHloAsHtml_; + /// + /// Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsHtml { + get { return xlaDumpHloAsHtml_; } + set { + xlaDumpHloAsHtml_ = value; + } + } + + /// Field number for the "xla_dump_fusion_visualization" field. + public const int XlaDumpFusionVisualizationFieldNumber = 149; + private bool xlaDumpFusionVisualization_; + /// + /// Dump the visualization of the fusion progress. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpFusionVisualization { + get { return xlaDumpFusionVisualization_; } + set { + xlaDumpFusionVisualization_ = value; + } + } + + /// Field number for the "xla_dump_hlo_snapshots" field. + public const int XlaDumpHloSnapshotsFieldNumber = 118; + private bool xlaDumpHloSnapshots_; + /// + /// If true, every time an HLO module is run, we will dump an HloSnapshot + /// (essentially, a serialized module plus its inputs) to the --xla_dump_to + /// directory. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloSnapshots { + get { return xlaDumpHloSnapshots_; } + set { + xlaDumpHloSnapshots_ = value; + } + } + + /// Field number for the "xla_dump_include_timestamp" field. + public const int XlaDumpIncludeTimestampFieldNumber = 131; + private bool xlaDumpIncludeTimestamp_; + /// + /// Include a timestamp in the dumped filenames. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpIncludeTimestamp { + get { return xlaDumpIncludeTimestamp_; } + set { + xlaDumpIncludeTimestamp_ = value; + } + } + + /// Field number for the "xla_dump_max_hlo_modules" field. + public const int XlaDumpMaxHloModulesFieldNumber = 132; + private int xlaDumpMaxHloModules_; + /// + /// Max number of hlo module dumps in a directory. Set to < 0 for unbounded. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaDumpMaxHloModules { + get { return xlaDumpMaxHloModules_; } + set { + xlaDumpMaxHloModules_ = value; + } + } + + /// Field number for the "xla_dump_module_metadata" field. + public const int XlaDumpModuleMetadataFieldNumber = 144; + private bool xlaDumpModuleMetadata_; + /// + /// Dump HloModuleMetadata as a text proto for each HLO module. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpModuleMetadata { + get { return xlaDumpModuleMetadata_; } + set { + xlaDumpModuleMetadata_ = value; + } + } + + /// Field number for the "xla_dump_compress_protos" field. + public const int XlaDumpCompressProtosFieldNumber = 151; + private bool xlaDumpCompressProtos_; + /// + /// GZip-compress protos dumped via --xla_dump_hlo_as_proto. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpCompressProtos { + get { return xlaDumpCompressProtos_; } + set { + xlaDumpCompressProtos_ = value; + } + } + + /// Field number for the "xla_dump_hlo_as_long_text" field. + public const int XlaDumpHloAsLongTextFieldNumber = 164; + private bool xlaDumpHloAsLongText_; + /// + /// Dump HLO in long text format. Ignored unless xla_dump_hlo_as_text is true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpHloAsLongText { + get { return xlaDumpHloAsLongText_; } + set { + xlaDumpHloAsLongText_ = value; + } + } + + /// Field number for the "xla_gpu_force_conv_nchw" field. + public const int XlaGpuForceConvNchwFieldNumber = 125; + private bool xlaGpuForceConvNchw_; + /// + /// Overrides for XLA GPU's convolution layout heuristic. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuForceConvNchw { + get { return xlaGpuForceConvNchw_; } + set { + xlaGpuForceConvNchw_ = value; + } + } + + /// Field number for the "xla_gpu_force_conv_nhwc" field. + public const int XlaGpuForceConvNhwcFieldNumber = 146; + private bool xlaGpuForceConvNhwc_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuForceConvNhwc { + get { return xlaGpuForceConvNhwc_; } + set { + xlaGpuForceConvNhwc_ = value; + } + } + + /// Field number for the "xla_gpu_ptx_file" field. + public const int XlaGpuPtxFileFieldNumber = 127; + private static readonly pb::FieldCodec _repeated_xlaGpuPtxFile_codec + = pb::FieldCodec.ForString(1018); + private readonly pbc::RepeatedField xlaGpuPtxFile_ = new pbc::RepeatedField(); + /// + /// Paths to files with ptx code. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField XlaGpuPtxFile { + get { return xlaGpuPtxFile_; } + } + + /// Field number for the "xla_gpu_dump_llvmir" field. + public const int XlaGpuDumpLlvmirFieldNumber = 155; + private bool xlaGpuDumpLlvmir_; + /// + /// Whether to dump llvm ir when compiling to ptx. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuDumpLlvmir { + get { return xlaGpuDumpLlvmir_; } + set { + xlaGpuDumpLlvmir_ = value; + } + } + + /// Field number for the "xla_gpu_algorithm_denylist_path" field. + public const int XlaGpuAlgorithmDenylistPathFieldNumber = 128; + private string xlaGpuAlgorithmDenylistPath_ = ""; + /// + /// Denylist for cuDNN convolutions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaGpuAlgorithmDenylistPath { + get { return xlaGpuAlgorithmDenylistPath_; } + set { + xlaGpuAlgorithmDenylistPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_tpu_detect_nan" field. + public const int XlaTpuDetectNanFieldNumber = 135; + private bool xlaTpuDetectNan_; + /// + /// Debug options that trigger execution errors when NaN or Inf are detected. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaTpuDetectNan { + get { return xlaTpuDetectNan_; } + set { + xlaTpuDetectNan_ = value; + } + } + + /// Field number for the "xla_tpu_detect_inf" field. + public const int XlaTpuDetectInfFieldNumber = 136; + private bool xlaTpuDetectInf_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaTpuDetectInf { + get { return xlaTpuDetectInf_; } + set { + xlaTpuDetectInf_ = value; + } + } + + /// Field number for the "xla_cpu_enable_xprof_traceme" field. + public const int XlaCpuEnableXprofTracemeFieldNumber = 137; + private bool xlaCpuEnableXprofTraceme_; + /// + /// True if TraceMe annotations are enabled for XLA:CPU. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuEnableXprofTraceme { + get { return xlaCpuEnableXprofTraceme_; } + set { + xlaCpuEnableXprofTraceme_ = value; + } + } + + /// Field number for the "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found" field. + public const int XlaGpuUnsafeFallbackToDriverOnPtxasNotFoundFieldNumber = 138; + private bool xlaGpuUnsafeFallbackToDriverOnPtxasNotFound_; + /// + /// It is usually preferable to not fallback to the driver; it can consume more + /// memory, or have bugs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuUnsafeFallbackToDriverOnPtxasNotFound { + get { return xlaGpuUnsafeFallbackToDriverOnPtxasNotFound_; } + set { + xlaGpuUnsafeFallbackToDriverOnPtxasNotFound_ = value; + } + } + + /// Field number for the "xla_gpu_asm_extra_flags" field. + public const int XlaGpuAsmExtraFlagsFieldNumber = 141; + private string xlaGpuAsmExtraFlags_ = ""; + /// + /// Extra parameters to pass the GPU assembler. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaGpuAsmExtraFlags { + get { return xlaGpuAsmExtraFlags_; } + set { + xlaGpuAsmExtraFlags_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_multiheap_size_constraint_per_heap" field. + public const int XlaMultiheapSizeConstraintPerHeapFieldNumber = 142; + private int xlaMultiheapSizeConstraintPerHeap_; + /// + /// Per-heap size constraint. New heaps will be created if per-heap max size is + /// reached. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaMultiheapSizeConstraintPerHeap { + get { return xlaMultiheapSizeConstraintPerHeap_; } + set { + xlaMultiheapSizeConstraintPerHeap_ = value; + } + } + + /// Field number for the "xla_detailed_logging_and_dumping" field. + public const int XlaDetailedLoggingAndDumpingFieldNumber = 143; + private bool xlaDetailedLoggingAndDumping_; + /// + /// Enable detailed logging into vlog and xla dumping. If this is disabled, no + /// compilation summary will be printed in the end of computation and no hlo + /// modules will be dumped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDetailedLoggingAndDumping { + get { return xlaDetailedLoggingAndDumping_; } + set { + xlaDetailedLoggingAndDumping_ = value; + } + } + + /// Field number for the "xla_gpu_force_compilation_parallelism" field. + public const int XlaGpuForceCompilationParallelismFieldNumber = 147; + private int xlaGpuForceCompilationParallelism_; + /// + /// Overrides normal multi-threaded compilation settting to use this many + /// threads. Setting to 0 (the default value) means no enforcement. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaGpuForceCompilationParallelism { + get { return xlaGpuForceCompilationParallelism_; } + set { + xlaGpuForceCompilationParallelism_ = value; + } + } + + /// Field number for the "xla_gpu_deterministic_ops" field. + public const int XlaGpuDeterministicOpsFieldNumber = 148; + private bool xlaGpuDeterministicOps_; + /// + /// Guarantees run-to-run determinism. At present, the HLO ops Scatter and + /// SelectAndScatter do not have deterministic XLA:GPU implementations. + /// Compilation errors out if these ops are encountered. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuDeterministicOps { + get { return xlaGpuDeterministicOps_; } + set { + xlaGpuDeterministicOps_ = value; + } + } + + /// Field number for the "xla_gpu_llvm_ir_file" field. + public const int XlaGpuLlvmIrFileFieldNumber = 150; + private static readonly pb::FieldCodec _repeated_xlaGpuLlvmIrFile_codec + = pb::FieldCodec.ForString(1202); + private readonly pbc::RepeatedField xlaGpuLlvmIrFile_ = new pbc::RepeatedField(); + /// + /// Paths to files with LLVM code. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField XlaGpuLlvmIrFile { + get { return xlaGpuLlvmIrFile_; } + } + + /// Field number for the "xla_gpu_enable_async_all_reduce" field. + public const int XlaGpuEnableAsyncAllReduceFieldNumber = 152; + private bool xlaGpuEnableAsyncAllReduce_; + /// + /// Convert synchronous all-reduces ops into asynchronous. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableAsyncAllReduce { + get { return xlaGpuEnableAsyncAllReduce_; } + set { + xlaGpuEnableAsyncAllReduce_ = value; + } + } + + /// Field number for the "xla_gpu_all_reduce_combine_threshold_bytes" field. + public const int XlaGpuAllReduceCombineThresholdBytesFieldNumber = 157; + private long xlaGpuAllReduceCombineThresholdBytes_; + /// + /// Size threshold (in bytes) for the GPU all-reduce combiner. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long XlaGpuAllReduceCombineThresholdBytes { + get { return xlaGpuAllReduceCombineThresholdBytes_; } + set { + xlaGpuAllReduceCombineThresholdBytes_ = value; + } + } + + /// Field number for the "xla_gpu_all_reduce_contiguous" field. + public const int XlaGpuAllReduceContiguousFieldNumber = 158; + private bool xlaGpuAllReduceContiguous_; + /// + /// Combine GPU all-reduces into a single operation over a contiguous buffer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuAllReduceContiguous { + get { return xlaGpuAllReduceContiguous_; } + set { + xlaGpuAllReduceContiguous_ = value; + } + } + + /// Field number for the "xla_gpu_all_reduce_blueconnect_num_devices_per_host" field. + public const int XlaGpuAllReduceBlueconnectNumDevicesPerHostFieldNumber = 159; + private int xlaGpuAllReduceBlueconnectNumDevicesPerHost_; + /// + /// Number of devices per host for first stage of BlueConnect decomposition + /// pass. The pass will attempt to decompose all-reduces ops into a + /// ReduceScatter-AllReduce-AllGather sequence, with the initial ReduceScatter + /// being performed over all of the devices in the same host. Set to < 1 to + /// disable all-reduce decomposition. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int XlaGpuAllReduceBlueconnectNumDevicesPerHost { + get { return xlaGpuAllReduceBlueconnectNumDevicesPerHost_; } + set { + xlaGpuAllReduceBlueconnectNumDevicesPerHost_ = value; + } + } + + /// Field number for the "xla_gpu_enable_cudnn_frontend" field. + public const int XlaGpuEnableCudnnFrontendFieldNumber = 160; + private bool xlaGpuEnableCudnnFrontend_; + /// + /// Whether to use the cuDNN frontend API for convolutions when possible. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableCudnnFrontend { + get { return xlaGpuEnableCudnnFrontend_; } + set { + xlaGpuEnableCudnnFrontend_ = value; + } + } + + /// Field number for the "xla_dump_disable_metadata" field. + public const int XlaDumpDisableMetadataFieldNumber = 153; + private bool xlaDumpDisableMetadata_; + /// + /// Disable dumping metadata in HLO dumps. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaDumpDisableMetadata { + get { return xlaDumpDisableMetadata_; } + set { + xlaDumpDisableMetadata_ = value; + } + } + + /// Field number for the "xla_dump_hlo_pipeline_re" field. + public const int XlaDumpHloPipelineReFieldNumber = 154; + private string xlaDumpHloPipelineRe_ = ""; + /// + /// If this flag is specified, will only dump HLO before and after passes in + /// the pass pipeline that matches this regular expression. Default empty value + /// enables dumping in all pipelines. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string XlaDumpHloPipelineRe { + get { return xlaDumpHloPipelineRe_; } + set { + xlaDumpHloPipelineRe_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "xla_gpu_strict_conv_algorithm_picker" field. + public const int XlaGpuStrictConvAlgorithmPickerFieldNumber = 156; + private bool xlaGpuStrictConvAlgorithmPicker_; + /// + /// If true, abort immediately when conv algorithm picker fails, rather than + /// logging a warning and proceeding with fallback. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuStrictConvAlgorithmPicker { + get { return xlaGpuStrictConvAlgorithmPicker_; } + set { + xlaGpuStrictConvAlgorithmPicker_ = value; + } + } + + /// Field number for the "xla_gpu_enable_xla_runtime_executable" field. + public const int XlaGpuEnableXlaRuntimeExecutableFieldNumber = 169; + private bool xlaGpuEnableXlaRuntimeExecutable_; + /// + /// If true, use XLA runtime for XLA:GPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableXlaRuntimeExecutable { + get { return xlaGpuEnableXlaRuntimeExecutable_; } + set { + xlaGpuEnableXlaRuntimeExecutable_ = value; + } + } + + /// Field number for the "xla_gpu_nccl_termination_timeout_seconds" field. + public const int XlaGpuNcclTerminationTimeoutSecondsFieldNumber = 163; + private long xlaGpuNcclTerminationTimeoutSeconds_; + /// + /// Timeout in seconds before terminating jobs that are stuck in a NCCL + /// Rendezvous. Negative value disables the timeout and will not terminate. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long XlaGpuNcclTerminationTimeoutSeconds { + get { return xlaGpuNcclTerminationTimeoutSeconds_; } + set { + xlaGpuNcclTerminationTimeoutSeconds_ = value; + } + } + + /// Field number for the "xla_gpu_enable_shared_constants" field. + public const int XlaGpuEnableSharedConstantsFieldNumber = 165; + private bool xlaGpuEnableSharedConstants_; + /// + /// Enables shared constants for XLA/GPU. This allows large constants to be + /// shared among multiple GPU executables. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableSharedConstants { + get { return xlaGpuEnableSharedConstants_; } + set { + xlaGpuEnableSharedConstants_ = value; + } + } + + /// Field number for the "xla_gpu_enable_cublaslt" field. + public const int XlaGpuEnableCublasltFieldNumber = 166; + private bool xlaGpuEnableCublaslt_; + /// + /// Whether to use cuBLASLt for GEMMs on GPUs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuEnableCublaslt { + get { return xlaGpuEnableCublaslt_; } + set { + xlaGpuEnableCublaslt_ = value; + } + } + + /// Field number for the "xla_gpu_redzone_scratch_max_megabytes" field. + public const int XlaGpuRedzoneScratchMaxMegabytesFieldNumber = 167; + private long xlaGpuRedzoneScratchMaxMegabytes_; + /// + /// Size threshold (in megabytes) for the GPU redzone scratch allocator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long XlaGpuRedzoneScratchMaxMegabytes { + get { return xlaGpuRedzoneScratchMaxMegabytes_; } + set { + xlaGpuRedzoneScratchMaxMegabytes_ = value; + } + } + + /// Field number for the "xla_gpu_simplify_all_fp_conversions" field. + public const int XlaGpuSimplifyAllFpConversionsFieldNumber = 168; + private bool xlaGpuSimplifyAllFpConversions_; + /// + /// Allows all floating-point conversions to be simplified, including those + /// that affect the numerics. The `BFloat16Normalization` pass inserts many + /// `f32 -> bf16 -> f32` conversion pairs. These are not removed by the + /// `AlgebraicSimplifier`, as that will only simplify conversions that are + /// no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuSimplifyAllFpConversions { + get { return xlaGpuSimplifyAllFpConversions_; } + set { + xlaGpuSimplifyAllFpConversions_ = value; + } + } + + /// Field number for the "xla_gpu_normalize_layouts" field. + public const int XlaGpuNormalizeLayoutsFieldNumber = 172; + private bool xlaGpuNormalizeLayouts_; + /// + /// An experimental option to force all layouts present in the + /// after-optimizations HLO to be descending, e.g. + /// ShapeUtil::MakeShapeWithDescendingLayout is an identity on all + /// instructions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaGpuNormalizeLayouts { + get { return xlaGpuNormalizeLayouts_; } + set { + xlaGpuNormalizeLayouts_ = value; + } + } + + /// Field number for the "xla_cpu_use_acl" field. + public const int XlaCpuUseAclFieldNumber = 174; + private bool xlaCpuUseAcl_; + /// + /// Generate calls to Arm Compute Library in the CPU backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuUseAcl { + get { return xlaCpuUseAcl_; } + set { + xlaCpuUseAcl_ = value; + } + } + + /// Field number for the "xla_cpu_strict_dot_conv_math" field. + public const int XlaCpuStrictDotConvMathFieldNumber = 175; + private bool xlaCpuStrictDotConvMath_; + /// + /// By default, XLA:CPU will run fp16 dot/conv as fp32, as this is generally + /// (much) faster on our hardware. Set this flag to disable this behavior. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool XlaCpuStrictDotConvMath { + get { return xlaCpuStrictDotConvMath_; } + set { + xlaCpuStrictDotConvMath_ = value; + } + } + + /// Field number for the "xla_backend_extra_options" field. + public const int XlaBackendExtraOptionsFieldNumber = 500; + private static readonly pbc::MapField.Codec _map_xlaBackendExtraOptions_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 4002); + private readonly pbc::MapField xlaBackendExtraOptions_ = new pbc::MapField(); + /// + /// Extra options to pass to the compilation backend (e.g. LLVM); specific + /// interpretation of these values is left to the backend. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField XlaBackendExtraOptions { + get { return xlaBackendExtraOptions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DebugOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DebugOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (XlaHloGraphAddresses != other.XlaHloGraphAddresses) return false; + if (XlaHloProfile != other.XlaHloProfile) return false; + if(!xlaDisableHloPasses_.Equals(other.xlaDisableHloPasses_)) return false; + if(!xlaEnableHloPassesOnly_.Equals(other.xlaEnableHloPassesOnly_)) return false; + if (XlaDisableAllHloPasses != other.XlaDisableAllHloPasses) return false; + if (XlaBackendOptimizationLevel != other.XlaBackendOptimizationLevel) return false; + if (XlaEmbedIrInExecutable != other.XlaEmbedIrInExecutable) return false; + if (XlaEliminateHloImplicitBroadcast != other.XlaEliminateHloImplicitBroadcast) return false; + if (XlaCpuMultiThreadEigen != other.XlaCpuMultiThreadEigen) return false; + if (XlaGpuCudaDataDir != other.XlaGpuCudaDataDir) return false; + if (XlaGpuFtz != other.XlaGpuFtz) return false; + if (XlaLlvmEnableAliasScopeMetadata != other.XlaLlvmEnableAliasScopeMetadata) return false; + if (XlaLlvmEnableNoaliasMetadata != other.XlaLlvmEnableNoaliasMetadata) return false; + if (XlaLlvmEnableInvariantLoadMetadata != other.XlaLlvmEnableInvariantLoadMetadata) return false; + if (XlaLlvmDisableExpensivePasses != other.XlaLlvmDisableExpensivePasses) return false; + if (XlaTestAllOutputLayouts != other.XlaTestAllOutputLayouts) return false; + if (XlaTestAllInputLayouts != other.XlaTestAllInputLayouts) return false; + if (XlaHloGraphShardingColor != other.XlaHloGraphShardingColor) return false; + if (XlaCpuUseMklDnn != other.XlaCpuUseMklDnn) return false; + if (XlaCpuUseXlaRuntime != other.XlaCpuUseXlaRuntime) return false; + if (XlaGpuMaxKernelUnrollFactor != other.XlaGpuMaxKernelUnrollFactor) return false; + if (XlaCpuEnableFastMath != other.XlaCpuEnableFastMath) return false; + if (XlaCpuFastMathHonorNans != other.XlaCpuFastMathHonorNans) return false; + if (XlaCpuFastMathHonorInfs != other.XlaCpuFastMathHonorInfs) return false; + if (XlaCpuFastMathHonorDivision != other.XlaCpuFastMathHonorDivision) return false; + if (XlaCpuFastMathHonorFunctions != other.XlaCpuFastMathHonorFunctions) return false; + if (XlaCpuEnableFastMinMax != other.XlaCpuEnableFastMinMax) return false; + if (XlaGpuEnableFastMinMax != other.XlaGpuEnableFastMinMax) return false; + if (XlaAllowExcessPrecision != other.XlaAllowExcessPrecision) return false; + if (XlaGpuCrashOnVerificationFailures != other.XlaGpuCrashOnVerificationFailures) return false; + if (XlaGpuAutotuneLevel != other.XlaGpuAutotuneLevel) return false; + if (XlaForceHostPlatformDeviceCount != other.XlaForceHostPlatformDeviceCount) return false; + if (XlaGpuDisableGpuasmOptimizations != other.XlaGpuDisableGpuasmOptimizations) return false; + if (XlaGpuShapeChecks != other.XlaGpuShapeChecks) return false; + if (XlaCpuEnableMlirLowering != other.XlaCpuEnableMlirLowering) return false; + if (XlaGpuEnableMlirLowering != other.XlaGpuEnableMlirLowering) return false; + if (XlaHloEvaluatorUseFastPath != other.XlaHloEvaluatorUseFastPath) return false; + if (XlaAllowScalarIndexDynamicOps != other.XlaAllowScalarIndexDynamicOps) return false; + if (XlaStepMarkerLocation != other.XlaStepMarkerLocation) return false; + if (XlaDumpTo != other.XlaDumpTo) return false; + if (XlaDumpHloModuleRe != other.XlaDumpHloModuleRe) return false; + if (XlaDumpHloPassRe != other.XlaDumpHloPassRe) return false; + if (XlaDumpHloAsText != other.XlaDumpHloAsText) return false; + if (XlaDumpHloAsProto != other.XlaDumpHloAsProto) return false; + if (XlaDumpHloAsDot != other.XlaDumpHloAsDot) return false; + if (XlaDumpHloAsUrl != other.XlaDumpHloAsUrl) return false; + if (XlaDumpHloAsHtml != other.XlaDumpHloAsHtml) return false; + if (XlaDumpFusionVisualization != other.XlaDumpFusionVisualization) return false; + if (XlaDumpHloSnapshots != other.XlaDumpHloSnapshots) return false; + if (XlaDumpIncludeTimestamp != other.XlaDumpIncludeTimestamp) return false; + if (XlaDumpMaxHloModules != other.XlaDumpMaxHloModules) return false; + if (XlaDumpModuleMetadata != other.XlaDumpModuleMetadata) return false; + if (XlaDumpCompressProtos != other.XlaDumpCompressProtos) return false; + if (XlaDumpHloAsLongText != other.XlaDumpHloAsLongText) return false; + if (XlaGpuForceConvNchw != other.XlaGpuForceConvNchw) return false; + if (XlaGpuForceConvNhwc != other.XlaGpuForceConvNhwc) return false; + if(!xlaGpuPtxFile_.Equals(other.xlaGpuPtxFile_)) return false; + if (XlaGpuDumpLlvmir != other.XlaGpuDumpLlvmir) return false; + if (XlaGpuAlgorithmDenylistPath != other.XlaGpuAlgorithmDenylistPath) return false; + if (XlaTpuDetectNan != other.XlaTpuDetectNan) return false; + if (XlaTpuDetectInf != other.XlaTpuDetectInf) return false; + if (XlaCpuEnableXprofTraceme != other.XlaCpuEnableXprofTraceme) return false; + if (XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != other.XlaGpuUnsafeFallbackToDriverOnPtxasNotFound) return false; + if (XlaGpuAsmExtraFlags != other.XlaGpuAsmExtraFlags) return false; + if (XlaMultiheapSizeConstraintPerHeap != other.XlaMultiheapSizeConstraintPerHeap) return false; + if (XlaDetailedLoggingAndDumping != other.XlaDetailedLoggingAndDumping) return false; + if (XlaGpuForceCompilationParallelism != other.XlaGpuForceCompilationParallelism) return false; + if (XlaGpuDeterministicOps != other.XlaGpuDeterministicOps) return false; + if(!xlaGpuLlvmIrFile_.Equals(other.xlaGpuLlvmIrFile_)) return false; + if (XlaGpuEnableAsyncAllReduce != other.XlaGpuEnableAsyncAllReduce) return false; + if (XlaGpuAllReduceCombineThresholdBytes != other.XlaGpuAllReduceCombineThresholdBytes) return false; + if (XlaGpuAllReduceContiguous != other.XlaGpuAllReduceContiguous) return false; + if (XlaGpuAllReduceBlueconnectNumDevicesPerHost != other.XlaGpuAllReduceBlueconnectNumDevicesPerHost) return false; + if (XlaGpuEnableCudnnFrontend != other.XlaGpuEnableCudnnFrontend) return false; + if (XlaDumpDisableMetadata != other.XlaDumpDisableMetadata) return false; + if (XlaDumpHloPipelineRe != other.XlaDumpHloPipelineRe) return false; + if (XlaGpuStrictConvAlgorithmPicker != other.XlaGpuStrictConvAlgorithmPicker) return false; + if (XlaGpuEnableXlaRuntimeExecutable != other.XlaGpuEnableXlaRuntimeExecutable) return false; + if (XlaGpuNcclTerminationTimeoutSeconds != other.XlaGpuNcclTerminationTimeoutSeconds) return false; + if (XlaGpuEnableSharedConstants != other.XlaGpuEnableSharedConstants) return false; + if (XlaGpuEnableCublaslt != other.XlaGpuEnableCublaslt) return false; + if (XlaGpuRedzoneScratchMaxMegabytes != other.XlaGpuRedzoneScratchMaxMegabytes) return false; + if (XlaGpuSimplifyAllFpConversions != other.XlaGpuSimplifyAllFpConversions) return false; + if (XlaGpuNormalizeLayouts != other.XlaGpuNormalizeLayouts) return false; + if (XlaCpuUseAcl != other.XlaCpuUseAcl) return false; + if (XlaCpuStrictDotConvMath != other.XlaCpuStrictDotConvMath) return false; + if (!XlaBackendExtraOptions.Equals(other.XlaBackendExtraOptions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (XlaHloGraphAddresses != false) hash ^= XlaHloGraphAddresses.GetHashCode(); + if (XlaHloProfile != false) hash ^= XlaHloProfile.GetHashCode(); + hash ^= xlaDisableHloPasses_.GetHashCode(); + hash ^= xlaEnableHloPassesOnly_.GetHashCode(); + if (XlaDisableAllHloPasses != false) hash ^= XlaDisableAllHloPasses.GetHashCode(); + if (XlaBackendOptimizationLevel != 0) hash ^= XlaBackendOptimizationLevel.GetHashCode(); + if (XlaEmbedIrInExecutable != false) hash ^= XlaEmbedIrInExecutable.GetHashCode(); + if (XlaEliminateHloImplicitBroadcast != false) hash ^= XlaEliminateHloImplicitBroadcast.GetHashCode(); + if (XlaCpuMultiThreadEigen != false) hash ^= XlaCpuMultiThreadEigen.GetHashCode(); + if (XlaGpuCudaDataDir.Length != 0) hash ^= XlaGpuCudaDataDir.GetHashCode(); + if (XlaGpuFtz != false) hash ^= XlaGpuFtz.GetHashCode(); + if (XlaLlvmEnableAliasScopeMetadata != false) hash ^= XlaLlvmEnableAliasScopeMetadata.GetHashCode(); + if (XlaLlvmEnableNoaliasMetadata != false) hash ^= XlaLlvmEnableNoaliasMetadata.GetHashCode(); + if (XlaLlvmEnableInvariantLoadMetadata != false) hash ^= XlaLlvmEnableInvariantLoadMetadata.GetHashCode(); + if (XlaLlvmDisableExpensivePasses != false) hash ^= XlaLlvmDisableExpensivePasses.GetHashCode(); + if (XlaTestAllOutputLayouts != false) hash ^= XlaTestAllOutputLayouts.GetHashCode(); + if (XlaTestAllInputLayouts != false) hash ^= XlaTestAllInputLayouts.GetHashCode(); + if (XlaHloGraphShardingColor != false) hash ^= XlaHloGraphShardingColor.GetHashCode(); + if (XlaCpuUseMklDnn != false) hash ^= XlaCpuUseMklDnn.GetHashCode(); + if (XlaCpuUseXlaRuntime != false) hash ^= XlaCpuUseXlaRuntime.GetHashCode(); + if (XlaGpuMaxKernelUnrollFactor != 0) hash ^= XlaGpuMaxKernelUnrollFactor.GetHashCode(); + if (XlaCpuEnableFastMath != false) hash ^= XlaCpuEnableFastMath.GetHashCode(); + if (XlaCpuFastMathHonorNans != false) hash ^= XlaCpuFastMathHonorNans.GetHashCode(); + if (XlaCpuFastMathHonorInfs != false) hash ^= XlaCpuFastMathHonorInfs.GetHashCode(); + if (XlaCpuFastMathHonorDivision != false) hash ^= XlaCpuFastMathHonorDivision.GetHashCode(); + if (XlaCpuFastMathHonorFunctions != false) hash ^= XlaCpuFastMathHonorFunctions.GetHashCode(); + if (XlaCpuEnableFastMinMax != false) hash ^= XlaCpuEnableFastMinMax.GetHashCode(); + if (XlaGpuEnableFastMinMax != false) hash ^= XlaGpuEnableFastMinMax.GetHashCode(); + if (XlaAllowExcessPrecision != false) hash ^= XlaAllowExcessPrecision.GetHashCode(); + if (XlaGpuCrashOnVerificationFailures != false) hash ^= XlaGpuCrashOnVerificationFailures.GetHashCode(); + if (XlaGpuAutotuneLevel != 0) hash ^= XlaGpuAutotuneLevel.GetHashCode(); + if (XlaForceHostPlatformDeviceCount != 0) hash ^= XlaForceHostPlatformDeviceCount.GetHashCode(); + if (XlaGpuDisableGpuasmOptimizations != false) hash ^= XlaGpuDisableGpuasmOptimizations.GetHashCode(); + if (XlaGpuShapeChecks != global::Xla.DebugOptions.Types.ShapeChecks.Ignore) hash ^= XlaGpuShapeChecks.GetHashCode(); + if (XlaCpuEnableMlirLowering != false) hash ^= XlaCpuEnableMlirLowering.GetHashCode(); + if (XlaGpuEnableMlirLowering != false) hash ^= XlaGpuEnableMlirLowering.GetHashCode(); + if (XlaHloEvaluatorUseFastPath != false) hash ^= XlaHloEvaluatorUseFastPath.GetHashCode(); + if (XlaAllowScalarIndexDynamicOps != false) hash ^= XlaAllowScalarIndexDynamicOps.GetHashCode(); + if (XlaStepMarkerLocation != global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry) hash ^= XlaStepMarkerLocation.GetHashCode(); + if (XlaDumpTo.Length != 0) hash ^= XlaDumpTo.GetHashCode(); + if (XlaDumpHloModuleRe.Length != 0) hash ^= XlaDumpHloModuleRe.GetHashCode(); + if (XlaDumpHloPassRe.Length != 0) hash ^= XlaDumpHloPassRe.GetHashCode(); + if (XlaDumpHloAsText != false) hash ^= XlaDumpHloAsText.GetHashCode(); + if (XlaDumpHloAsProto != false) hash ^= XlaDumpHloAsProto.GetHashCode(); + if (XlaDumpHloAsDot != false) hash ^= XlaDumpHloAsDot.GetHashCode(); + if (XlaDumpHloAsUrl != false) hash ^= XlaDumpHloAsUrl.GetHashCode(); + if (XlaDumpHloAsHtml != false) hash ^= XlaDumpHloAsHtml.GetHashCode(); + if (XlaDumpFusionVisualization != false) hash ^= XlaDumpFusionVisualization.GetHashCode(); + if (XlaDumpHloSnapshots != false) hash ^= XlaDumpHloSnapshots.GetHashCode(); + if (XlaDumpIncludeTimestamp != false) hash ^= XlaDumpIncludeTimestamp.GetHashCode(); + if (XlaDumpMaxHloModules != 0) hash ^= XlaDumpMaxHloModules.GetHashCode(); + if (XlaDumpModuleMetadata != false) hash ^= XlaDumpModuleMetadata.GetHashCode(); + if (XlaDumpCompressProtos != false) hash ^= XlaDumpCompressProtos.GetHashCode(); + if (XlaDumpHloAsLongText != false) hash ^= XlaDumpHloAsLongText.GetHashCode(); + if (XlaGpuForceConvNchw != false) hash ^= XlaGpuForceConvNchw.GetHashCode(); + if (XlaGpuForceConvNhwc != false) hash ^= XlaGpuForceConvNhwc.GetHashCode(); + hash ^= xlaGpuPtxFile_.GetHashCode(); + if (XlaGpuDumpLlvmir != false) hash ^= XlaGpuDumpLlvmir.GetHashCode(); + if (XlaGpuAlgorithmDenylistPath.Length != 0) hash ^= XlaGpuAlgorithmDenylistPath.GetHashCode(); + if (XlaTpuDetectNan != false) hash ^= XlaTpuDetectNan.GetHashCode(); + if (XlaTpuDetectInf != false) hash ^= XlaTpuDetectInf.GetHashCode(); + if (XlaCpuEnableXprofTraceme != false) hash ^= XlaCpuEnableXprofTraceme.GetHashCode(); + if (XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != false) hash ^= XlaGpuUnsafeFallbackToDriverOnPtxasNotFound.GetHashCode(); + if (XlaGpuAsmExtraFlags.Length != 0) hash ^= XlaGpuAsmExtraFlags.GetHashCode(); + if (XlaMultiheapSizeConstraintPerHeap != 0) hash ^= XlaMultiheapSizeConstraintPerHeap.GetHashCode(); + if (XlaDetailedLoggingAndDumping != false) hash ^= XlaDetailedLoggingAndDumping.GetHashCode(); + if (XlaGpuForceCompilationParallelism != 0) hash ^= XlaGpuForceCompilationParallelism.GetHashCode(); + if (XlaGpuDeterministicOps != false) hash ^= XlaGpuDeterministicOps.GetHashCode(); + hash ^= xlaGpuLlvmIrFile_.GetHashCode(); + if (XlaGpuEnableAsyncAllReduce != false) hash ^= XlaGpuEnableAsyncAllReduce.GetHashCode(); + if (XlaGpuAllReduceCombineThresholdBytes != 0L) hash ^= XlaGpuAllReduceCombineThresholdBytes.GetHashCode(); + if (XlaGpuAllReduceContiguous != false) hash ^= XlaGpuAllReduceContiguous.GetHashCode(); + if (XlaGpuAllReduceBlueconnectNumDevicesPerHost != 0) hash ^= XlaGpuAllReduceBlueconnectNumDevicesPerHost.GetHashCode(); + if (XlaGpuEnableCudnnFrontend != false) hash ^= XlaGpuEnableCudnnFrontend.GetHashCode(); + if (XlaDumpDisableMetadata != false) hash ^= XlaDumpDisableMetadata.GetHashCode(); + if (XlaDumpHloPipelineRe.Length != 0) hash ^= XlaDumpHloPipelineRe.GetHashCode(); + if (XlaGpuStrictConvAlgorithmPicker != false) hash ^= XlaGpuStrictConvAlgorithmPicker.GetHashCode(); + if (XlaGpuEnableXlaRuntimeExecutable != false) hash ^= XlaGpuEnableXlaRuntimeExecutable.GetHashCode(); + if (XlaGpuNcclTerminationTimeoutSeconds != 0L) hash ^= XlaGpuNcclTerminationTimeoutSeconds.GetHashCode(); + if (XlaGpuEnableSharedConstants != false) hash ^= XlaGpuEnableSharedConstants.GetHashCode(); + if (XlaGpuEnableCublaslt != false) hash ^= XlaGpuEnableCublaslt.GetHashCode(); + if (XlaGpuRedzoneScratchMaxMegabytes != 0L) hash ^= XlaGpuRedzoneScratchMaxMegabytes.GetHashCode(); + if (XlaGpuSimplifyAllFpConversions != false) hash ^= XlaGpuSimplifyAllFpConversions.GetHashCode(); + if (XlaGpuNormalizeLayouts != false) hash ^= XlaGpuNormalizeLayouts.GetHashCode(); + if (XlaCpuUseAcl != false) hash ^= XlaCpuUseAcl.GetHashCode(); + if (XlaCpuStrictDotConvMath != false) hash ^= XlaCpuStrictDotConvMath.GetHashCode(); + hash ^= XlaBackendExtraOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (XlaHloGraphAddresses != false) { + output.WriteRawTag(16); + output.WriteBool(XlaHloGraphAddresses); + } + if (XlaHloProfile != false) { + output.WriteRawTag(72); + output.WriteBool(XlaHloProfile); + } + xlaDisableHloPasses_.WriteTo(output, _repeated_xlaDisableHloPasses_codec); + if (XlaBackendOptimizationLevel != 0) { + output.WriteRawTag(248, 1); + output.WriteInt32(XlaBackendOptimizationLevel); + } + if (XlaEmbedIrInExecutable != false) { + output.WriteRawTag(136, 2); + output.WriteBool(XlaEmbedIrInExecutable); + } + if (XlaEliminateHloImplicitBroadcast != false) { + output.WriteRawTag(152, 2); + output.WriteBool(XlaEliminateHloImplicitBroadcast); + } + if (XlaCpuMultiThreadEigen != false) { + output.WriteRawTag(224, 3); + output.WriteBool(XlaCpuMultiThreadEigen); + } + if (XlaGpuCudaDataDir.Length != 0) { + output.WriteRawTag(234, 3); + output.WriteString(XlaGpuCudaDataDir); + } + if (XlaGpuFtz != false) { + output.WriteRawTag(240, 3); + output.WriteBool(XlaGpuFtz); + } + if (XlaLlvmEnableAliasScopeMetadata != false) { + output.WriteRawTag(176, 4); + output.WriteBool(XlaLlvmEnableAliasScopeMetadata); + } + if (XlaLlvmEnableNoaliasMetadata != false) { + output.WriteRawTag(184, 4); + output.WriteBool(XlaLlvmEnableNoaliasMetadata); + } + if (XlaLlvmEnableInvariantLoadMetadata != false) { + output.WriteRawTag(192, 4); + output.WriteBool(XlaLlvmEnableInvariantLoadMetadata); + } + if (XlaLlvmDisableExpensivePasses != false) { + output.WriteRawTag(200, 4); + output.WriteBool(XlaLlvmDisableExpensivePasses); + } + if (XlaTestAllOutputLayouts != false) { + output.WriteRawTag(208, 5); + output.WriteBool(XlaTestAllOutputLayouts); + } + if (XlaTestAllInputLayouts != false) { + output.WriteRawTag(216, 5); + output.WriteBool(XlaTestAllInputLayouts); + } + if (XlaHloGraphShardingColor != false) { + output.WriteRawTag(224, 5); + output.WriteBool(XlaHloGraphShardingColor); + } + if (XlaCpuUseMklDnn != false) { + output.WriteRawTag(136, 6); + output.WriteBool(XlaCpuUseMklDnn); + } + if (XlaGpuMaxKernelUnrollFactor != 0) { + output.WriteRawTag(144, 6); + output.WriteInt32(XlaGpuMaxKernelUnrollFactor); + } + if (XlaCpuEnableFastMath != false) { + output.WriteRawTag(152, 6); + output.WriteBool(XlaCpuEnableFastMath); + } + if (XlaGpuEnableFastMinMax != false) { + output.WriteRawTag(160, 6); + output.WriteBool(XlaGpuEnableFastMinMax); + } + if (XlaGpuCrashOnVerificationFailures != false) { + output.WriteRawTag(168, 6); + output.WriteBool(XlaGpuCrashOnVerificationFailures); + } + if (XlaForceHostPlatformDeviceCount != 0) { + output.WriteRawTag(176, 6); + output.WriteInt32(XlaForceHostPlatformDeviceCount); + } + if (XlaGpuDisableGpuasmOptimizations != false) { + output.WriteRawTag(184, 6); + output.WriteBool(XlaGpuDisableGpuasmOptimizations); + } + if (XlaDisableAllHloPasses != false) { + output.WriteRawTag(192, 6); + output.WriteBool(XlaDisableAllHloPasses); + } + if (XlaHloEvaluatorUseFastPath != false) { + output.WriteRawTag(208, 6); + output.WriteBool(XlaHloEvaluatorUseFastPath); + } + if (XlaAllowScalarIndexDynamicOps != false) { + output.WriteRawTag(216, 6); + output.WriteBool(XlaAllowScalarIndexDynamicOps); + } + if (XlaStepMarkerLocation != global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry) { + output.WriteRawTag(224, 6); + output.WriteEnum((int) XlaStepMarkerLocation); + } + if (XlaDumpTo.Length != 0) { + output.WriteRawTag(234, 6); + output.WriteString(XlaDumpTo); + } + if (XlaDumpHloModuleRe.Length != 0) { + output.WriteRawTag(242, 6); + output.WriteString(XlaDumpHloModuleRe); + } + if (XlaDumpHloPassRe.Length != 0) { + output.WriteRawTag(250, 6); + output.WriteString(XlaDumpHloPassRe); + } + if (XlaDumpHloAsText != false) { + output.WriteRawTag(128, 7); + output.WriteBool(XlaDumpHloAsText); + } + if (XlaDumpHloAsProto != false) { + output.WriteRawTag(136, 7); + output.WriteBool(XlaDumpHloAsProto); + } + if (XlaDumpHloAsDot != false) { + output.WriteRawTag(144, 7); + output.WriteBool(XlaDumpHloAsDot); + } + if (XlaDumpHloAsUrl != false) { + output.WriteRawTag(152, 7); + output.WriteBool(XlaDumpHloAsUrl); + } + if (XlaDumpHloAsHtml != false) { + output.WriteRawTag(160, 7); + output.WriteBool(XlaDumpHloAsHtml); + } + if (XlaDumpHloSnapshots != false) { + output.WriteRawTag(176, 7); + output.WriteBool(XlaDumpHloSnapshots); + } + if (XlaCpuFastMathHonorNans != false) { + output.WriteRawTag(192, 7); + output.WriteBool(XlaCpuFastMathHonorNans); + } + if (XlaCpuFastMathHonorInfs != false) { + output.WriteRawTag(200, 7); + output.WriteBool(XlaCpuFastMathHonorInfs); + } + if (XlaAllowExcessPrecision != false) { + output.WriteRawTag(208, 7); + output.WriteBool(XlaAllowExcessPrecision); + } + if (XlaGpuAutotuneLevel != 0) { + output.WriteRawTag(216, 7); + output.WriteInt32(XlaGpuAutotuneLevel); + } + xlaEnableHloPassesOnly_.WriteTo(output, _repeated_xlaEnableHloPassesOnly_codec); + if (XlaGpuForceConvNchw != false) { + output.WriteRawTag(232, 7); + output.WriteBool(XlaGpuForceConvNchw); + } + if (XlaCpuFastMathHonorDivision != false) { + output.WriteRawTag(240, 7); + output.WriteBool(XlaCpuFastMathHonorDivision); + } + xlaGpuPtxFile_.WriteTo(output, _repeated_xlaGpuPtxFile_codec); + if (XlaGpuAlgorithmDenylistPath.Length != 0) { + output.WriteRawTag(130, 8); + output.WriteString(XlaGpuAlgorithmDenylistPath); + } + if (XlaCpuFastMathHonorFunctions != false) { + output.WriteRawTag(136, 8); + output.WriteBool(XlaCpuFastMathHonorFunctions); + } + if (XlaDumpIncludeTimestamp != false) { + output.WriteRawTag(152, 8); + output.WriteBool(XlaDumpIncludeTimestamp); + } + if (XlaDumpMaxHloModules != 0) { + output.WriteRawTag(160, 8); + output.WriteInt32(XlaDumpMaxHloModules); + } + if (XlaTpuDetectNan != false) { + output.WriteRawTag(184, 8); + output.WriteBool(XlaTpuDetectNan); + } + if (XlaTpuDetectInf != false) { + output.WriteRawTag(192, 8); + output.WriteBool(XlaTpuDetectInf); + } + if (XlaCpuEnableXprofTraceme != false) { + output.WriteRawTag(200, 8); + output.WriteBool(XlaCpuEnableXprofTraceme); + } + if (XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != false) { + output.WriteRawTag(208, 8); + output.WriteBool(XlaGpuUnsafeFallbackToDriverOnPtxasNotFound); + } + if (XlaCpuEnableFastMinMax != false) { + output.WriteRawTag(224, 8); + output.WriteBool(XlaCpuEnableFastMinMax); + } + if (XlaGpuAsmExtraFlags.Length != 0) { + output.WriteRawTag(234, 8); + output.WriteString(XlaGpuAsmExtraFlags); + } + if (XlaMultiheapSizeConstraintPerHeap != 0) { + output.WriteRawTag(240, 8); + output.WriteInt32(XlaMultiheapSizeConstraintPerHeap); + } + if (XlaDetailedLoggingAndDumping != false) { + output.WriteRawTag(248, 8); + output.WriteBool(XlaDetailedLoggingAndDumping); + } + if (XlaDumpModuleMetadata != false) { + output.WriteRawTag(128, 9); + output.WriteBool(XlaDumpModuleMetadata); + } + if (XlaGpuForceConvNhwc != false) { + output.WriteRawTag(144, 9); + output.WriteBool(XlaGpuForceConvNhwc); + } + if (XlaGpuForceCompilationParallelism != 0) { + output.WriteRawTag(152, 9); + output.WriteInt32(XlaGpuForceCompilationParallelism); + } + if (XlaGpuDeterministicOps != false) { + output.WriteRawTag(160, 9); + output.WriteBool(XlaGpuDeterministicOps); + } + if (XlaDumpFusionVisualization != false) { + output.WriteRawTag(168, 9); + output.WriteBool(XlaDumpFusionVisualization); + } + xlaGpuLlvmIrFile_.WriteTo(output, _repeated_xlaGpuLlvmIrFile_codec); + if (XlaDumpCompressProtos != false) { + output.WriteRawTag(184, 9); + output.WriteBool(XlaDumpCompressProtos); + } + if (XlaGpuEnableAsyncAllReduce != false) { + output.WriteRawTag(192, 9); + output.WriteBool(XlaGpuEnableAsyncAllReduce); + } + if (XlaDumpDisableMetadata != false) { + output.WriteRawTag(200, 9); + output.WriteBool(XlaDumpDisableMetadata); + } + if (XlaDumpHloPipelineRe.Length != 0) { + output.WriteRawTag(210, 9); + output.WriteString(XlaDumpHloPipelineRe); + } + if (XlaGpuDumpLlvmir != false) { + output.WriteRawTag(216, 9); + output.WriteBool(XlaGpuDumpLlvmir); + } + if (XlaGpuStrictConvAlgorithmPicker != false) { + output.WriteRawTag(224, 9); + output.WriteBool(XlaGpuStrictConvAlgorithmPicker); + } + if (XlaGpuAllReduceCombineThresholdBytes != 0L) { + output.WriteRawTag(232, 9); + output.WriteInt64(XlaGpuAllReduceCombineThresholdBytes); + } + if (XlaGpuAllReduceContiguous != false) { + output.WriteRawTag(240, 9); + output.WriteBool(XlaGpuAllReduceContiguous); + } + if (XlaGpuAllReduceBlueconnectNumDevicesPerHost != 0) { + output.WriteRawTag(248, 9); + output.WriteInt32(XlaGpuAllReduceBlueconnectNumDevicesPerHost); + } + if (XlaGpuEnableCudnnFrontend != false) { + output.WriteRawTag(128, 10); + output.WriteBool(XlaGpuEnableCudnnFrontend); + } + if (XlaGpuNcclTerminationTimeoutSeconds != 0L) { + output.WriteRawTag(152, 10); + output.WriteInt64(XlaGpuNcclTerminationTimeoutSeconds); + } + if (XlaDumpHloAsLongText != false) { + output.WriteRawTag(160, 10); + output.WriteBool(XlaDumpHloAsLongText); + } + if (XlaGpuEnableSharedConstants != false) { + output.WriteRawTag(168, 10); + output.WriteBool(XlaGpuEnableSharedConstants); + } + if (XlaGpuEnableCublaslt != false) { + output.WriteRawTag(176, 10); + output.WriteBool(XlaGpuEnableCublaslt); + } + if (XlaGpuRedzoneScratchMaxMegabytes != 0L) { + output.WriteRawTag(184, 10); + output.WriteInt64(XlaGpuRedzoneScratchMaxMegabytes); + } + if (XlaGpuSimplifyAllFpConversions != false) { + output.WriteRawTag(192, 10); + output.WriteBool(XlaGpuSimplifyAllFpConversions); + } + if (XlaGpuEnableXlaRuntimeExecutable != false) { + output.WriteRawTag(200, 10); + output.WriteBool(XlaGpuEnableXlaRuntimeExecutable); + } + if (XlaGpuShapeChecks != global::Xla.DebugOptions.Types.ShapeChecks.Ignore) { + output.WriteRawTag(208, 10); + output.WriteEnum((int) XlaGpuShapeChecks); + } + if (XlaCpuEnableMlirLowering != false) { + output.WriteRawTag(216, 10); + output.WriteBool(XlaCpuEnableMlirLowering); + } + if (XlaGpuNormalizeLayouts != false) { + output.WriteRawTag(224, 10); + output.WriteBool(XlaGpuNormalizeLayouts); + } + if (XlaGpuEnableMlirLowering != false) { + output.WriteRawTag(232, 10); + output.WriteBool(XlaGpuEnableMlirLowering); + } + if (XlaCpuUseAcl != false) { + output.WriteRawTag(240, 10); + output.WriteBool(XlaCpuUseAcl); + } + if (XlaCpuStrictDotConvMath != false) { + output.WriteRawTag(248, 10); + output.WriteBool(XlaCpuStrictDotConvMath); + } + if (XlaCpuUseXlaRuntime != false) { + output.WriteRawTag(136, 11); + output.WriteBool(XlaCpuUseXlaRuntime); + } + xlaBackendExtraOptions_.WriteTo(output, _map_xlaBackendExtraOptions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (XlaHloGraphAddresses != false) { + output.WriteRawTag(16); + output.WriteBool(XlaHloGraphAddresses); + } + if (XlaHloProfile != false) { + output.WriteRawTag(72); + output.WriteBool(XlaHloProfile); + } + xlaDisableHloPasses_.WriteTo(ref output, _repeated_xlaDisableHloPasses_codec); + if (XlaBackendOptimizationLevel != 0) { + output.WriteRawTag(248, 1); + output.WriteInt32(XlaBackendOptimizationLevel); + } + if (XlaEmbedIrInExecutable != false) { + output.WriteRawTag(136, 2); + output.WriteBool(XlaEmbedIrInExecutable); + } + if (XlaEliminateHloImplicitBroadcast != false) { + output.WriteRawTag(152, 2); + output.WriteBool(XlaEliminateHloImplicitBroadcast); + } + if (XlaCpuMultiThreadEigen != false) { + output.WriteRawTag(224, 3); + output.WriteBool(XlaCpuMultiThreadEigen); + } + if (XlaGpuCudaDataDir.Length != 0) { + output.WriteRawTag(234, 3); + output.WriteString(XlaGpuCudaDataDir); + } + if (XlaGpuFtz != false) { + output.WriteRawTag(240, 3); + output.WriteBool(XlaGpuFtz); + } + if (XlaLlvmEnableAliasScopeMetadata != false) { + output.WriteRawTag(176, 4); + output.WriteBool(XlaLlvmEnableAliasScopeMetadata); + } + if (XlaLlvmEnableNoaliasMetadata != false) { + output.WriteRawTag(184, 4); + output.WriteBool(XlaLlvmEnableNoaliasMetadata); + } + if (XlaLlvmEnableInvariantLoadMetadata != false) { + output.WriteRawTag(192, 4); + output.WriteBool(XlaLlvmEnableInvariantLoadMetadata); + } + if (XlaLlvmDisableExpensivePasses != false) { + output.WriteRawTag(200, 4); + output.WriteBool(XlaLlvmDisableExpensivePasses); + } + if (XlaTestAllOutputLayouts != false) { + output.WriteRawTag(208, 5); + output.WriteBool(XlaTestAllOutputLayouts); + } + if (XlaTestAllInputLayouts != false) { + output.WriteRawTag(216, 5); + output.WriteBool(XlaTestAllInputLayouts); + } + if (XlaHloGraphShardingColor != false) { + output.WriteRawTag(224, 5); + output.WriteBool(XlaHloGraphShardingColor); + } + if (XlaCpuUseMklDnn != false) { + output.WriteRawTag(136, 6); + output.WriteBool(XlaCpuUseMklDnn); + } + if (XlaGpuMaxKernelUnrollFactor != 0) { + output.WriteRawTag(144, 6); + output.WriteInt32(XlaGpuMaxKernelUnrollFactor); + } + if (XlaCpuEnableFastMath != false) { + output.WriteRawTag(152, 6); + output.WriteBool(XlaCpuEnableFastMath); + } + if (XlaGpuEnableFastMinMax != false) { + output.WriteRawTag(160, 6); + output.WriteBool(XlaGpuEnableFastMinMax); + } + if (XlaGpuCrashOnVerificationFailures != false) { + output.WriteRawTag(168, 6); + output.WriteBool(XlaGpuCrashOnVerificationFailures); + } + if (XlaForceHostPlatformDeviceCount != 0) { + output.WriteRawTag(176, 6); + output.WriteInt32(XlaForceHostPlatformDeviceCount); + } + if (XlaGpuDisableGpuasmOptimizations != false) { + output.WriteRawTag(184, 6); + output.WriteBool(XlaGpuDisableGpuasmOptimizations); + } + if (XlaDisableAllHloPasses != false) { + output.WriteRawTag(192, 6); + output.WriteBool(XlaDisableAllHloPasses); + } + if (XlaHloEvaluatorUseFastPath != false) { + output.WriteRawTag(208, 6); + output.WriteBool(XlaHloEvaluatorUseFastPath); + } + if (XlaAllowScalarIndexDynamicOps != false) { + output.WriteRawTag(216, 6); + output.WriteBool(XlaAllowScalarIndexDynamicOps); + } + if (XlaStepMarkerLocation != global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry) { + output.WriteRawTag(224, 6); + output.WriteEnum((int) XlaStepMarkerLocation); + } + if (XlaDumpTo.Length != 0) { + output.WriteRawTag(234, 6); + output.WriteString(XlaDumpTo); + } + if (XlaDumpHloModuleRe.Length != 0) { + output.WriteRawTag(242, 6); + output.WriteString(XlaDumpHloModuleRe); + } + if (XlaDumpHloPassRe.Length != 0) { + output.WriteRawTag(250, 6); + output.WriteString(XlaDumpHloPassRe); + } + if (XlaDumpHloAsText != false) { + output.WriteRawTag(128, 7); + output.WriteBool(XlaDumpHloAsText); + } + if (XlaDumpHloAsProto != false) { + output.WriteRawTag(136, 7); + output.WriteBool(XlaDumpHloAsProto); + } + if (XlaDumpHloAsDot != false) { + output.WriteRawTag(144, 7); + output.WriteBool(XlaDumpHloAsDot); + } + if (XlaDumpHloAsUrl != false) { + output.WriteRawTag(152, 7); + output.WriteBool(XlaDumpHloAsUrl); + } + if (XlaDumpHloAsHtml != false) { + output.WriteRawTag(160, 7); + output.WriteBool(XlaDumpHloAsHtml); + } + if (XlaDumpHloSnapshots != false) { + output.WriteRawTag(176, 7); + output.WriteBool(XlaDumpHloSnapshots); + } + if (XlaCpuFastMathHonorNans != false) { + output.WriteRawTag(192, 7); + output.WriteBool(XlaCpuFastMathHonorNans); + } + if (XlaCpuFastMathHonorInfs != false) { + output.WriteRawTag(200, 7); + output.WriteBool(XlaCpuFastMathHonorInfs); + } + if (XlaAllowExcessPrecision != false) { + output.WriteRawTag(208, 7); + output.WriteBool(XlaAllowExcessPrecision); + } + if (XlaGpuAutotuneLevel != 0) { + output.WriteRawTag(216, 7); + output.WriteInt32(XlaGpuAutotuneLevel); + } + xlaEnableHloPassesOnly_.WriteTo(ref output, _repeated_xlaEnableHloPassesOnly_codec); + if (XlaGpuForceConvNchw != false) { + output.WriteRawTag(232, 7); + output.WriteBool(XlaGpuForceConvNchw); + } + if (XlaCpuFastMathHonorDivision != false) { + output.WriteRawTag(240, 7); + output.WriteBool(XlaCpuFastMathHonorDivision); + } + xlaGpuPtxFile_.WriteTo(ref output, _repeated_xlaGpuPtxFile_codec); + if (XlaGpuAlgorithmDenylistPath.Length != 0) { + output.WriteRawTag(130, 8); + output.WriteString(XlaGpuAlgorithmDenylistPath); + } + if (XlaCpuFastMathHonorFunctions != false) { + output.WriteRawTag(136, 8); + output.WriteBool(XlaCpuFastMathHonorFunctions); + } + if (XlaDumpIncludeTimestamp != false) { + output.WriteRawTag(152, 8); + output.WriteBool(XlaDumpIncludeTimestamp); + } + if (XlaDumpMaxHloModules != 0) { + output.WriteRawTag(160, 8); + output.WriteInt32(XlaDumpMaxHloModules); + } + if (XlaTpuDetectNan != false) { + output.WriteRawTag(184, 8); + output.WriteBool(XlaTpuDetectNan); + } + if (XlaTpuDetectInf != false) { + output.WriteRawTag(192, 8); + output.WriteBool(XlaTpuDetectInf); + } + if (XlaCpuEnableXprofTraceme != false) { + output.WriteRawTag(200, 8); + output.WriteBool(XlaCpuEnableXprofTraceme); + } + if (XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != false) { + output.WriteRawTag(208, 8); + output.WriteBool(XlaGpuUnsafeFallbackToDriverOnPtxasNotFound); + } + if (XlaCpuEnableFastMinMax != false) { + output.WriteRawTag(224, 8); + output.WriteBool(XlaCpuEnableFastMinMax); + } + if (XlaGpuAsmExtraFlags.Length != 0) { + output.WriteRawTag(234, 8); + output.WriteString(XlaGpuAsmExtraFlags); + } + if (XlaMultiheapSizeConstraintPerHeap != 0) { + output.WriteRawTag(240, 8); + output.WriteInt32(XlaMultiheapSizeConstraintPerHeap); + } + if (XlaDetailedLoggingAndDumping != false) { + output.WriteRawTag(248, 8); + output.WriteBool(XlaDetailedLoggingAndDumping); + } + if (XlaDumpModuleMetadata != false) { + output.WriteRawTag(128, 9); + output.WriteBool(XlaDumpModuleMetadata); + } + if (XlaGpuForceConvNhwc != false) { + output.WriteRawTag(144, 9); + output.WriteBool(XlaGpuForceConvNhwc); + } + if (XlaGpuForceCompilationParallelism != 0) { + output.WriteRawTag(152, 9); + output.WriteInt32(XlaGpuForceCompilationParallelism); + } + if (XlaGpuDeterministicOps != false) { + output.WriteRawTag(160, 9); + output.WriteBool(XlaGpuDeterministicOps); + } + if (XlaDumpFusionVisualization != false) { + output.WriteRawTag(168, 9); + output.WriteBool(XlaDumpFusionVisualization); + } + xlaGpuLlvmIrFile_.WriteTo(ref output, _repeated_xlaGpuLlvmIrFile_codec); + if (XlaDumpCompressProtos != false) { + output.WriteRawTag(184, 9); + output.WriteBool(XlaDumpCompressProtos); + } + if (XlaGpuEnableAsyncAllReduce != false) { + output.WriteRawTag(192, 9); + output.WriteBool(XlaGpuEnableAsyncAllReduce); + } + if (XlaDumpDisableMetadata != false) { + output.WriteRawTag(200, 9); + output.WriteBool(XlaDumpDisableMetadata); + } + if (XlaDumpHloPipelineRe.Length != 0) { + output.WriteRawTag(210, 9); + output.WriteString(XlaDumpHloPipelineRe); + } + if (XlaGpuDumpLlvmir != false) { + output.WriteRawTag(216, 9); + output.WriteBool(XlaGpuDumpLlvmir); + } + if (XlaGpuStrictConvAlgorithmPicker != false) { + output.WriteRawTag(224, 9); + output.WriteBool(XlaGpuStrictConvAlgorithmPicker); + } + if (XlaGpuAllReduceCombineThresholdBytes != 0L) { + output.WriteRawTag(232, 9); + output.WriteInt64(XlaGpuAllReduceCombineThresholdBytes); + } + if (XlaGpuAllReduceContiguous != false) { + output.WriteRawTag(240, 9); + output.WriteBool(XlaGpuAllReduceContiguous); + } + if (XlaGpuAllReduceBlueconnectNumDevicesPerHost != 0) { + output.WriteRawTag(248, 9); + output.WriteInt32(XlaGpuAllReduceBlueconnectNumDevicesPerHost); + } + if (XlaGpuEnableCudnnFrontend != false) { + output.WriteRawTag(128, 10); + output.WriteBool(XlaGpuEnableCudnnFrontend); + } + if (XlaGpuNcclTerminationTimeoutSeconds != 0L) { + output.WriteRawTag(152, 10); + output.WriteInt64(XlaGpuNcclTerminationTimeoutSeconds); + } + if (XlaDumpHloAsLongText != false) { + output.WriteRawTag(160, 10); + output.WriteBool(XlaDumpHloAsLongText); + } + if (XlaGpuEnableSharedConstants != false) { + output.WriteRawTag(168, 10); + output.WriteBool(XlaGpuEnableSharedConstants); + } + if (XlaGpuEnableCublaslt != false) { + output.WriteRawTag(176, 10); + output.WriteBool(XlaGpuEnableCublaslt); + } + if (XlaGpuRedzoneScratchMaxMegabytes != 0L) { + output.WriteRawTag(184, 10); + output.WriteInt64(XlaGpuRedzoneScratchMaxMegabytes); + } + if (XlaGpuSimplifyAllFpConversions != false) { + output.WriteRawTag(192, 10); + output.WriteBool(XlaGpuSimplifyAllFpConversions); + } + if (XlaGpuEnableXlaRuntimeExecutable != false) { + output.WriteRawTag(200, 10); + output.WriteBool(XlaGpuEnableXlaRuntimeExecutable); + } + if (XlaGpuShapeChecks != global::Xla.DebugOptions.Types.ShapeChecks.Ignore) { + output.WriteRawTag(208, 10); + output.WriteEnum((int) XlaGpuShapeChecks); + } + if (XlaCpuEnableMlirLowering != false) { + output.WriteRawTag(216, 10); + output.WriteBool(XlaCpuEnableMlirLowering); + } + if (XlaGpuNormalizeLayouts != false) { + output.WriteRawTag(224, 10); + output.WriteBool(XlaGpuNormalizeLayouts); + } + if (XlaGpuEnableMlirLowering != false) { + output.WriteRawTag(232, 10); + output.WriteBool(XlaGpuEnableMlirLowering); + } + if (XlaCpuUseAcl != false) { + output.WriteRawTag(240, 10); + output.WriteBool(XlaCpuUseAcl); + } + if (XlaCpuStrictDotConvMath != false) { + output.WriteRawTag(248, 10); + output.WriteBool(XlaCpuStrictDotConvMath); + } + if (XlaCpuUseXlaRuntime != false) { + output.WriteRawTag(136, 11); + output.WriteBool(XlaCpuUseXlaRuntime); + } + xlaBackendExtraOptions_.WriteTo(ref output, _map_xlaBackendExtraOptions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (XlaHloGraphAddresses != false) { + size += 1 + 1; + } + if (XlaHloProfile != false) { + size += 1 + 1; + } + size += xlaDisableHloPasses_.CalculateSize(_repeated_xlaDisableHloPasses_codec); + size += xlaEnableHloPassesOnly_.CalculateSize(_repeated_xlaEnableHloPassesOnly_codec); + if (XlaDisableAllHloPasses != false) { + size += 2 + 1; + } + if (XlaBackendOptimizationLevel != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaBackendOptimizationLevel); + } + if (XlaEmbedIrInExecutable != false) { + size += 2 + 1; + } + if (XlaEliminateHloImplicitBroadcast != false) { + size += 2 + 1; + } + if (XlaCpuMultiThreadEigen != false) { + size += 2 + 1; + } + if (XlaGpuCudaDataDir.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaGpuCudaDataDir); + } + if (XlaGpuFtz != false) { + size += 2 + 1; + } + if (XlaLlvmEnableAliasScopeMetadata != false) { + size += 2 + 1; + } + if (XlaLlvmEnableNoaliasMetadata != false) { + size += 2 + 1; + } + if (XlaLlvmEnableInvariantLoadMetadata != false) { + size += 2 + 1; + } + if (XlaLlvmDisableExpensivePasses != false) { + size += 2 + 1; + } + if (XlaTestAllOutputLayouts != false) { + size += 2 + 1; + } + if (XlaTestAllInputLayouts != false) { + size += 2 + 1; + } + if (XlaHloGraphShardingColor != false) { + size += 2 + 1; + } + if (XlaCpuUseMklDnn != false) { + size += 2 + 1; + } + if (XlaCpuUseXlaRuntime != false) { + size += 2 + 1; + } + if (XlaGpuMaxKernelUnrollFactor != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaGpuMaxKernelUnrollFactor); + } + if (XlaCpuEnableFastMath != false) { + size += 2 + 1; + } + if (XlaCpuFastMathHonorNans != false) { + size += 2 + 1; + } + if (XlaCpuFastMathHonorInfs != false) { + size += 2 + 1; + } + if (XlaCpuFastMathHonorDivision != false) { + size += 2 + 1; + } + if (XlaCpuFastMathHonorFunctions != false) { + size += 2 + 1; + } + if (XlaCpuEnableFastMinMax != false) { + size += 2 + 1; + } + if (XlaGpuEnableFastMinMax != false) { + size += 2 + 1; + } + if (XlaAllowExcessPrecision != false) { + size += 2 + 1; + } + if (XlaGpuCrashOnVerificationFailures != false) { + size += 2 + 1; + } + if (XlaGpuAutotuneLevel != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaGpuAutotuneLevel); + } + if (XlaForceHostPlatformDeviceCount != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaForceHostPlatformDeviceCount); + } + if (XlaGpuDisableGpuasmOptimizations != false) { + size += 2 + 1; + } + if (XlaGpuShapeChecks != global::Xla.DebugOptions.Types.ShapeChecks.Ignore) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) XlaGpuShapeChecks); + } + if (XlaCpuEnableMlirLowering != false) { + size += 2 + 1; + } + if (XlaGpuEnableMlirLowering != false) { + size += 2 + 1; + } + if (XlaHloEvaluatorUseFastPath != false) { + size += 2 + 1; + } + if (XlaAllowScalarIndexDynamicOps != false) { + size += 2 + 1; + } + if (XlaStepMarkerLocation != global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) XlaStepMarkerLocation); + } + if (XlaDumpTo.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaDumpTo); + } + if (XlaDumpHloModuleRe.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaDumpHloModuleRe); + } + if (XlaDumpHloPassRe.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaDumpHloPassRe); + } + if (XlaDumpHloAsText != false) { + size += 2 + 1; + } + if (XlaDumpHloAsProto != false) { + size += 2 + 1; + } + if (XlaDumpHloAsDot != false) { + size += 2 + 1; + } + if (XlaDumpHloAsUrl != false) { + size += 2 + 1; + } + if (XlaDumpHloAsHtml != false) { + size += 2 + 1; + } + if (XlaDumpFusionVisualization != false) { + size += 2 + 1; + } + if (XlaDumpHloSnapshots != false) { + size += 2 + 1; + } + if (XlaDumpIncludeTimestamp != false) { + size += 2 + 1; + } + if (XlaDumpMaxHloModules != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaDumpMaxHloModules); + } + if (XlaDumpModuleMetadata != false) { + size += 2 + 1; + } + if (XlaDumpCompressProtos != false) { + size += 2 + 1; + } + if (XlaDumpHloAsLongText != false) { + size += 2 + 1; + } + if (XlaGpuForceConvNchw != false) { + size += 2 + 1; + } + if (XlaGpuForceConvNhwc != false) { + size += 2 + 1; + } + size += xlaGpuPtxFile_.CalculateSize(_repeated_xlaGpuPtxFile_codec); + if (XlaGpuDumpLlvmir != false) { + size += 2 + 1; + } + if (XlaGpuAlgorithmDenylistPath.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaGpuAlgorithmDenylistPath); + } + if (XlaTpuDetectNan != false) { + size += 2 + 1; + } + if (XlaTpuDetectInf != false) { + size += 2 + 1; + } + if (XlaCpuEnableXprofTraceme != false) { + size += 2 + 1; + } + if (XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != false) { + size += 2 + 1; + } + if (XlaGpuAsmExtraFlags.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaGpuAsmExtraFlags); + } + if (XlaMultiheapSizeConstraintPerHeap != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaMultiheapSizeConstraintPerHeap); + } + if (XlaDetailedLoggingAndDumping != false) { + size += 2 + 1; + } + if (XlaGpuForceCompilationParallelism != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaGpuForceCompilationParallelism); + } + if (XlaGpuDeterministicOps != false) { + size += 2 + 1; + } + size += xlaGpuLlvmIrFile_.CalculateSize(_repeated_xlaGpuLlvmIrFile_codec); + if (XlaGpuEnableAsyncAllReduce != false) { + size += 2 + 1; + } + if (XlaGpuAllReduceCombineThresholdBytes != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(XlaGpuAllReduceCombineThresholdBytes); + } + if (XlaGpuAllReduceContiguous != false) { + size += 2 + 1; + } + if (XlaGpuAllReduceBlueconnectNumDevicesPerHost != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(XlaGpuAllReduceBlueconnectNumDevicesPerHost); + } + if (XlaGpuEnableCudnnFrontend != false) { + size += 2 + 1; + } + if (XlaDumpDisableMetadata != false) { + size += 2 + 1; + } + if (XlaDumpHloPipelineRe.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(XlaDumpHloPipelineRe); + } + if (XlaGpuStrictConvAlgorithmPicker != false) { + size += 2 + 1; + } + if (XlaGpuEnableXlaRuntimeExecutable != false) { + size += 2 + 1; + } + if (XlaGpuNcclTerminationTimeoutSeconds != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(XlaGpuNcclTerminationTimeoutSeconds); + } + if (XlaGpuEnableSharedConstants != false) { + size += 2 + 1; + } + if (XlaGpuEnableCublaslt != false) { + size += 2 + 1; + } + if (XlaGpuRedzoneScratchMaxMegabytes != 0L) { + size += 2 + pb::CodedOutputStream.ComputeInt64Size(XlaGpuRedzoneScratchMaxMegabytes); + } + if (XlaGpuSimplifyAllFpConversions != false) { + size += 2 + 1; + } + if (XlaGpuNormalizeLayouts != false) { + size += 2 + 1; + } + if (XlaCpuUseAcl != false) { + size += 2 + 1; + } + if (XlaCpuStrictDotConvMath != false) { + size += 2 + 1; + } + size += xlaBackendExtraOptions_.CalculateSize(_map_xlaBackendExtraOptions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DebugOptions other) { + if (other == null) { + return; + } + if (other.XlaHloGraphAddresses != false) { + XlaHloGraphAddresses = other.XlaHloGraphAddresses; + } + if (other.XlaHloProfile != false) { + XlaHloProfile = other.XlaHloProfile; + } + xlaDisableHloPasses_.Add(other.xlaDisableHloPasses_); + xlaEnableHloPassesOnly_.Add(other.xlaEnableHloPassesOnly_); + if (other.XlaDisableAllHloPasses != false) { + XlaDisableAllHloPasses = other.XlaDisableAllHloPasses; + } + if (other.XlaBackendOptimizationLevel != 0) { + XlaBackendOptimizationLevel = other.XlaBackendOptimizationLevel; + } + if (other.XlaEmbedIrInExecutable != false) { + XlaEmbedIrInExecutable = other.XlaEmbedIrInExecutable; + } + if (other.XlaEliminateHloImplicitBroadcast != false) { + XlaEliminateHloImplicitBroadcast = other.XlaEliminateHloImplicitBroadcast; + } + if (other.XlaCpuMultiThreadEigen != false) { + XlaCpuMultiThreadEigen = other.XlaCpuMultiThreadEigen; + } + if (other.XlaGpuCudaDataDir.Length != 0) { + XlaGpuCudaDataDir = other.XlaGpuCudaDataDir; + } + if (other.XlaGpuFtz != false) { + XlaGpuFtz = other.XlaGpuFtz; + } + if (other.XlaLlvmEnableAliasScopeMetadata != false) { + XlaLlvmEnableAliasScopeMetadata = other.XlaLlvmEnableAliasScopeMetadata; + } + if (other.XlaLlvmEnableNoaliasMetadata != false) { + XlaLlvmEnableNoaliasMetadata = other.XlaLlvmEnableNoaliasMetadata; + } + if (other.XlaLlvmEnableInvariantLoadMetadata != false) { + XlaLlvmEnableInvariantLoadMetadata = other.XlaLlvmEnableInvariantLoadMetadata; + } + if (other.XlaLlvmDisableExpensivePasses != false) { + XlaLlvmDisableExpensivePasses = other.XlaLlvmDisableExpensivePasses; + } + if (other.XlaTestAllOutputLayouts != false) { + XlaTestAllOutputLayouts = other.XlaTestAllOutputLayouts; + } + if (other.XlaTestAllInputLayouts != false) { + XlaTestAllInputLayouts = other.XlaTestAllInputLayouts; + } + if (other.XlaHloGraphShardingColor != false) { + XlaHloGraphShardingColor = other.XlaHloGraphShardingColor; + } + if (other.XlaCpuUseMklDnn != false) { + XlaCpuUseMklDnn = other.XlaCpuUseMklDnn; + } + if (other.XlaCpuUseXlaRuntime != false) { + XlaCpuUseXlaRuntime = other.XlaCpuUseXlaRuntime; + } + if (other.XlaGpuMaxKernelUnrollFactor != 0) { + XlaGpuMaxKernelUnrollFactor = other.XlaGpuMaxKernelUnrollFactor; + } + if (other.XlaCpuEnableFastMath != false) { + XlaCpuEnableFastMath = other.XlaCpuEnableFastMath; + } + if (other.XlaCpuFastMathHonorNans != false) { + XlaCpuFastMathHonorNans = other.XlaCpuFastMathHonorNans; + } + if (other.XlaCpuFastMathHonorInfs != false) { + XlaCpuFastMathHonorInfs = other.XlaCpuFastMathHonorInfs; + } + if (other.XlaCpuFastMathHonorDivision != false) { + XlaCpuFastMathHonorDivision = other.XlaCpuFastMathHonorDivision; + } + if (other.XlaCpuFastMathHonorFunctions != false) { + XlaCpuFastMathHonorFunctions = other.XlaCpuFastMathHonorFunctions; + } + if (other.XlaCpuEnableFastMinMax != false) { + XlaCpuEnableFastMinMax = other.XlaCpuEnableFastMinMax; + } + if (other.XlaGpuEnableFastMinMax != false) { + XlaGpuEnableFastMinMax = other.XlaGpuEnableFastMinMax; + } + if (other.XlaAllowExcessPrecision != false) { + XlaAllowExcessPrecision = other.XlaAllowExcessPrecision; + } + if (other.XlaGpuCrashOnVerificationFailures != false) { + XlaGpuCrashOnVerificationFailures = other.XlaGpuCrashOnVerificationFailures; + } + if (other.XlaGpuAutotuneLevel != 0) { + XlaGpuAutotuneLevel = other.XlaGpuAutotuneLevel; + } + if (other.XlaForceHostPlatformDeviceCount != 0) { + XlaForceHostPlatformDeviceCount = other.XlaForceHostPlatformDeviceCount; + } + if (other.XlaGpuDisableGpuasmOptimizations != false) { + XlaGpuDisableGpuasmOptimizations = other.XlaGpuDisableGpuasmOptimizations; + } + if (other.XlaGpuShapeChecks != global::Xla.DebugOptions.Types.ShapeChecks.Ignore) { + XlaGpuShapeChecks = other.XlaGpuShapeChecks; + } + if (other.XlaCpuEnableMlirLowering != false) { + XlaCpuEnableMlirLowering = other.XlaCpuEnableMlirLowering; + } + if (other.XlaGpuEnableMlirLowering != false) { + XlaGpuEnableMlirLowering = other.XlaGpuEnableMlirLowering; + } + if (other.XlaHloEvaluatorUseFastPath != false) { + XlaHloEvaluatorUseFastPath = other.XlaHloEvaluatorUseFastPath; + } + if (other.XlaAllowScalarIndexDynamicOps != false) { + XlaAllowScalarIndexDynamicOps = other.XlaAllowScalarIndexDynamicOps; + } + if (other.XlaStepMarkerLocation != global::Xla.DebugOptions.Types.StepMarkerLocation.StepMarkAtEntry) { + XlaStepMarkerLocation = other.XlaStepMarkerLocation; + } + if (other.XlaDumpTo.Length != 0) { + XlaDumpTo = other.XlaDumpTo; + } + if (other.XlaDumpHloModuleRe.Length != 0) { + XlaDumpHloModuleRe = other.XlaDumpHloModuleRe; + } + if (other.XlaDumpHloPassRe.Length != 0) { + XlaDumpHloPassRe = other.XlaDumpHloPassRe; + } + if (other.XlaDumpHloAsText != false) { + XlaDumpHloAsText = other.XlaDumpHloAsText; + } + if (other.XlaDumpHloAsProto != false) { + XlaDumpHloAsProto = other.XlaDumpHloAsProto; + } + if (other.XlaDumpHloAsDot != false) { + XlaDumpHloAsDot = other.XlaDumpHloAsDot; + } + if (other.XlaDumpHloAsUrl != false) { + XlaDumpHloAsUrl = other.XlaDumpHloAsUrl; + } + if (other.XlaDumpHloAsHtml != false) { + XlaDumpHloAsHtml = other.XlaDumpHloAsHtml; + } + if (other.XlaDumpFusionVisualization != false) { + XlaDumpFusionVisualization = other.XlaDumpFusionVisualization; + } + if (other.XlaDumpHloSnapshots != false) { + XlaDumpHloSnapshots = other.XlaDumpHloSnapshots; + } + if (other.XlaDumpIncludeTimestamp != false) { + XlaDumpIncludeTimestamp = other.XlaDumpIncludeTimestamp; + } + if (other.XlaDumpMaxHloModules != 0) { + XlaDumpMaxHloModules = other.XlaDumpMaxHloModules; + } + if (other.XlaDumpModuleMetadata != false) { + XlaDumpModuleMetadata = other.XlaDumpModuleMetadata; + } + if (other.XlaDumpCompressProtos != false) { + XlaDumpCompressProtos = other.XlaDumpCompressProtos; + } + if (other.XlaDumpHloAsLongText != false) { + XlaDumpHloAsLongText = other.XlaDumpHloAsLongText; + } + if (other.XlaGpuForceConvNchw != false) { + XlaGpuForceConvNchw = other.XlaGpuForceConvNchw; + } + if (other.XlaGpuForceConvNhwc != false) { + XlaGpuForceConvNhwc = other.XlaGpuForceConvNhwc; + } + xlaGpuPtxFile_.Add(other.xlaGpuPtxFile_); + if (other.XlaGpuDumpLlvmir != false) { + XlaGpuDumpLlvmir = other.XlaGpuDumpLlvmir; + } + if (other.XlaGpuAlgorithmDenylistPath.Length != 0) { + XlaGpuAlgorithmDenylistPath = other.XlaGpuAlgorithmDenylistPath; + } + if (other.XlaTpuDetectNan != false) { + XlaTpuDetectNan = other.XlaTpuDetectNan; + } + if (other.XlaTpuDetectInf != false) { + XlaTpuDetectInf = other.XlaTpuDetectInf; + } + if (other.XlaCpuEnableXprofTraceme != false) { + XlaCpuEnableXprofTraceme = other.XlaCpuEnableXprofTraceme; + } + if (other.XlaGpuUnsafeFallbackToDriverOnPtxasNotFound != false) { + XlaGpuUnsafeFallbackToDriverOnPtxasNotFound = other.XlaGpuUnsafeFallbackToDriverOnPtxasNotFound; + } + if (other.XlaGpuAsmExtraFlags.Length != 0) { + XlaGpuAsmExtraFlags = other.XlaGpuAsmExtraFlags; + } + if (other.XlaMultiheapSizeConstraintPerHeap != 0) { + XlaMultiheapSizeConstraintPerHeap = other.XlaMultiheapSizeConstraintPerHeap; + } + if (other.XlaDetailedLoggingAndDumping != false) { + XlaDetailedLoggingAndDumping = other.XlaDetailedLoggingAndDumping; + } + if (other.XlaGpuForceCompilationParallelism != 0) { + XlaGpuForceCompilationParallelism = other.XlaGpuForceCompilationParallelism; + } + if (other.XlaGpuDeterministicOps != false) { + XlaGpuDeterministicOps = other.XlaGpuDeterministicOps; + } + xlaGpuLlvmIrFile_.Add(other.xlaGpuLlvmIrFile_); + if (other.XlaGpuEnableAsyncAllReduce != false) { + XlaGpuEnableAsyncAllReduce = other.XlaGpuEnableAsyncAllReduce; + } + if (other.XlaGpuAllReduceCombineThresholdBytes != 0L) { + XlaGpuAllReduceCombineThresholdBytes = other.XlaGpuAllReduceCombineThresholdBytes; + } + if (other.XlaGpuAllReduceContiguous != false) { + XlaGpuAllReduceContiguous = other.XlaGpuAllReduceContiguous; + } + if (other.XlaGpuAllReduceBlueconnectNumDevicesPerHost != 0) { + XlaGpuAllReduceBlueconnectNumDevicesPerHost = other.XlaGpuAllReduceBlueconnectNumDevicesPerHost; + } + if (other.XlaGpuEnableCudnnFrontend != false) { + XlaGpuEnableCudnnFrontend = other.XlaGpuEnableCudnnFrontend; + } + if (other.XlaDumpDisableMetadata != false) { + XlaDumpDisableMetadata = other.XlaDumpDisableMetadata; + } + if (other.XlaDumpHloPipelineRe.Length != 0) { + XlaDumpHloPipelineRe = other.XlaDumpHloPipelineRe; + } + if (other.XlaGpuStrictConvAlgorithmPicker != false) { + XlaGpuStrictConvAlgorithmPicker = other.XlaGpuStrictConvAlgorithmPicker; + } + if (other.XlaGpuEnableXlaRuntimeExecutable != false) { + XlaGpuEnableXlaRuntimeExecutable = other.XlaGpuEnableXlaRuntimeExecutable; + } + if (other.XlaGpuNcclTerminationTimeoutSeconds != 0L) { + XlaGpuNcclTerminationTimeoutSeconds = other.XlaGpuNcclTerminationTimeoutSeconds; + } + if (other.XlaGpuEnableSharedConstants != false) { + XlaGpuEnableSharedConstants = other.XlaGpuEnableSharedConstants; + } + if (other.XlaGpuEnableCublaslt != false) { + XlaGpuEnableCublaslt = other.XlaGpuEnableCublaslt; + } + if (other.XlaGpuRedzoneScratchMaxMegabytes != 0L) { + XlaGpuRedzoneScratchMaxMegabytes = other.XlaGpuRedzoneScratchMaxMegabytes; + } + if (other.XlaGpuSimplifyAllFpConversions != false) { + XlaGpuSimplifyAllFpConversions = other.XlaGpuSimplifyAllFpConversions; + } + if (other.XlaGpuNormalizeLayouts != false) { + XlaGpuNormalizeLayouts = other.XlaGpuNormalizeLayouts; + } + if (other.XlaCpuUseAcl != false) { + XlaCpuUseAcl = other.XlaCpuUseAcl; + } + if (other.XlaCpuStrictDotConvMath != false) { + XlaCpuStrictDotConvMath = other.XlaCpuStrictDotConvMath; + } + xlaBackendExtraOptions_.Add(other.xlaBackendExtraOptions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 16: { + XlaHloGraphAddresses = input.ReadBool(); + break; + } + case 72: { + XlaHloProfile = input.ReadBool(); + break; + } + case 242: { + xlaDisableHloPasses_.AddEntriesFrom(input, _repeated_xlaDisableHloPasses_codec); + break; + } + case 248: { + XlaBackendOptimizationLevel = input.ReadInt32(); + break; + } + case 264: { + XlaEmbedIrInExecutable = input.ReadBool(); + break; + } + case 280: { + XlaEliminateHloImplicitBroadcast = input.ReadBool(); + break; + } + case 480: { + XlaCpuMultiThreadEigen = input.ReadBool(); + break; + } + case 490: { + XlaGpuCudaDataDir = input.ReadString(); + break; + } + case 496: { + XlaGpuFtz = input.ReadBool(); + break; + } + case 560: { + XlaLlvmEnableAliasScopeMetadata = input.ReadBool(); + break; + } + case 568: { + XlaLlvmEnableNoaliasMetadata = input.ReadBool(); + break; + } + case 576: { + XlaLlvmEnableInvariantLoadMetadata = input.ReadBool(); + break; + } + case 584: { + XlaLlvmDisableExpensivePasses = input.ReadBool(); + break; + } + case 720: { + XlaTestAllOutputLayouts = input.ReadBool(); + break; + } + case 728: { + XlaTestAllInputLayouts = input.ReadBool(); + break; + } + case 736: { + XlaHloGraphShardingColor = input.ReadBool(); + break; + } + case 776: { + XlaCpuUseMklDnn = input.ReadBool(); + break; + } + case 784: { + XlaGpuMaxKernelUnrollFactor = input.ReadInt32(); + break; + } + case 792: { + XlaCpuEnableFastMath = input.ReadBool(); + break; + } + case 800: { + XlaGpuEnableFastMinMax = input.ReadBool(); + break; + } + case 808: { + XlaGpuCrashOnVerificationFailures = input.ReadBool(); + break; + } + case 816: { + XlaForceHostPlatformDeviceCount = input.ReadInt32(); + break; + } + case 824: { + XlaGpuDisableGpuasmOptimizations = input.ReadBool(); + break; + } + case 832: { + XlaDisableAllHloPasses = input.ReadBool(); + break; + } + case 848: { + XlaHloEvaluatorUseFastPath = input.ReadBool(); + break; + } + case 856: { + XlaAllowScalarIndexDynamicOps = input.ReadBool(); + break; + } + case 864: { + XlaStepMarkerLocation = (global::Xla.DebugOptions.Types.StepMarkerLocation) input.ReadEnum(); + break; + } + case 874: { + XlaDumpTo = input.ReadString(); + break; + } + case 882: { + XlaDumpHloModuleRe = input.ReadString(); + break; + } + case 890: { + XlaDumpHloPassRe = input.ReadString(); + break; + } + case 896: { + XlaDumpHloAsText = input.ReadBool(); + break; + } + case 904: { + XlaDumpHloAsProto = input.ReadBool(); + break; + } + case 912: { + XlaDumpHloAsDot = input.ReadBool(); + break; + } + case 920: { + XlaDumpHloAsUrl = input.ReadBool(); + break; + } + case 928: { + XlaDumpHloAsHtml = input.ReadBool(); + break; + } + case 944: { + XlaDumpHloSnapshots = input.ReadBool(); + break; + } + case 960: { + XlaCpuFastMathHonorNans = input.ReadBool(); + break; + } + case 968: { + XlaCpuFastMathHonorInfs = input.ReadBool(); + break; + } + case 976: { + XlaAllowExcessPrecision = input.ReadBool(); + break; + } + case 984: { + XlaGpuAutotuneLevel = input.ReadInt32(); + break; + } + case 994: { + xlaEnableHloPassesOnly_.AddEntriesFrom(input, _repeated_xlaEnableHloPassesOnly_codec); + break; + } + case 1000: { + XlaGpuForceConvNchw = input.ReadBool(); + break; + } + case 1008: { + XlaCpuFastMathHonorDivision = input.ReadBool(); + break; + } + case 1018: { + xlaGpuPtxFile_.AddEntriesFrom(input, _repeated_xlaGpuPtxFile_codec); + break; + } + case 1026: { + XlaGpuAlgorithmDenylistPath = input.ReadString(); + break; + } + case 1032: { + XlaCpuFastMathHonorFunctions = input.ReadBool(); + break; + } + case 1048: { + XlaDumpIncludeTimestamp = input.ReadBool(); + break; + } + case 1056: { + XlaDumpMaxHloModules = input.ReadInt32(); + break; + } + case 1080: { + XlaTpuDetectNan = input.ReadBool(); + break; + } + case 1088: { + XlaTpuDetectInf = input.ReadBool(); + break; + } + case 1096: { + XlaCpuEnableXprofTraceme = input.ReadBool(); + break; + } + case 1104: { + XlaGpuUnsafeFallbackToDriverOnPtxasNotFound = input.ReadBool(); + break; + } + case 1120: { + XlaCpuEnableFastMinMax = input.ReadBool(); + break; + } + case 1130: { + XlaGpuAsmExtraFlags = input.ReadString(); + break; + } + case 1136: { + XlaMultiheapSizeConstraintPerHeap = input.ReadInt32(); + break; + } + case 1144: { + XlaDetailedLoggingAndDumping = input.ReadBool(); + break; + } + case 1152: { + XlaDumpModuleMetadata = input.ReadBool(); + break; + } + case 1168: { + XlaGpuForceConvNhwc = input.ReadBool(); + break; + } + case 1176: { + XlaGpuForceCompilationParallelism = input.ReadInt32(); + break; + } + case 1184: { + XlaGpuDeterministicOps = input.ReadBool(); + break; + } + case 1192: { + XlaDumpFusionVisualization = input.ReadBool(); + break; + } + case 1202: { + xlaGpuLlvmIrFile_.AddEntriesFrom(input, _repeated_xlaGpuLlvmIrFile_codec); + break; + } + case 1208: { + XlaDumpCompressProtos = input.ReadBool(); + break; + } + case 1216: { + XlaGpuEnableAsyncAllReduce = input.ReadBool(); + break; + } + case 1224: { + XlaDumpDisableMetadata = input.ReadBool(); + break; + } + case 1234: { + XlaDumpHloPipelineRe = input.ReadString(); + break; + } + case 1240: { + XlaGpuDumpLlvmir = input.ReadBool(); + break; + } + case 1248: { + XlaGpuStrictConvAlgorithmPicker = input.ReadBool(); + break; + } + case 1256: { + XlaGpuAllReduceCombineThresholdBytes = input.ReadInt64(); + break; + } + case 1264: { + XlaGpuAllReduceContiguous = input.ReadBool(); + break; + } + case 1272: { + XlaGpuAllReduceBlueconnectNumDevicesPerHost = input.ReadInt32(); + break; + } + case 1280: { + XlaGpuEnableCudnnFrontend = input.ReadBool(); + break; + } + case 1304: { + XlaGpuNcclTerminationTimeoutSeconds = input.ReadInt64(); + break; + } + case 1312: { + XlaDumpHloAsLongText = input.ReadBool(); + break; + } + case 1320: { + XlaGpuEnableSharedConstants = input.ReadBool(); + break; + } + case 1328: { + XlaGpuEnableCublaslt = input.ReadBool(); + break; + } + case 1336: { + XlaGpuRedzoneScratchMaxMegabytes = input.ReadInt64(); + break; + } + case 1344: { + XlaGpuSimplifyAllFpConversions = input.ReadBool(); + break; + } + case 1352: { + XlaGpuEnableXlaRuntimeExecutable = input.ReadBool(); + break; + } + case 1360: { + XlaGpuShapeChecks = (global::Xla.DebugOptions.Types.ShapeChecks) input.ReadEnum(); + break; + } + case 1368: { + XlaCpuEnableMlirLowering = input.ReadBool(); + break; + } + case 1376: { + XlaGpuNormalizeLayouts = input.ReadBool(); + break; + } + case 1384: { + XlaGpuEnableMlirLowering = input.ReadBool(); + break; + } + case 1392: { + XlaCpuUseAcl = input.ReadBool(); + break; + } + case 1400: { + XlaCpuStrictDotConvMath = input.ReadBool(); + break; + } + case 1416: { + XlaCpuUseXlaRuntime = input.ReadBool(); + break; + } + case 4002: { + xlaBackendExtraOptions_.AddEntriesFrom(input, _map_xlaBackendExtraOptions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 16: { + XlaHloGraphAddresses = input.ReadBool(); + break; + } + case 72: { + XlaHloProfile = input.ReadBool(); + break; + } + case 242: { + xlaDisableHloPasses_.AddEntriesFrom(ref input, _repeated_xlaDisableHloPasses_codec); + break; + } + case 248: { + XlaBackendOptimizationLevel = input.ReadInt32(); + break; + } + case 264: { + XlaEmbedIrInExecutable = input.ReadBool(); + break; + } + case 280: { + XlaEliminateHloImplicitBroadcast = input.ReadBool(); + break; + } + case 480: { + XlaCpuMultiThreadEigen = input.ReadBool(); + break; + } + case 490: { + XlaGpuCudaDataDir = input.ReadString(); + break; + } + case 496: { + XlaGpuFtz = input.ReadBool(); + break; + } + case 560: { + XlaLlvmEnableAliasScopeMetadata = input.ReadBool(); + break; + } + case 568: { + XlaLlvmEnableNoaliasMetadata = input.ReadBool(); + break; + } + case 576: { + XlaLlvmEnableInvariantLoadMetadata = input.ReadBool(); + break; + } + case 584: { + XlaLlvmDisableExpensivePasses = input.ReadBool(); + break; + } + case 720: { + XlaTestAllOutputLayouts = input.ReadBool(); + break; + } + case 728: { + XlaTestAllInputLayouts = input.ReadBool(); + break; + } + case 736: { + XlaHloGraphShardingColor = input.ReadBool(); + break; + } + case 776: { + XlaCpuUseMklDnn = input.ReadBool(); + break; + } + case 784: { + XlaGpuMaxKernelUnrollFactor = input.ReadInt32(); + break; + } + case 792: { + XlaCpuEnableFastMath = input.ReadBool(); + break; + } + case 800: { + XlaGpuEnableFastMinMax = input.ReadBool(); + break; + } + case 808: { + XlaGpuCrashOnVerificationFailures = input.ReadBool(); + break; + } + case 816: { + XlaForceHostPlatformDeviceCount = input.ReadInt32(); + break; + } + case 824: { + XlaGpuDisableGpuasmOptimizations = input.ReadBool(); + break; + } + case 832: { + XlaDisableAllHloPasses = input.ReadBool(); + break; + } + case 848: { + XlaHloEvaluatorUseFastPath = input.ReadBool(); + break; + } + case 856: { + XlaAllowScalarIndexDynamicOps = input.ReadBool(); + break; + } + case 864: { + XlaStepMarkerLocation = (global::Xla.DebugOptions.Types.StepMarkerLocation) input.ReadEnum(); + break; + } + case 874: { + XlaDumpTo = input.ReadString(); + break; + } + case 882: { + XlaDumpHloModuleRe = input.ReadString(); + break; + } + case 890: { + XlaDumpHloPassRe = input.ReadString(); + break; + } + case 896: { + XlaDumpHloAsText = input.ReadBool(); + break; + } + case 904: { + XlaDumpHloAsProto = input.ReadBool(); + break; + } + case 912: { + XlaDumpHloAsDot = input.ReadBool(); + break; + } + case 920: { + XlaDumpHloAsUrl = input.ReadBool(); + break; + } + case 928: { + XlaDumpHloAsHtml = input.ReadBool(); + break; + } + case 944: { + XlaDumpHloSnapshots = input.ReadBool(); + break; + } + case 960: { + XlaCpuFastMathHonorNans = input.ReadBool(); + break; + } + case 968: { + XlaCpuFastMathHonorInfs = input.ReadBool(); + break; + } + case 976: { + XlaAllowExcessPrecision = input.ReadBool(); + break; + } + case 984: { + XlaGpuAutotuneLevel = input.ReadInt32(); + break; + } + case 994: { + xlaEnableHloPassesOnly_.AddEntriesFrom(ref input, _repeated_xlaEnableHloPassesOnly_codec); + break; + } + case 1000: { + XlaGpuForceConvNchw = input.ReadBool(); + break; + } + case 1008: { + XlaCpuFastMathHonorDivision = input.ReadBool(); + break; + } + case 1018: { + xlaGpuPtxFile_.AddEntriesFrom(ref input, _repeated_xlaGpuPtxFile_codec); + break; + } + case 1026: { + XlaGpuAlgorithmDenylistPath = input.ReadString(); + break; + } + case 1032: { + XlaCpuFastMathHonorFunctions = input.ReadBool(); + break; + } + case 1048: { + XlaDumpIncludeTimestamp = input.ReadBool(); + break; + } + case 1056: { + XlaDumpMaxHloModules = input.ReadInt32(); + break; + } + case 1080: { + XlaTpuDetectNan = input.ReadBool(); + break; + } + case 1088: { + XlaTpuDetectInf = input.ReadBool(); + break; + } + case 1096: { + XlaCpuEnableXprofTraceme = input.ReadBool(); + break; + } + case 1104: { + XlaGpuUnsafeFallbackToDriverOnPtxasNotFound = input.ReadBool(); + break; + } + case 1120: { + XlaCpuEnableFastMinMax = input.ReadBool(); + break; + } + case 1130: { + XlaGpuAsmExtraFlags = input.ReadString(); + break; + } + case 1136: { + XlaMultiheapSizeConstraintPerHeap = input.ReadInt32(); + break; + } + case 1144: { + XlaDetailedLoggingAndDumping = input.ReadBool(); + break; + } + case 1152: { + XlaDumpModuleMetadata = input.ReadBool(); + break; + } + case 1168: { + XlaGpuForceConvNhwc = input.ReadBool(); + break; + } + case 1176: { + XlaGpuForceCompilationParallelism = input.ReadInt32(); + break; + } + case 1184: { + XlaGpuDeterministicOps = input.ReadBool(); + break; + } + case 1192: { + XlaDumpFusionVisualization = input.ReadBool(); + break; + } + case 1202: { + xlaGpuLlvmIrFile_.AddEntriesFrom(ref input, _repeated_xlaGpuLlvmIrFile_codec); + break; + } + case 1208: { + XlaDumpCompressProtos = input.ReadBool(); + break; + } + case 1216: { + XlaGpuEnableAsyncAllReduce = input.ReadBool(); + break; + } + case 1224: { + XlaDumpDisableMetadata = input.ReadBool(); + break; + } + case 1234: { + XlaDumpHloPipelineRe = input.ReadString(); + break; + } + case 1240: { + XlaGpuDumpLlvmir = input.ReadBool(); + break; + } + case 1248: { + XlaGpuStrictConvAlgorithmPicker = input.ReadBool(); + break; + } + case 1256: { + XlaGpuAllReduceCombineThresholdBytes = input.ReadInt64(); + break; + } + case 1264: { + XlaGpuAllReduceContiguous = input.ReadBool(); + break; + } + case 1272: { + XlaGpuAllReduceBlueconnectNumDevicesPerHost = input.ReadInt32(); + break; + } + case 1280: { + XlaGpuEnableCudnnFrontend = input.ReadBool(); + break; + } + case 1304: { + XlaGpuNcclTerminationTimeoutSeconds = input.ReadInt64(); + break; + } + case 1312: { + XlaDumpHloAsLongText = input.ReadBool(); + break; + } + case 1320: { + XlaGpuEnableSharedConstants = input.ReadBool(); + break; + } + case 1328: { + XlaGpuEnableCublaslt = input.ReadBool(); + break; + } + case 1336: { + XlaGpuRedzoneScratchMaxMegabytes = input.ReadInt64(); + break; + } + case 1344: { + XlaGpuSimplifyAllFpConversions = input.ReadBool(); + break; + } + case 1352: { + XlaGpuEnableXlaRuntimeExecutable = input.ReadBool(); + break; + } + case 1360: { + XlaGpuShapeChecks = (global::Xla.DebugOptions.Types.ShapeChecks) input.ReadEnum(); + break; + } + case 1368: { + XlaCpuEnableMlirLowering = input.ReadBool(); + break; + } + case 1376: { + XlaGpuNormalizeLayouts = input.ReadBool(); + break; + } + case 1384: { + XlaGpuEnableMlirLowering = input.ReadBool(); + break; + } + case 1392: { + XlaCpuUseAcl = input.ReadBool(); + break; + } + case 1400: { + XlaCpuStrictDotConvMath = input.ReadBool(); + break; + } + case 1416: { + XlaCpuUseXlaRuntime = input.ReadBool(); + break; + } + case 4002: { + xlaBackendExtraOptions_.AddEntriesFrom(ref input, _map_xlaBackendExtraOptions_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the DebugOptions message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum ShapeChecks { + /// + /// Do not insert any shape checks for dynamically shaped operations; output + /// buffers might contain garbage data if shapes don't match. + /// + [pbr::OriginalName("IGNORE")] Ignore = 0, + /// + /// Check shapes at runtime, will insert an extra synchronization if shapes + /// cannot be proven correct at compile time. + /// + [pbr::OriginalName("RUNTIME")] Runtime = 1, + /// + /// Will refuse to compile any program where shape correctness can not be + /// established at compile time. + /// + [pbr::OriginalName("COMPILE_TIME")] CompileTime = 2, + } + + public enum StepMarkerLocation { + /// + /// Generate a step marker at the program entry. This handles the case where + /// each step is done by one or multiple program execution(s). Only the first + /// program will be tagged for generating a step marker at the program entry. + /// This is the default. + /// + [pbr::OriginalName("STEP_MARK_AT_ENTRY")] StepMarkAtEntry = 0, + /// + /// Generate a step marker at each iteration of the top level while loop, + /// which is assumed to be a training loop. + /// + [pbr::OriginalName("STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")] StepMarkAtTopLevelWhileLoop = 1, + /// + /// Generate a step marker at each iteration of the second level while loops, + /// which is assumed to be a training or eval loop. + /// + [pbr::OriginalName("STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP")] StepMarkAtSecondLevelWhileLoop = 3, + /// + /// No step marker generated. + /// + [pbr::OriginalName("STEP_MARK_NONE")] StepMarkNone = 2, + } + + } + #endregion + + } + + /// + /// These settings control how XLA compiles and/or runs code. Not all settings + /// will have an effect on every platform. + /// + /// When adding new fields, keep in mind that boolean fields default to false. + /// + public sealed partial class ExecutionOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecutionOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionOptions(ExecutionOptions other) : this() { + shapeWithOutputLayout_ = other.shapeWithOutputLayout_ != null ? other.shapeWithOutputLayout_.Clone() : null; + seed_ = other.seed_; + debugOptions_ = other.debugOptions_ != null ? other.debugOptions_.Clone() : null; + deviceHandles_ = other.deviceHandles_.Clone(); + numReplicas_ = other.numReplicas_; + deviceAssignment_ = other.deviceAssignment_ != null ? other.deviceAssignment_.Clone() : null; + aliasPassthroughParams_ = other.aliasPassthroughParams_; + numPartitions_ = other.numPartitions_; + launchId_ = other.launchId_; + useSpmdPartitioning_ = other.useSpmdPartitioning_; + useAutoSpmdPartitioning_ = other.useAutoSpmdPartitioning_; + autoSpmdPartitioningMeshShape_ = other.autoSpmdPartitioningMeshShape_.Clone(); + autoSpmdPartitioningMeshIds_ = other.autoSpmdPartitioningMeshIds_.Clone(); + deduplicateHlo_ = other.deduplicateHlo_; + allowSpmdShardingPropagationToOutput_ = other.allowSpmdShardingPropagationToOutput_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionOptions Clone() { + return new ExecutionOptions(this); + } + + /// Field number for the "shape_with_output_layout" field. + public const int ShapeWithOutputLayoutFieldNumber = 2; + private global::Xla.ShapeProto shapeWithOutputLayout_; + /// + /// This optional field's layout is used as a hint when storing the output of + /// this computation. Subsequent transfers of this output array to the client + /// may be faster when using this layout. + /// + /// We use a Shape here to accommodate computations that return a tuple. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto ShapeWithOutputLayout { + get { return shapeWithOutputLayout_; } + set { + shapeWithOutputLayout_ = value; + } + } + + /// Field number for the "seed" field. + public const int SeedFieldNumber = 3; + private ulong seed_; + /// + /// Used to seed random-number generators used in this computation. If this is + /// 0, we generate a seed ourselves. + /// + /// TODO(b/32083678): Changing the seed unnecessarily forces a recompilation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ulong Seed { + get { return seed_; } + set { + seed_ = value; + } + } + + /// Field number for the "debug_options" field. + public const int DebugOptionsFieldNumber = 4; + private global::Xla.DebugOptions debugOptions_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DebugOptions DebugOptions { + get { return debugOptions_; } + set { + debugOptions_ = value; + } + } + + /// Field number for the "device_handles" field. + public const int DeviceHandlesFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_deviceHandles_codec + = pb::FieldCodec.ForMessage(42, global::Xla.DeviceHandle.Parser); + private readonly pbc::RepeatedField deviceHandles_ = new pbc::RepeatedField(); + /// + /// This optional field specifies a particular set of devices to run the + /// computation on. The computation will be partitioned across these devices. + /// If not provided, the default device will be chosen. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DeviceHandles { + get { return deviceHandles_; } + } + + /// Field number for the "num_replicas" field. + public const int NumReplicasFieldNumber = 6; + private int numReplicas_; + /// + /// Number of replicas of the computation to run. If zero, uses the default + /// number of replicas for the XLA service. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumReplicas { + get { return numReplicas_; } + set { + numReplicas_ = value; + } + } + + /// Field number for the "device_assignment" field. + public const int DeviceAssignmentFieldNumber = 7; + private global::Xla.DeviceAssignmentProto deviceAssignment_; + /// + /// This optional field specifies the device assignment if known at compile + /// time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceAssignmentProto DeviceAssignment { + get { return deviceAssignment_; } + set { + deviceAssignment_ = value; + } + } + + /// Field number for the "alias_passthrough_params" field. + public const int AliasPassthroughParamsFieldNumber = 8; + private bool aliasPassthroughParams_; + /// + /// Alias input and output buffers for parameters that are passed-through XLA + /// modules without being changed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool AliasPassthroughParams { + get { return aliasPassthroughParams_; } + set { + aliasPassthroughParams_ = value; + } + } + + /// Field number for the "num_partitions" field. + public const int NumPartitionsFieldNumber = 9; + private int numPartitions_; + /// + /// Number of partitions of the computation to run (model parallelism). + /// If zero, uses the default number of partitions for the XLA service. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int NumPartitions { + get { return numPartitions_; } + set { + numPartitions_ = value; + } + } + + /// Field number for the "launch_id" field. + public const int LaunchIdFieldNumber = 10; + private int launchId_; + /// + /// Used to identify a set of programs that should be launch together. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int LaunchId { + get { return launchId_; } + set { + launchId_ = value; + } + } + + /// Field number for the "use_spmd_partitioning" field. + public const int UseSpmdPartitioningFieldNumber = 11; + private bool useSpmdPartitioning_; + /// + /// Indicates whether to use SPMD (true) or MPMD (false) partitioning when + /// num_partitions > 1 and XLA is requested to partition the input program. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseSpmdPartitioning { + get { return useSpmdPartitioning_; } + set { + useSpmdPartitioning_ = value; + } + } + + /// Field number for the "use_auto_spmd_partitioning" field. + public const int UseAutoSpmdPartitioningFieldNumber = 15; + private bool useAutoSpmdPartitioning_; + /// + /// Whether to automatically generate XLA shardings for SPMD partitioner. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UseAutoSpmdPartitioning { + get { return useAutoSpmdPartitioning_; } + set { + useAutoSpmdPartitioning_ = value; + } + } + + /// Field number for the "auto_spmd_partitioning_mesh_shape" field. + public const int AutoSpmdPartitioningMeshShapeFieldNumber = 16; + private static readonly pb::FieldCodec _repeated_autoSpmdPartitioningMeshShape_codec + = pb::FieldCodec.ForInt64(130); + private readonly pbc::RepeatedField autoSpmdPartitioningMeshShape_ = new pbc::RepeatedField(); + /// + /// Device mesh shape used to create the sharding search space when + /// use_auto_spmd_partitioning=true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AutoSpmdPartitioningMeshShape { + get { return autoSpmdPartitioningMeshShape_; } + } + + /// Field number for the "auto_spmd_partitioning_mesh_ids" field. + public const int AutoSpmdPartitioningMeshIdsFieldNumber = 17; + private static readonly pb::FieldCodec _repeated_autoSpmdPartitioningMeshIds_codec + = pb::FieldCodec.ForInt64(138); + private readonly pbc::RepeatedField autoSpmdPartitioningMeshIds_ = new pbc::RepeatedField(); + /// + /// Device mesh ids compatible with the above mesh_shape used when + /// use_auto_spmd_partitioning=true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField AutoSpmdPartitioningMeshIds { + get { return autoSpmdPartitioningMeshIds_; } + } + + /// Field number for the "deduplicate_hlo" field. + public const int DeduplicateHloFieldNumber = 12; + private bool deduplicateHlo_; + /// + /// If set, deduplicate hlo into function calls to reduce binary size. Only + /// works on TPU. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool DeduplicateHlo { + get { return deduplicateHlo_; } + set { + deduplicateHlo_ = value; + } + } + + /// Field number for the "allow_spmd_sharding_propagation_to_output" field. + public const int AllowSpmdShardingPropagationToOutputFieldNumber = 14; + private bool allowSpmdShardingPropagationToOutput_; + /// + /// Allows sharding propagation to propagate to the outputs. This changes the + /// output shape of the computation (which is undesirable), but it can be used + /// to allow to run partial compilation to determine what would be the output + /// sharding of a computation if XLA would be allowed to propagate the sharding + /// which can be used by higher level framework as a way to query intermediate + /// sharding of operations when multiple computation would be chained and + /// merged together. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool AllowSpmdShardingPropagationToOutput { + get { return allowSpmdShardingPropagationToOutput_; } + set { + allowSpmdShardingPropagationToOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecutionOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecutionOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ShapeWithOutputLayout, other.ShapeWithOutputLayout)) return false; + if (Seed != other.Seed) return false; + if (!object.Equals(DebugOptions, other.DebugOptions)) return false; + if(!deviceHandles_.Equals(other.deviceHandles_)) return false; + if (NumReplicas != other.NumReplicas) return false; + if (!object.Equals(DeviceAssignment, other.DeviceAssignment)) return false; + if (AliasPassthroughParams != other.AliasPassthroughParams) return false; + if (NumPartitions != other.NumPartitions) return false; + if (LaunchId != other.LaunchId) return false; + if (UseSpmdPartitioning != other.UseSpmdPartitioning) return false; + if (UseAutoSpmdPartitioning != other.UseAutoSpmdPartitioning) return false; + if(!autoSpmdPartitioningMeshShape_.Equals(other.autoSpmdPartitioningMeshShape_)) return false; + if(!autoSpmdPartitioningMeshIds_.Equals(other.autoSpmdPartitioningMeshIds_)) return false; + if (DeduplicateHlo != other.DeduplicateHlo) return false; + if (AllowSpmdShardingPropagationToOutput != other.AllowSpmdShardingPropagationToOutput) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shapeWithOutputLayout_ != null) hash ^= ShapeWithOutputLayout.GetHashCode(); + if (Seed != 0UL) hash ^= Seed.GetHashCode(); + if (debugOptions_ != null) hash ^= DebugOptions.GetHashCode(); + hash ^= deviceHandles_.GetHashCode(); + if (NumReplicas != 0) hash ^= NumReplicas.GetHashCode(); + if (deviceAssignment_ != null) hash ^= DeviceAssignment.GetHashCode(); + if (AliasPassthroughParams != false) hash ^= AliasPassthroughParams.GetHashCode(); + if (NumPartitions != 0) hash ^= NumPartitions.GetHashCode(); + if (LaunchId != 0) hash ^= LaunchId.GetHashCode(); + if (UseSpmdPartitioning != false) hash ^= UseSpmdPartitioning.GetHashCode(); + if (UseAutoSpmdPartitioning != false) hash ^= UseAutoSpmdPartitioning.GetHashCode(); + hash ^= autoSpmdPartitioningMeshShape_.GetHashCode(); + hash ^= autoSpmdPartitioningMeshIds_.GetHashCode(); + if (DeduplicateHlo != false) hash ^= DeduplicateHlo.GetHashCode(); + if (AllowSpmdShardingPropagationToOutput != false) hash ^= AllowSpmdShardingPropagationToOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shapeWithOutputLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ShapeWithOutputLayout); + } + if (Seed != 0UL) { + output.WriteRawTag(24); + output.WriteUInt64(Seed); + } + if (debugOptions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(DebugOptions); + } + deviceHandles_.WriteTo(output, _repeated_deviceHandles_codec); + if (NumReplicas != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumReplicas); + } + if (deviceAssignment_ != null) { + output.WriteRawTag(58); + output.WriteMessage(DeviceAssignment); + } + if (AliasPassthroughParams != false) { + output.WriteRawTag(64); + output.WriteBool(AliasPassthroughParams); + } + if (NumPartitions != 0) { + output.WriteRawTag(72); + output.WriteInt32(NumPartitions); + } + if (LaunchId != 0) { + output.WriteRawTag(80); + output.WriteInt32(LaunchId); + } + if (UseSpmdPartitioning != false) { + output.WriteRawTag(88); + output.WriteBool(UseSpmdPartitioning); + } + if (DeduplicateHlo != false) { + output.WriteRawTag(96); + output.WriteBool(DeduplicateHlo); + } + if (AllowSpmdShardingPropagationToOutput != false) { + output.WriteRawTag(112); + output.WriteBool(AllowSpmdShardingPropagationToOutput); + } + if (UseAutoSpmdPartitioning != false) { + output.WriteRawTag(120); + output.WriteBool(UseAutoSpmdPartitioning); + } + autoSpmdPartitioningMeshShape_.WriteTo(output, _repeated_autoSpmdPartitioningMeshShape_codec); + autoSpmdPartitioningMeshIds_.WriteTo(output, _repeated_autoSpmdPartitioningMeshIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shapeWithOutputLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ShapeWithOutputLayout); + } + if (Seed != 0UL) { + output.WriteRawTag(24); + output.WriteUInt64(Seed); + } + if (debugOptions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(DebugOptions); + } + deviceHandles_.WriteTo(ref output, _repeated_deviceHandles_codec); + if (NumReplicas != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumReplicas); + } + if (deviceAssignment_ != null) { + output.WriteRawTag(58); + output.WriteMessage(DeviceAssignment); + } + if (AliasPassthroughParams != false) { + output.WriteRawTag(64); + output.WriteBool(AliasPassthroughParams); + } + if (NumPartitions != 0) { + output.WriteRawTag(72); + output.WriteInt32(NumPartitions); + } + if (LaunchId != 0) { + output.WriteRawTag(80); + output.WriteInt32(LaunchId); + } + if (UseSpmdPartitioning != false) { + output.WriteRawTag(88); + output.WriteBool(UseSpmdPartitioning); + } + if (DeduplicateHlo != false) { + output.WriteRawTag(96); + output.WriteBool(DeduplicateHlo); + } + if (AllowSpmdShardingPropagationToOutput != false) { + output.WriteRawTag(112); + output.WriteBool(AllowSpmdShardingPropagationToOutput); + } + if (UseAutoSpmdPartitioning != false) { + output.WriteRawTag(120); + output.WriteBool(UseAutoSpmdPartitioning); + } + autoSpmdPartitioningMeshShape_.WriteTo(ref output, _repeated_autoSpmdPartitioningMeshShape_codec); + autoSpmdPartitioningMeshIds_.WriteTo(ref output, _repeated_autoSpmdPartitioningMeshIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shapeWithOutputLayout_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ShapeWithOutputLayout); + } + if (Seed != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Seed); + } + if (debugOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DebugOptions); + } + size += deviceHandles_.CalculateSize(_repeated_deviceHandles_codec); + if (NumReplicas != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumReplicas); + } + if (deviceAssignment_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceAssignment); + } + if (AliasPassthroughParams != false) { + size += 1 + 1; + } + if (NumPartitions != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumPartitions); + } + if (LaunchId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(LaunchId); + } + if (UseSpmdPartitioning != false) { + size += 1 + 1; + } + if (UseAutoSpmdPartitioning != false) { + size += 1 + 1; + } + size += autoSpmdPartitioningMeshShape_.CalculateSize(_repeated_autoSpmdPartitioningMeshShape_codec); + size += autoSpmdPartitioningMeshIds_.CalculateSize(_repeated_autoSpmdPartitioningMeshIds_codec); + if (DeduplicateHlo != false) { + size += 1 + 1; + } + if (AllowSpmdShardingPropagationToOutput != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecutionOptions other) { + if (other == null) { + return; + } + if (other.shapeWithOutputLayout_ != null) { + if (shapeWithOutputLayout_ == null) { + ShapeWithOutputLayout = new global::Xla.ShapeProto(); + } + ShapeWithOutputLayout.MergeFrom(other.ShapeWithOutputLayout); + } + if (other.Seed != 0UL) { + Seed = other.Seed; + } + if (other.debugOptions_ != null) { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + DebugOptions.MergeFrom(other.DebugOptions); + } + deviceHandles_.Add(other.deviceHandles_); + if (other.NumReplicas != 0) { + NumReplicas = other.NumReplicas; + } + if (other.deviceAssignment_ != null) { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + DeviceAssignment.MergeFrom(other.DeviceAssignment); + } + if (other.AliasPassthroughParams != false) { + AliasPassthroughParams = other.AliasPassthroughParams; + } + if (other.NumPartitions != 0) { + NumPartitions = other.NumPartitions; + } + if (other.LaunchId != 0) { + LaunchId = other.LaunchId; + } + if (other.UseSpmdPartitioning != false) { + UseSpmdPartitioning = other.UseSpmdPartitioning; + } + if (other.UseAutoSpmdPartitioning != false) { + UseAutoSpmdPartitioning = other.UseAutoSpmdPartitioning; + } + autoSpmdPartitioningMeshShape_.Add(other.autoSpmdPartitioningMeshShape_); + autoSpmdPartitioningMeshIds_.Add(other.autoSpmdPartitioningMeshIds_); + if (other.DeduplicateHlo != false) { + DeduplicateHlo = other.DeduplicateHlo; + } + if (other.AllowSpmdShardingPropagationToOutput != false) { + AllowSpmdShardingPropagationToOutput = other.AllowSpmdShardingPropagationToOutput; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: { + if (shapeWithOutputLayout_ == null) { + ShapeWithOutputLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithOutputLayout); + break; + } + case 24: { + Seed = input.ReadUInt64(); + break; + } + case 34: { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + case 42: { + deviceHandles_.AddEntriesFrom(input, _repeated_deviceHandles_codec); + break; + } + case 48: { + NumReplicas = input.ReadInt32(); + break; + } + case 58: { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + input.ReadMessage(DeviceAssignment); + break; + } + case 64: { + AliasPassthroughParams = input.ReadBool(); + break; + } + case 72: { + NumPartitions = input.ReadInt32(); + break; + } + case 80: { + LaunchId = input.ReadInt32(); + break; + } + case 88: { + UseSpmdPartitioning = input.ReadBool(); + break; + } + case 96: { + DeduplicateHlo = input.ReadBool(); + break; + } + case 112: { + AllowSpmdShardingPropagationToOutput = input.ReadBool(); + break; + } + case 120: { + UseAutoSpmdPartitioning = input.ReadBool(); + break; + } + case 130: + case 128: { + autoSpmdPartitioningMeshShape_.AddEntriesFrom(input, _repeated_autoSpmdPartitioningMeshShape_codec); + break; + } + case 138: + case 136: { + autoSpmdPartitioningMeshIds_.AddEntriesFrom(input, _repeated_autoSpmdPartitioningMeshIds_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 18: { + if (shapeWithOutputLayout_ == null) { + ShapeWithOutputLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithOutputLayout); + break; + } + case 24: { + Seed = input.ReadUInt64(); + break; + } + case 34: { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + case 42: { + deviceHandles_.AddEntriesFrom(ref input, _repeated_deviceHandles_codec); + break; + } + case 48: { + NumReplicas = input.ReadInt32(); + break; + } + case 58: { + if (deviceAssignment_ == null) { + DeviceAssignment = new global::Xla.DeviceAssignmentProto(); + } + input.ReadMessage(DeviceAssignment); + break; + } + case 64: { + AliasPassthroughParams = input.ReadBool(); + break; + } + case 72: { + NumPartitions = input.ReadInt32(); + break; + } + case 80: { + LaunchId = input.ReadInt32(); + break; + } + case 88: { + UseSpmdPartitioning = input.ReadBool(); + break; + } + case 96: { + DeduplicateHlo = input.ReadBool(); + break; + } + case 112: { + AllowSpmdShardingPropagationToOutput = input.ReadBool(); + break; + } + case 120: { + UseAutoSpmdPartitioning = input.ReadBool(); + break; + } + case 130: + case 128: { + autoSpmdPartitioningMeshShape_.AddEntriesFrom(ref input, _repeated_autoSpmdPartitioningMeshShape_codec); + break; + } + case 138: + case 136: { + autoSpmdPartitioningMeshIds_.AddEntriesFrom(ref input, _repeated_autoSpmdPartitioningMeshIds_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetDeviceHandlesRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetDeviceHandlesRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesRequest(GetDeviceHandlesRequest other) : this() { + deviceCount_ = other.deviceCount_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesRequest Clone() { + return new GetDeviceHandlesRequest(this); + } + + /// Field number for the "device_count" field. + public const int DeviceCountFieldNumber = 1; + private long deviceCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DeviceCount { + get { return deviceCount_; } + set { + deviceCount_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetDeviceHandlesRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetDeviceHandlesRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (DeviceCount != other.DeviceCount) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (DeviceCount != 0L) hash ^= DeviceCount.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (DeviceCount != 0L) { + output.WriteRawTag(8); + output.WriteInt64(DeviceCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (DeviceCount != 0L) { + output.WriteRawTag(8); + output.WriteInt64(DeviceCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (DeviceCount != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DeviceCount); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetDeviceHandlesRequest other) { + if (other == null) { + return; + } + if (other.DeviceCount != 0L) { + DeviceCount = other.DeviceCount; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + DeviceCount = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + DeviceCount = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetDeviceHandlesResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetDeviceHandlesResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesResponse(GetDeviceHandlesResponse other) : this() { + deviceHandles_ = other.deviceHandles_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetDeviceHandlesResponse Clone() { + return new GetDeviceHandlesResponse(this); + } + + /// Field number for the "device_handles" field. + public const int DeviceHandlesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_deviceHandles_codec + = pb::FieldCodec.ForMessage(10, global::Xla.DeviceHandle.Parser); + private readonly pbc::RepeatedField deviceHandles_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DeviceHandles { + get { return deviceHandles_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetDeviceHandlesResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetDeviceHandlesResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!deviceHandles_.Equals(other.deviceHandles_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= deviceHandles_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + deviceHandles_.WriteTo(output, _repeated_deviceHandles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + deviceHandles_.WriteTo(ref output, _repeated_deviceHandles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += deviceHandles_.CalculateSize(_repeated_deviceHandles_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetDeviceHandlesResponse other) { + if (other == null) { + return; + } + deviceHandles_.Add(other.deviceHandles_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + deviceHandles_.AddEntriesFrom(input, _repeated_deviceHandles_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + deviceHandles_.AddEntriesFrom(ref input, _repeated_deviceHandles_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToClientRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToClientRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientRequest(TransferToClientRequest other) : this() { + data_ = other.data_ != null ? other.data_.Clone() : null; + shapeWithLayout_ = other.shapeWithLayout_ != null ? other.shapeWithLayout_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientRequest Clone() { + return new TransferToClientRequest(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private global::Xla.GlobalDataHandle data_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Data { + get { return data_; } + set { + data_ = value; + } + } + + /// Field number for the "shape_with_layout" field. + public const int ShapeWithLayoutFieldNumber = 2; + private global::Xla.ShapeProto shapeWithLayout_; + /// + /// This optional field directs the service to return the literal in this + /// layout. A shape is used to hold the layout to accommodate tuples. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto ShapeWithLayout { + get { return shapeWithLayout_; } + set { + shapeWithLayout_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToClientRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToClientRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Data, other.Data)) return false; + if (!object.Equals(ShapeWithLayout, other.ShapeWithLayout)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (data_ != null) hash ^= Data.GetHashCode(); + if (shapeWithLayout_ != null) hash ^= ShapeWithLayout.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (shapeWithLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ShapeWithLayout); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (shapeWithLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ShapeWithLayout); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (data_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Data); + } + if (shapeWithLayout_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ShapeWithLayout); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToClientRequest other) { + if (other == null) { + return; + } + if (other.data_ != null) { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + Data.MergeFrom(other.Data); + } + if (other.shapeWithLayout_ != null) { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + ShapeWithLayout.MergeFrom(other.ShapeWithLayout); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + case 18: { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithLayout); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + case 18: { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithLayout); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToClientResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToClientResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientResponse(TransferToClientResponse other) : this() { + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToClientResponse Clone() { + return new TransferToClientResponse(this); + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 1; + private global::Xla.LiteralProto literal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToClientResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToClientResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Literal, other.Literal)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToClientResponse other) { + if (other == null) { + return; + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToServerRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToServerRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerRequest(TransferToServerRequest other) : this() { + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + deviceHandle_ = other.deviceHandle_ != null ? other.deviceHandle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerRequest Clone() { + return new TransferToServerRequest(this); + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 1; + private global::Xla.LiteralProto literal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + /// Field number for the "device_handle" field. + public const int DeviceHandleFieldNumber = 2; + private global::Xla.DeviceHandle deviceHandle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceHandle DeviceHandle { + get { return deviceHandle_; } + set { + deviceHandle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToServerRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToServerRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Literal, other.Literal)) return false; + if (!object.Equals(DeviceHandle, other.DeviceHandle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (deviceHandle_ != null) hash ^= DeviceHandle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (deviceHandle_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (deviceHandle_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (deviceHandle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceHandle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToServerRequest other) { + if (other == null) { + return; + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + if (other.deviceHandle_ != null) { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + DeviceHandle.MergeFrom(other.DeviceHandle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 18: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 18: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToServerResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToServerResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerResponse(TransferToServerResponse other) : this() { + data_ = other.data_ != null ? other.data_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToServerResponse Clone() { + return new TransferToServerResponse(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private global::Xla.GlobalDataHandle data_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Data { + get { return data_; } + set { + data_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToServerResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToServerResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Data, other.Data)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (data_ != null) hash ^= Data.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (data_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Data); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToServerResponse other) { + if (other == null) { + return; + } + if (other.data_ != null) { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + Data.MergeFrom(other.Data); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToInfeedRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToInfeedRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedRequest(TransferToInfeedRequest other) : this() { + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + replicaId_ = other.replicaId_; + deviceHandle_ = other.deviceHandle_ != null ? other.deviceHandle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedRequest Clone() { + return new TransferToInfeedRequest(this); + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 1; + private global::Xla.LiteralProto literal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + /// Field number for the "replica_id" field. + public const int ReplicaIdFieldNumber = 2; + private long replicaId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ReplicaId { + get { return replicaId_; } + set { + replicaId_ = value; + } + } + + /// Field number for the "device_handle" field. + public const int DeviceHandleFieldNumber = 3; + private global::Xla.DeviceHandle deviceHandle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceHandle DeviceHandle { + get { return deviceHandle_; } + set { + deviceHandle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToInfeedRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToInfeedRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Literal, other.Literal)) return false; + if (ReplicaId != other.ReplicaId) return false; + if (!object.Equals(DeviceHandle, other.DeviceHandle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (ReplicaId != 0L) hash ^= ReplicaId.GetHashCode(); + if (deviceHandle_ != null) hash ^= DeviceHandle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (ReplicaId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ReplicaId); + } + if (deviceHandle_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (ReplicaId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ReplicaId); + } + if (deviceHandle_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (ReplicaId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ReplicaId); + } + if (deviceHandle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceHandle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToInfeedRequest other) { + if (other == null) { + return; + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + if (other.ReplicaId != 0L) { + ReplicaId = other.ReplicaId; + } + if (other.deviceHandle_ != null) { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + DeviceHandle.MergeFrom(other.DeviceHandle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 16: { + ReplicaId = input.ReadInt64(); + break; + } + case 26: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + case 16: { + ReplicaId = input.ReadInt64(); + break; + } + case 26: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferToInfeedResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferToInfeedResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedResponse(TransferToInfeedResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferToInfeedResponse Clone() { + return new TransferToInfeedResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferToInfeedResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferToInfeedResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferToInfeedResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class TransferFromOutfeedRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferFromOutfeedRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedRequest(TransferFromOutfeedRequest other) : this() { + shapeWithLayout_ = other.shapeWithLayout_ != null ? other.shapeWithLayout_.Clone() : null; + replicaId_ = other.replicaId_; + deviceHandle_ = other.deviceHandle_ != null ? other.deviceHandle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedRequest Clone() { + return new TransferFromOutfeedRequest(this); + } + + /// Field number for the "shape_with_layout" field. + public const int ShapeWithLayoutFieldNumber = 1; + private global::Xla.ShapeProto shapeWithLayout_; + /// + /// This optional field directs the service to return the literal in this + /// layout. A shape is used to hold the layout to accommodate tuples. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto ShapeWithLayout { + get { return shapeWithLayout_; } + set { + shapeWithLayout_ = value; + } + } + + /// Field number for the "replica_id" field. + public const int ReplicaIdFieldNumber = 2; + private long replicaId_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ReplicaId { + get { return replicaId_; } + set { + replicaId_ = value; + } + } + + /// Field number for the "device_handle" field. + public const int DeviceHandleFieldNumber = 3; + private global::Xla.DeviceHandle deviceHandle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceHandle DeviceHandle { + get { return deviceHandle_; } + set { + deviceHandle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferFromOutfeedRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferFromOutfeedRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ShapeWithLayout, other.ShapeWithLayout)) return false; + if (ReplicaId != other.ReplicaId) return false; + if (!object.Equals(DeviceHandle, other.DeviceHandle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shapeWithLayout_ != null) hash ^= ShapeWithLayout.GetHashCode(); + if (ReplicaId != 0L) hash ^= ReplicaId.GetHashCode(); + if (deviceHandle_ != null) hash ^= DeviceHandle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shapeWithLayout_ != null) { + output.WriteRawTag(10); + output.WriteMessage(ShapeWithLayout); + } + if (ReplicaId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ReplicaId); + } + if (deviceHandle_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shapeWithLayout_ != null) { + output.WriteRawTag(10); + output.WriteMessage(ShapeWithLayout); + } + if (ReplicaId != 0L) { + output.WriteRawTag(16); + output.WriteInt64(ReplicaId); + } + if (deviceHandle_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shapeWithLayout_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ShapeWithLayout); + } + if (ReplicaId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ReplicaId); + } + if (deviceHandle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceHandle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferFromOutfeedRequest other) { + if (other == null) { + return; + } + if (other.shapeWithLayout_ != null) { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + ShapeWithLayout.MergeFrom(other.ShapeWithLayout); + } + if (other.ReplicaId != 0L) { + ReplicaId = other.ReplicaId; + } + if (other.deviceHandle_ != null) { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + DeviceHandle.MergeFrom(other.DeviceHandle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithLayout); + break; + } + case 16: { + ReplicaId = input.ReadInt64(); + break; + } + case 26: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (shapeWithLayout_ == null) { + ShapeWithLayout = new global::Xla.ShapeProto(); + } + input.ReadMessage(ShapeWithLayout); + break; + } + case 16: { + ReplicaId = input.ReadInt64(); + break; + } + case 26: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + } + #endif + + } + + public sealed partial class TransferFromOutfeedResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TransferFromOutfeedResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedResponse(TransferFromOutfeedResponse other) : this() { + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TransferFromOutfeedResponse Clone() { + return new TransferFromOutfeedResponse(this); + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 1; + private global::Xla.LiteralProto literal_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TransferFromOutfeedResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TransferFromOutfeedResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Literal, other.Literal)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TransferFromOutfeedResponse other) { + if (other == null) { + return; + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + } + #endif + + } + + public sealed partial class ResetDeviceRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResetDeviceRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceRequest(ResetDeviceRequest other) : this() { + deviceHandle_ = other.deviceHandle_ != null ? other.deviceHandle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceRequest Clone() { + return new ResetDeviceRequest(this); + } + + /// Field number for the "device_handle" field. + public const int DeviceHandleFieldNumber = 1; + private global::Xla.DeviceHandle deviceHandle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DeviceHandle DeviceHandle { + get { return deviceHandle_; } + set { + deviceHandle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ResetDeviceRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ResetDeviceRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(DeviceHandle, other.DeviceHandle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (deviceHandle_ != null) hash ^= DeviceHandle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (deviceHandle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (deviceHandle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(DeviceHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (deviceHandle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DeviceHandle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ResetDeviceRequest other) { + if (other == null) { + return; + } + if (other.deviceHandle_ != null) { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + DeviceHandle.MergeFrom(other.DeviceHandle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (deviceHandle_ == null) { + DeviceHandle = new global::Xla.DeviceHandle(); + } + input.ReadMessage(DeviceHandle); + break; + } + } + } + } + #endif + + } + + public sealed partial class ResetDeviceResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResetDeviceResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceResponse(ResetDeviceResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ResetDeviceResponse Clone() { + return new ResetDeviceResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ResetDeviceResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ResetDeviceResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ResetDeviceResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class ComputationGraphStatsRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputationGraphStatsRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationGraphStatsRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationGraphStatsRequest(ComputationGraphStatsRequest other) : this() { + computation_ = other.computation_ != null ? other.computation_.Clone() : null; + debugOptions_ = other.debugOptions_ != null ? other.debugOptions_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationGraphStatsRequest Clone() { + return new ComputationGraphStatsRequest(this); + } + + /// Field number for the "computation" field. + public const int ComputationFieldNumber = 1; + private global::Xla.HloModuleProto computation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto Computation { + get { return computation_; } + set { + computation_ = value; + } + } + + /// Field number for the "debug_options" field. + public const int DebugOptionsFieldNumber = 2; + private global::Xla.DebugOptions debugOptions_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.DebugOptions DebugOptions { + get { return debugOptions_; } + set { + debugOptions_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputationGraphStatsRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputationGraphStatsRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Computation, other.Computation)) return false; + if (!object.Equals(DebugOptions, other.DebugOptions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (computation_ != null) hash ^= Computation.GetHashCode(); + if (debugOptions_ != null) hash ^= DebugOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (debugOptions_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DebugOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (debugOptions_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DebugOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (computation_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Computation); + } + if (debugOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DebugOptions); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputationGraphStatsRequest other) { + if (other == null) { + return; + } + if (other.computation_ != null) { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + Computation.MergeFrom(other.Computation); + } + if (other.debugOptions_ != null) { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + DebugOptions.MergeFrom(other.DebugOptions); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (debugOptions_ == null) { + DebugOptions = new global::Xla.DebugOptions(); + } + input.ReadMessage(DebugOptions); + break; + } + } + } + } + #endif + + } + + public sealed partial class ComputationStatsResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputationStatsResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStatsResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStatsResponse(ComputationStatsResponse other) : this() { + stats_ = other.stats_ != null ? other.stats_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStatsResponse Clone() { + return new ComputationStatsResponse(this); + } + + /// Field number for the "stats" field. + public const int StatsFieldNumber = 1; + private global::Xla.ComputationStats stats_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ComputationStats Stats { + get { return stats_; } + set { + stats_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputationStatsResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputationStatsResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Stats, other.Stats)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (stats_ != null) hash ^= Stats.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (stats_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Stats); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (stats_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Stats); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (stats_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Stats); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputationStatsResponse other) { + if (other == null) { + return; + } + if (other.stats_ != null) { + if (stats_ == null) { + Stats = new global::Xla.ComputationStats(); + } + Stats.MergeFrom(other.Stats); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (stats_ == null) { + Stats = new global::Xla.ComputationStats(); + } + input.ReadMessage(Stats); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (stats_ == null) { + Stats = new global::Xla.ComputationStats(); + } + input.ReadMessage(Stats); + break; + } + } + } + } + #endif + + } + + public sealed partial class CreateChannelHandleRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CreateChannelHandleRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleRequest(CreateChannelHandleRequest other) : this() { + channelType_ = other.channelType_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleRequest Clone() { + return new CreateChannelHandleRequest(this); + } + + /// Field number for the "channel_type" field. + public const int ChannelTypeFieldNumber = 1; + private global::Xla.ChannelHandle.Types.ChannelType channelType_ = global::Xla.ChannelHandle.Types.ChannelType.Invalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ChannelHandle.Types.ChannelType ChannelType { + get { return channelType_; } + set { + channelType_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CreateChannelHandleRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CreateChannelHandleRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ChannelType != other.ChannelType) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ChannelType != global::Xla.ChannelHandle.Types.ChannelType.Invalid) hash ^= ChannelType.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ChannelType != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + output.WriteRawTag(8); + output.WriteEnum((int) ChannelType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ChannelType != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + output.WriteRawTag(8); + output.WriteEnum((int) ChannelType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ChannelType != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ChannelType); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CreateChannelHandleRequest other) { + if (other == null) { + return; + } + if (other.ChannelType != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + ChannelType = other.ChannelType; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ChannelType = (global::Xla.ChannelHandle.Types.ChannelType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ChannelType = (global::Xla.ChannelHandle.Types.ChannelType) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + public sealed partial class CreateChannelHandleResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CreateChannelHandleResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[17]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleResponse(CreateChannelHandleResponse other) : this() { + channel_ = other.channel_ != null ? other.channel_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CreateChannelHandleResponse Clone() { + return new CreateChannelHandleResponse(this); + } + + /// Field number for the "channel" field. + public const int ChannelFieldNumber = 1; + private global::Xla.ChannelHandle channel_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ChannelHandle Channel { + get { return channel_; } + set { + channel_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CreateChannelHandleResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CreateChannelHandleResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Channel, other.Channel)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (channel_ != null) hash ^= Channel.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (channel_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Channel); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (channel_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Channel); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (channel_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Channel); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CreateChannelHandleResponse other) { + if (other == null) { + return; + } + if (other.channel_ != null) { + if (channel_ == null) { + Channel = new global::Xla.ChannelHandle(); + } + Channel.MergeFrom(other.Channel); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (channel_ == null) { + Channel = new global::Xla.ChannelHandle(); + } + input.ReadMessage(Channel); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (channel_ == null) { + Channel = new global::Xla.ChannelHandle(); + } + input.ReadMessage(Channel); + break; + } + } + } + } + #endif + + } + + public sealed partial class UnregisterRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnregisterRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[18]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterRequest(UnregisterRequest other) : this() { + data_ = other.data_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterRequest Clone() { + return new UnregisterRequest(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_data_codec + = pb::FieldCodec.ForMessage(10, global::Xla.GlobalDataHandle.Parser); + private readonly pbc::RepeatedField data_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Data { + get { return data_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as UnregisterRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(UnregisterRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!data_.Equals(other.data_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= data_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + data_.WriteTo(output, _repeated_data_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + data_.WriteTo(ref output, _repeated_data_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += data_.CalculateSize(_repeated_data_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(UnregisterRequest other) { + if (other == null) { + return; + } + data_.Add(other.data_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + data_.AddEntriesFrom(input, _repeated_data_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + data_.AddEntriesFrom(ref input, _repeated_data_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class UnregisterResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnregisterResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[19]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterResponse(UnregisterResponse other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnregisterResponse Clone() { + return new UnregisterResponse(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as UnregisterResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(UnregisterResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(UnregisterResponse other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + } + } + } + #endif + + } + + public sealed partial class CompileRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CompileRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[20]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileRequest(CompileRequest other) : this() { + computation_ = other.computation_ != null ? other.computation_.Clone() : null; + executionOptions_ = other.executionOptions_ != null ? other.executionOptions_.Clone() : null; + inputShapeWithLayout_ = other.inputShapeWithLayout_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileRequest Clone() { + return new CompileRequest(this); + } + + /// Field number for the "computation" field. + public const int ComputationFieldNumber = 1; + private global::Xla.HloModuleProto computation_; + /// + /// The graph to be compiled. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto Computation { + get { return computation_; } + set { + computation_ = value; + } + } + + /// Field number for the "execution_options" field. + public const int ExecutionOptionsFieldNumber = 2; + private global::Xla.ExecutionOptions executionOptions_; + /// + /// Options that affect how XLA compiles code to service this request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionOptions ExecutionOptions { + get { return executionOptions_; } + set { + executionOptions_ = value; + } + } + + /// Field number for the "input_shape_with_layout" field. + public const int InputShapeWithLayoutFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_inputShapeWithLayout_codec + = pb::FieldCodec.ForMessage(26, global::Xla.ShapeProto.Parser); + private readonly pbc::RepeatedField inputShapeWithLayout_ = new pbc::RepeatedField(); + /// + /// The layouts of the input arguments. If not set, the default layout will be + /// used. Although the real arguments are not needed in compilation, the + /// layouts of the arguments can affect the compilation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InputShapeWithLayout { + get { return inputShapeWithLayout_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CompileRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CompileRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Computation, other.Computation)) return false; + if (!object.Equals(ExecutionOptions, other.ExecutionOptions)) return false; + if(!inputShapeWithLayout_.Equals(other.inputShapeWithLayout_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (computation_ != null) hash ^= Computation.GetHashCode(); + if (executionOptions_ != null) hash ^= ExecutionOptions.GetHashCode(); + hash ^= inputShapeWithLayout_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (executionOptions_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ExecutionOptions); + } + inputShapeWithLayout_.WriteTo(output, _repeated_inputShapeWithLayout_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (executionOptions_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ExecutionOptions); + } + inputShapeWithLayout_.WriteTo(ref output, _repeated_inputShapeWithLayout_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (computation_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Computation); + } + if (executionOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExecutionOptions); + } + size += inputShapeWithLayout_.CalculateSize(_repeated_inputShapeWithLayout_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CompileRequest other) { + if (other == null) { + return; + } + if (other.computation_ != null) { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + Computation.MergeFrom(other.Computation); + } + if (other.executionOptions_ != null) { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + ExecutionOptions.MergeFrom(other.ExecutionOptions); + } + inputShapeWithLayout_.Add(other.inputShapeWithLayout_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + input.ReadMessage(ExecutionOptions); + break; + } + case 26: { + inputShapeWithLayout_.AddEntriesFrom(input, _repeated_inputShapeWithLayout_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + input.ReadMessage(ExecutionOptions); + break; + } + case 26: { + inputShapeWithLayout_.AddEntriesFrom(ref input, _repeated_inputShapeWithLayout_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class CompileResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CompileResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[21]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileResponse(CompileResponse other) : this() { + handle_ = other.handle_ != null ? other.handle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CompileResponse Clone() { + return new CompileResponse(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private global::Xla.ExecutionHandle handle_; + /// + /// The handle to the executable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionHandle Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CompileResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CompileResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Handle, other.Handle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (handle_ != null) hash ^= Handle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (handle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (handle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (handle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Handle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CompileResponse other) { + if (other == null) { + return; + } + if (other.handle_ != null) { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + Handle.MergeFrom(other.Handle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Handle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Handle); + break; + } + } + } + } + #endif + + } + + public sealed partial class ExecuteRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecuteRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[22]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteRequest(ExecuteRequest other) : this() { + handle_ = other.handle_ != null ? other.handle_.Clone() : null; + arguments_ = other.arguments_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteRequest Clone() { + return new ExecuteRequest(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private global::Xla.ExecutionHandle handle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionHandle Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + /// Field number for the "arguments" field. + public const int ArgumentsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_arguments_codec + = pb::FieldCodec.ForMessage(18, global::Xla.GlobalDataHandle.Parser); + private readonly pbc::RepeatedField arguments_ = new pbc::RepeatedField(); + /// + /// The shape and layout of the arguments must be the same as the those of the + /// executable's parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Arguments { + get { return arguments_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecuteRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecuteRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Handle, other.Handle)) return false; + if(!arguments_.Equals(other.arguments_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (handle_ != null) hash ^= Handle.GetHashCode(); + hash ^= arguments_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (handle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Handle); + } + arguments_.WriteTo(output, _repeated_arguments_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (handle_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Handle); + } + arguments_.WriteTo(ref output, _repeated_arguments_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (handle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Handle); + } + size += arguments_.CalculateSize(_repeated_arguments_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecuteRequest other) { + if (other == null) { + return; + } + if (other.handle_ != null) { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + Handle.MergeFrom(other.Handle); + } + arguments_.Add(other.arguments_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Handle); + break; + } + case 18: { + arguments_.AddEntriesFrom(input, _repeated_arguments_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (handle_ == null) { + Handle = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Handle); + break; + } + case 18: { + arguments_.AddEntriesFrom(ref input, _repeated_arguments_codec); + break; + } + } + } + } + #endif + + } + + /// + /// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace + /// the uses with calls to Compile and Execute. + /// + public sealed partial class ExecuteGraphRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecuteGraphRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[23]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphRequest(ExecuteGraphRequest other) : this() { + computation_ = other.computation_ != null ? other.computation_.Clone() : null; + arguments_ = other.arguments_.Clone(); + executionOptions_ = other.executionOptions_ != null ? other.executionOptions_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphRequest Clone() { + return new ExecuteGraphRequest(this); + } + + /// Field number for the "computation" field. + public const int ComputationFieldNumber = 1; + private global::Xla.HloModuleProto computation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto Computation { + get { return computation_; } + set { + computation_ = value; + } + } + + /// Field number for the "arguments" field. + public const int ArgumentsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_arguments_codec + = pb::FieldCodec.ForMessage(18, global::Xla.GlobalDataHandle.Parser); + private readonly pbc::RepeatedField arguments_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Arguments { + get { return arguments_; } + } + + /// Field number for the "execution_options" field. + public const int ExecutionOptionsFieldNumber = 3; + private global::Xla.ExecutionOptions executionOptions_; + /// + /// Options that affect how XLA compiles and runs code to service this request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionOptions ExecutionOptions { + get { return executionOptions_; } + set { + executionOptions_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecuteGraphRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecuteGraphRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Computation, other.Computation)) return false; + if(!arguments_.Equals(other.arguments_)) return false; + if (!object.Equals(ExecutionOptions, other.ExecutionOptions)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (computation_ != null) hash ^= Computation.GetHashCode(); + hash ^= arguments_.GetHashCode(); + if (executionOptions_ != null) hash ^= ExecutionOptions.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + arguments_.WriteTo(output, _repeated_arguments_codec); + if (executionOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ExecutionOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + arguments_.WriteTo(ref output, _repeated_arguments_codec); + if (executionOptions_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ExecutionOptions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (computation_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Computation); + } + size += arguments_.CalculateSize(_repeated_arguments_codec); + if (executionOptions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExecutionOptions); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecuteGraphRequest other) { + if (other == null) { + return; + } + if (other.computation_ != null) { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + Computation.MergeFrom(other.Computation); + } + arguments_.Add(other.arguments_); + if (other.executionOptions_ != null) { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + ExecutionOptions.MergeFrom(other.ExecutionOptions); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + arguments_.AddEntriesFrom(input, _repeated_arguments_codec); + break; + } + case 26: { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + input.ReadMessage(ExecutionOptions); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + arguments_.AddEntriesFrom(ref input, _repeated_arguments_codec); + break; + } + case 26: { + if (executionOptions_ == null) { + ExecutionOptions = new global::Xla.ExecutionOptions(); + } + input.ReadMessage(ExecutionOptions); + break; + } + } + } + } + #endif + + } + + public sealed partial class ExecuteGraphParallelRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecuteGraphParallelRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[24]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphParallelRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphParallelRequest(ExecuteGraphParallelRequest other) : this() { + requests_ = other.requests_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteGraphParallelRequest Clone() { + return new ExecuteGraphParallelRequest(this); + } + + /// Field number for the "requests" field. + public const int RequestsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_requests_codec + = pb::FieldCodec.ForMessage(10, global::Xla.ExecuteGraphRequest.Parser); + private readonly pbc::RepeatedField requests_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Requests { + get { return requests_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecuteGraphParallelRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecuteGraphParallelRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!requests_.Equals(other.requests_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= requests_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + requests_.WriteTo(output, _repeated_requests_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + requests_.WriteTo(ref output, _repeated_requests_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += requests_.CalculateSize(_repeated_requests_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecuteGraphParallelRequest other) { + if (other == null) { + return; + } + requests_.Add(other.requests_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + requests_.AddEntriesFrom(input, _repeated_requests_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + requests_.AddEntriesFrom(ref input, _repeated_requests_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class ExecuteResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecuteResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[25]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteResponse(ExecuteResponse other) : this() { + output_ = other.output_ != null ? other.output_.Clone() : null; + profile_ = other.profile_ != null ? other.profile_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteResponse Clone() { + return new ExecuteResponse(this); + } + + /// Field number for the "output" field. + public const int OutputFieldNumber = 1; + private global::Xla.GlobalDataHandle output_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Output { + get { return output_; } + set { + output_ = value; + } + } + + /// Field number for the "profile" field. + public const int ProfileFieldNumber = 2; + private global::Xla.ExecutionProfile profile_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionProfile Profile { + get { return profile_; } + set { + profile_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecuteResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecuteResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Output, other.Output)) return false; + if (!object.Equals(Profile, other.Profile)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (output_ != null) hash ^= Output.GetHashCode(); + if (profile_ != null) hash ^= Profile.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (output_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Output); + } + if (profile_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Profile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (output_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Output); + } + if (profile_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Profile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (output_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Output); + } + if (profile_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Profile); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecuteResponse other) { + if (other == null) { + return; + } + if (other.output_ != null) { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + Output.MergeFrom(other.Output); + } + if (other.profile_ != null) { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + Profile.MergeFrom(other.Profile); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Output); + break; + } + case 18: { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + input.ReadMessage(Profile); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Output); + break; + } + case 18: { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + input.ReadMessage(Profile); + break; + } + } + } + } + #endif + + } + + public sealed partial class ExecuteParallelResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecuteParallelResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[26]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteParallelResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteParallelResponse(ExecuteParallelResponse other) : this() { + responses_ = other.responses_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecuteParallelResponse Clone() { + return new ExecuteParallelResponse(this); + } + + /// Field number for the "responses" field. + public const int ResponsesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_responses_codec + = pb::FieldCodec.ForMessage(10, global::Xla.ExecuteResponse.Parser); + private readonly pbc::RepeatedField responses_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Responses { + get { return responses_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecuteParallelResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecuteParallelResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!responses_.Equals(other.responses_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= responses_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + responses_.WriteTo(output, _repeated_responses_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + responses_.WriteTo(ref output, _repeated_responses_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += responses_.CalculateSize(_repeated_responses_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecuteParallelResponse other) { + if (other == null) { + return; + } + responses_.Add(other.responses_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + responses_.AddEntriesFrom(input, _repeated_responses_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + responses_.AddEntriesFrom(ref input, _repeated_responses_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class WaitForExecutionRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitForExecutionRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[27]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionRequest(WaitForExecutionRequest other) : this() { + execution_ = other.execution_ != null ? other.execution_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionRequest Clone() { + return new WaitForExecutionRequest(this); + } + + /// Field number for the "execution" field. + public const int ExecutionFieldNumber = 1; + private global::Xla.ExecutionHandle execution_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionHandle Execution { + get { return execution_; } + set { + execution_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitForExecutionRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitForExecutionRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Execution, other.Execution)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (execution_ != null) hash ^= Execution.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (execution_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Execution); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (execution_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Execution); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (execution_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Execution); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitForExecutionRequest other) { + if (other == null) { + return; + } + if (other.execution_ != null) { + if (execution_ == null) { + Execution = new global::Xla.ExecutionHandle(); + } + Execution.MergeFrom(other.Execution); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (execution_ == null) { + Execution = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Execution); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (execution_ == null) { + Execution = new global::Xla.ExecutionHandle(); + } + input.ReadMessage(Execution); + break; + } + } + } + } + #endif + + } + + public sealed partial class WaitForExecutionResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WaitForExecutionResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[28]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionResponse(WaitForExecutionResponse other) : this() { + output_ = other.output_ != null ? other.output_.Clone() : null; + profile_ = other.profile_ != null ? other.profile_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WaitForExecutionResponse Clone() { + return new WaitForExecutionResponse(this); + } + + /// Field number for the "output" field. + public const int OutputFieldNumber = 1; + private global::Xla.GlobalDataHandle output_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Output { + get { return output_; } + set { + output_ = value; + } + } + + /// Field number for the "profile" field. + public const int ProfileFieldNumber = 2; + private global::Xla.ExecutionProfile profile_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ExecutionProfile Profile { + get { return profile_; } + set { + profile_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WaitForExecutionResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WaitForExecutionResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Output, other.Output)) return false; + if (!object.Equals(Profile, other.Profile)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (output_ != null) hash ^= Output.GetHashCode(); + if (profile_ != null) hash ^= Profile.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (output_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Output); + } + if (profile_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Profile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (output_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Output); + } + if (profile_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Profile); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (output_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Output); + } + if (profile_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Profile); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WaitForExecutionResponse other) { + if (other == null) { + return; + } + if (other.output_ != null) { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + Output.MergeFrom(other.Output); + } + if (other.profile_ != null) { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + Profile.MergeFrom(other.Profile); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Output); + break; + } + case 18: { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + input.ReadMessage(Profile); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (output_ == null) { + Output = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Output); + break; + } + case 18: { + if (profile_ == null) { + Profile = new global::Xla.ExecutionProfile(); + } + input.ReadMessage(Profile); + break; + } + } + } + } + #endif + + } + + public sealed partial class ComputeConstantGraphRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputeConstantGraphRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[29]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantGraphRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantGraphRequest(ComputeConstantGraphRequest other) : this() { + computation_ = other.computation_ != null ? other.computation_.Clone() : null; + outputLayout_ = other.outputLayout_ != null ? other.outputLayout_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantGraphRequest Clone() { + return new ComputeConstantGraphRequest(this); + } + + /// Field number for the "computation" field. + public const int ComputationFieldNumber = 1; + private global::Xla.HloModuleProto computation_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.HloModuleProto Computation { + get { return computation_; } + set { + computation_ = value; + } + } + + /// Field number for the "output_layout" field. + public const int OutputLayoutFieldNumber = 2; + private global::Xla.LayoutProto outputLayout_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LayoutProto OutputLayout { + get { return outputLayout_; } + set { + outputLayout_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputeConstantGraphRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputeConstantGraphRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Computation, other.Computation)) return false; + if (!object.Equals(OutputLayout, other.OutputLayout)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (computation_ != null) hash ^= Computation.GetHashCode(); + if (outputLayout_ != null) hash ^= OutputLayout.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (outputLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(OutputLayout); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (computation_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Computation); + } + if (outputLayout_ != null) { + output.WriteRawTag(18); + output.WriteMessage(OutputLayout); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (computation_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Computation); + } + if (outputLayout_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(OutputLayout); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputeConstantGraphRequest other) { + if (other == null) { + return; + } + if (other.computation_ != null) { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + Computation.MergeFrom(other.Computation); + } + if (other.outputLayout_ != null) { + if (outputLayout_ == null) { + OutputLayout = new global::Xla.LayoutProto(); + } + OutputLayout.MergeFrom(other.OutputLayout); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (outputLayout_ == null) { + OutputLayout = new global::Xla.LayoutProto(); + } + input.ReadMessage(OutputLayout); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (computation_ == null) { + Computation = new global::Xla.HloModuleProto(); + } + input.ReadMessage(Computation); + break; + } + case 18: { + if (outputLayout_ == null) { + OutputLayout = new global::Xla.LayoutProto(); + } + input.ReadMessage(OutputLayout); + break; + } + } + } + } + #endif + + } + + public sealed partial class ComputeConstantResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputeConstantResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[30]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantResponse(ComputeConstantResponse other) : this() { + literal_ = other.literal_ != null ? other.literal_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputeConstantResponse Clone() { + return new ComputeConstantResponse(this); + } + + /// Field number for the "literal" field. + public const int LiteralFieldNumber = 1; + private global::Xla.LiteralProto literal_; + /// + /// A LiteralProto is returned directly for this request. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LiteralProto Literal { + get { return literal_; } + set { + literal_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputeConstantResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputeConstantResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Literal, other.Literal)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (literal_ != null) hash ^= Literal.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (literal_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Literal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (literal_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Literal); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputeConstantResponse other) { + if (other == null) { + return; + } + if (other.literal_ != null) { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + Literal.MergeFrom(other.Literal); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (literal_ == null) { + Literal = new global::Xla.LiteralProto(); + } + input.ReadMessage(Literal); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeconstructTupleRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeconstructTupleRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[31]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleRequest(DeconstructTupleRequest other) : this() { + tupleHandle_ = other.tupleHandle_ != null ? other.tupleHandle_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleRequest Clone() { + return new DeconstructTupleRequest(this); + } + + /// Field number for the "tuple_handle" field. + public const int TupleHandleFieldNumber = 2; + private global::Xla.GlobalDataHandle tupleHandle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle TupleHandle { + get { return tupleHandle_; } + set { + tupleHandle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeconstructTupleRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeconstructTupleRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(TupleHandle, other.TupleHandle)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (tupleHandle_ != null) hash ^= TupleHandle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (tupleHandle_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TupleHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (tupleHandle_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TupleHandle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (tupleHandle_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TupleHandle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeconstructTupleRequest other) { + if (other == null) { + return; + } + if (other.tupleHandle_ != null) { + if (tupleHandle_ == null) { + TupleHandle = new global::Xla.GlobalDataHandle(); + } + TupleHandle.MergeFrom(other.TupleHandle); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: { + if (tupleHandle_ == null) { + TupleHandle = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(TupleHandle); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 18: { + if (tupleHandle_ == null) { + TupleHandle = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(TupleHandle); + break; + } + } + } + } + #endif + + } + + public sealed partial class DeconstructTupleResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeconstructTupleResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[32]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleResponse(DeconstructTupleResponse other) : this() { + elementHandles_ = other.elementHandles_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeconstructTupleResponse Clone() { + return new DeconstructTupleResponse(this); + } + + /// Field number for the "element_handles" field. + public const int ElementHandlesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_elementHandles_codec + = pb::FieldCodec.ForMessage(10, global::Xla.GlobalDataHandle.Parser); + private readonly pbc::RepeatedField elementHandles_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ElementHandles { + get { return elementHandles_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeconstructTupleResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeconstructTupleResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!elementHandles_.Equals(other.elementHandles_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= elementHandles_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + elementHandles_.WriteTo(output, _repeated_elementHandles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + elementHandles_.WriteTo(ref output, _repeated_elementHandles_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += elementHandles_.CalculateSize(_repeated_elementHandles_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeconstructTupleResponse other) { + if (other == null) { + return; + } + elementHandles_.Add(other.elementHandles_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + elementHandles_.AddEntriesFrom(input, _repeated_elementHandles_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + elementHandles_.AddEntriesFrom(ref input, _repeated_elementHandles_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class LoadDataRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LoadDataRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[33]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataRequest(LoadDataRequest other) : this() { + columnioTabletPath_ = other.columnioTabletPath_; + columnioField_ = other.columnioField_; + elementShape_ = other.elementShape_ != null ? other.elementShape_.Clone() : null; + offset_ = other.offset_; + limit_ = other.limit_; + zip_ = other.zip_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataRequest Clone() { + return new LoadDataRequest(this); + } + + /// Field number for the "columnio_tablet_path" field. + public const int ColumnioTabletPathFieldNumber = 1; + private string columnioTabletPath_ = ""; + /// + /// Describes the path of the ColumnIO tablet to load. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ColumnioTabletPath { + get { return columnioTabletPath_; } + set { + columnioTabletPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "columnio_field" field. + public const int ColumnioFieldFieldNumber = 2; + private string columnioField_ = ""; + /// + /// Describes the field to load within the ColumnIO tablet. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string ColumnioField { + get { return columnioField_; } + set { + columnioField_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "element_shape" field. + public const int ElementShapeFieldNumber = 3; + private global::Xla.ShapeProto elementShape_; + /// + /// Individual element shape, excluding rows. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto ElementShape { + get { return elementShape_; } + set { + elementShape_ = value; + } + } + + /// Field number for the "offset" field. + public const int OffsetFieldNumber = 4; + private long offset_; + /// + /// Warning: ColumnIO does not support random-access, so use offset with + /// caution in performance-critical scenarios. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Offset { + get { return offset_; } + set { + offset_ = value; + } + } + + /// Field number for the "limit" field. + public const int LimitFieldNumber = 5; + private long limit_; + /// + /// Maximum number of elements (with shape element_shape) to load. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Limit { + get { return limit_; } + set { + limit_ = value; + } + } + + /// Field number for the "zip" field. + public const int ZipFieldNumber = 6; + private bool zip_; + /// + /// If more than one item is requested (via limit > 1), then this request + /// attribute zips together the produced vectors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Zip { + get { return zip_; } + set { + zip_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LoadDataRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LoadDataRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ColumnioTabletPath != other.ColumnioTabletPath) return false; + if (ColumnioField != other.ColumnioField) return false; + if (!object.Equals(ElementShape, other.ElementShape)) return false; + if (Offset != other.Offset) return false; + if (Limit != other.Limit) return false; + if (Zip != other.Zip) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ColumnioTabletPath.Length != 0) hash ^= ColumnioTabletPath.GetHashCode(); + if (ColumnioField.Length != 0) hash ^= ColumnioField.GetHashCode(); + if (elementShape_ != null) hash ^= ElementShape.GetHashCode(); + if (Offset != 0L) hash ^= Offset.GetHashCode(); + if (Limit != 0L) hash ^= Limit.GetHashCode(); + if (Zip != false) hash ^= Zip.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ColumnioTabletPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ColumnioTabletPath); + } + if (ColumnioField.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ColumnioField); + } + if (elementShape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ElementShape); + } + if (Offset != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Offset); + } + if (Limit != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Limit); + } + if (Zip != false) { + output.WriteRawTag(48); + output.WriteBool(Zip); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ColumnioTabletPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ColumnioTabletPath); + } + if (ColumnioField.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ColumnioField); + } + if (elementShape_ != null) { + output.WriteRawTag(26); + output.WriteMessage(ElementShape); + } + if (Offset != 0L) { + output.WriteRawTag(32); + output.WriteInt64(Offset); + } + if (Limit != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Limit); + } + if (Zip != false) { + output.WriteRawTag(48); + output.WriteBool(Zip); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ColumnioTabletPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ColumnioTabletPath); + } + if (ColumnioField.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ColumnioField); + } + if (elementShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ElementShape); + } + if (Offset != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Offset); + } + if (Limit != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Limit); + } + if (Zip != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LoadDataRequest other) { + if (other == null) { + return; + } + if (other.ColumnioTabletPath.Length != 0) { + ColumnioTabletPath = other.ColumnioTabletPath; + } + if (other.ColumnioField.Length != 0) { + ColumnioField = other.ColumnioField; + } + if (other.elementShape_ != null) { + if (elementShape_ == null) { + ElementShape = new global::Xla.ShapeProto(); + } + ElementShape.MergeFrom(other.ElementShape); + } + if (other.Offset != 0L) { + Offset = other.Offset; + } + if (other.Limit != 0L) { + Limit = other.Limit; + } + if (other.Zip != false) { + Zip = other.Zip; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ColumnioTabletPath = input.ReadString(); + break; + } + case 18: { + ColumnioField = input.ReadString(); + break; + } + case 26: { + if (elementShape_ == null) { + ElementShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(ElementShape); + break; + } + case 32: { + Offset = input.ReadInt64(); + break; + } + case 40: { + Limit = input.ReadInt64(); + break; + } + case 48: { + Zip = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + ColumnioTabletPath = input.ReadString(); + break; + } + case 18: { + ColumnioField = input.ReadString(); + break; + } + case 26: { + if (elementShape_ == null) { + ElementShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(ElementShape); + break; + } + case 32: { + Offset = input.ReadInt64(); + break; + } + case 40: { + Limit = input.ReadInt64(); + break; + } + case 48: { + Zip = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + public sealed partial class LoadDataResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LoadDataResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[34]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataResponse(LoadDataResponse other) : this() { + data_ = other.data_ != null ? other.data_.Clone() : null; + dataShape_ = other.dataShape_ != null ? other.dataShape_.Clone() : null; + availableRows_ = other.availableRows_; + rowsLoaded_ = other.rowsLoaded_; + nanoseconds_ = other.nanoseconds_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LoadDataResponse Clone() { + return new LoadDataResponse(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private global::Xla.GlobalDataHandle data_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Data { + get { return data_; } + set { + data_ = value; + } + } + + /// Field number for the "data_shape" field. + public const int DataShapeFieldNumber = 2; + private global::Xla.ShapeProto dataShape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto DataShape { + get { return dataShape_; } + set { + dataShape_ = value; + } + } + + /// Field number for the "available_rows" field. + public const int AvailableRowsFieldNumber = 3; + private long availableRows_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long AvailableRows { + get { return availableRows_; } + set { + availableRows_ = value; + } + } + + /// Field number for the "rows_loaded" field. + public const int RowsLoadedFieldNumber = 4; + private long rowsLoaded_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long RowsLoaded { + get { return rowsLoaded_; } + set { + rowsLoaded_ = value; + } + } + + /// Field number for the "nanoseconds" field. + public const int NanosecondsFieldNumber = 5; + private long nanoseconds_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Nanoseconds { + get { return nanoseconds_; } + set { + nanoseconds_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LoadDataResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LoadDataResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Data, other.Data)) return false; + if (!object.Equals(DataShape, other.DataShape)) return false; + if (AvailableRows != other.AvailableRows) return false; + if (RowsLoaded != other.RowsLoaded) return false; + if (Nanoseconds != other.Nanoseconds) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (data_ != null) hash ^= Data.GetHashCode(); + if (dataShape_ != null) hash ^= DataShape.GetHashCode(); + if (AvailableRows != 0L) hash ^= AvailableRows.GetHashCode(); + if (RowsLoaded != 0L) hash ^= RowsLoaded.GetHashCode(); + if (Nanoseconds != 0L) hash ^= Nanoseconds.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (dataShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DataShape); + } + if (AvailableRows != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AvailableRows); + } + if (RowsLoaded != 0L) { + output.WriteRawTag(32); + output.WriteInt64(RowsLoaded); + } + if (Nanoseconds != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Nanoseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (dataShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(DataShape); + } + if (AvailableRows != 0L) { + output.WriteRawTag(24); + output.WriteInt64(AvailableRows); + } + if (RowsLoaded != 0L) { + output.WriteRawTag(32); + output.WriteInt64(RowsLoaded); + } + if (Nanoseconds != 0L) { + output.WriteRawTag(40); + output.WriteInt64(Nanoseconds); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (data_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Data); + } + if (dataShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DataShape); + } + if (AvailableRows != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(AvailableRows); + } + if (RowsLoaded != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(RowsLoaded); + } + if (Nanoseconds != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Nanoseconds); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LoadDataResponse other) { + if (other == null) { + return; + } + if (other.data_ != null) { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + Data.MergeFrom(other.Data); + } + if (other.dataShape_ != null) { + if (dataShape_ == null) { + DataShape = new global::Xla.ShapeProto(); + } + DataShape.MergeFrom(other.DataShape); + } + if (other.AvailableRows != 0L) { + AvailableRows = other.AvailableRows; + } + if (other.RowsLoaded != 0L) { + RowsLoaded = other.RowsLoaded; + } + if (other.Nanoseconds != 0L) { + Nanoseconds = other.Nanoseconds; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + case 18: { + if (dataShape_ == null) { + DataShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(DataShape); + break; + } + case 24: { + AvailableRows = input.ReadInt64(); + break; + } + case 32: { + RowsLoaded = input.ReadInt64(); + break; + } + case 40: { + Nanoseconds = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + case 18: { + if (dataShape_ == null) { + DataShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(DataShape); + break; + } + case 24: { + AvailableRows = input.ReadInt64(); + break; + } + case 32: { + RowsLoaded = input.ReadInt64(); + break; + } + case 40: { + Nanoseconds = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetShapeRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetShapeRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[35]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeRequest(GetShapeRequest other) : this() { + data_ = other.data_ != null ? other.data_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeRequest Clone() { + return new GetShapeRequest(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private global::Xla.GlobalDataHandle data_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Data { + get { return data_; } + set { + data_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetShapeRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetShapeRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Data, other.Data)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (data_ != null) hash ^= Data.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (data_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Data); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetShapeRequest other) { + if (other == null) { + return; + } + if (other.data_ != null) { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + Data.MergeFrom(other.Data); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + } + #endif + + } + + public sealed partial class GetShapeResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GetShapeResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[36]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeResponse(GetShapeResponse other) : this() { + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GetShapeResponse Clone() { + return new GetShapeResponse(this); + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 1; + private global::Xla.ShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GetShapeResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GetShapeResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Shape, other.Shape)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GetShapeResponse other) { + if (other == null) { + return; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + } + } + } + #endif + + } + + public sealed partial class UnpackRequest : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnpackRequest()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[37]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackRequest() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackRequest(UnpackRequest other) : this() { + data_ = other.data_ != null ? other.data_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackRequest Clone() { + return new UnpackRequest(this); + } + + /// Field number for the "data" field. + public const int DataFieldNumber = 1; + private global::Xla.GlobalDataHandle data_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.GlobalDataHandle Data { + get { return data_; } + set { + data_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as UnpackRequest); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(UnpackRequest other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Data, other.Data)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (data_ != null) hash ^= Data.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (data_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Data); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (data_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Data); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(UnpackRequest other) { + if (other == null) { + return; + } + if (other.data_ != null) { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + Data.MergeFrom(other.Data); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (data_ == null) { + Data = new global::Xla.GlobalDataHandle(); + } + input.ReadMessage(Data); + break; + } + } + } + } + #endif + + } + + public sealed partial class UnpackResponse : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new UnpackResponse()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaReflection.Descriptor.MessageTypes[38]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackResponse() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackResponse(UnpackResponse other) : this() { + tiedData_ = other.tiedData_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public UnpackResponse Clone() { + return new UnpackResponse(this); + } + + /// Field number for the "tied_data" field. + public const int TiedDataFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_tiedData_codec + = pb::FieldCodec.ForMessage(10, global::Xla.GlobalDataHandle.Parser); + private readonly pbc::RepeatedField tiedData_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TiedData { + get { return tiedData_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as UnpackResponse); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(UnpackResponse other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!tiedData_.Equals(other.tiedData_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= tiedData_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + tiedData_.WriteTo(output, _repeated_tiedData_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + tiedData_.WriteTo(ref output, _repeated_tiedData_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += tiedData_.CalculateSize(_repeated_tiedData_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(UnpackResponse other) { + if (other == null) { + return; + } + tiedData_.Add(other.tiedData_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + tiedData_.AddEntriesFrom(input, _repeated_tiedData_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + tiedData_.AddEntriesFrom(ref input, _repeated_tiedData_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/XlaData.cs b/src/TensorFlowNET.Core/Protobuf/XlaData.cs new file mode 100644 index 000000000..b281ab778 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/XlaData.cs @@ -0,0 +1,10350 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/xla_data.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla { + + /// Holder for reflection information generated from tensorflow/compiler/xla/xla_data.proto + public static partial class XlaDataReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/xla_data.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static XlaDataReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiZ0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS94bGFfZGF0YS5wcm90bxIDeGxh", + "IrcBCg1QYWRkaW5nQ29uZmlnEj0KCmRpbWVuc2lvbnMYASADKAsyKS54bGEu", + "UGFkZGluZ0NvbmZpZy5QYWRkaW5nQ29uZmlnRGltZW5zaW9uGmcKFlBhZGRp", + "bmdDb25maWdEaW1lbnNpb24SGAoQZWRnZV9wYWRkaW5nX2xvdxgBIAEoAxIZ", + "ChFlZGdlX3BhZGRpbmdfaGlnaBgCIAEoAxIYChBpbnRlcmlvcl9wYWRkaW5n", + "GAMgASgDIh8KCVRpbGVQcm90bxISCgpkaW1lbnNpb25zGAEgAygDIqQCCgtM", + "YXlvdXRQcm90bxIqCg9kaW1fbGV2ZWxfdHlwZXMYCSADKA4yES54bGEuRGlt", + "TGV2ZWxUeXBlEhYKDm1pbm9yX3RvX21ham9yGAEgAygDEh0KBXRpbGVzGAYg", + "AygLMg4ueGxhLlRpbGVQcm90bxIcChRlbGVtZW50X3NpemVfaW5fYml0cxgH", + "IAEoAxIUCgxtZW1vcnlfc3BhY2UYCCABKAMSJwoOcGh5c2ljYWxfc2hhcGUY", + "CiABKAsyDy54bGEuU2hhcGVQcm90b0oECAIQA0oECAMQBEoECAQQBUoECAUQ", + "BlIRcGFkZGVkX2RpbWVuc2lvbnNSDXBhZGRpbmdfdmFsdWVSBmZvcm1hdFIT", + "bWF4X3NwYXJzZV9lbGVtZW50cyK9AQoKU2hhcGVQcm90bxIoCgxlbGVtZW50", + "X3R5cGUYAiABKA4yEi54bGEuUHJpbWl0aXZlVHlwZRISCgpkaW1lbnNpb25z", + "GAMgAygDEiUKDHR1cGxlX3NoYXBlcxgEIAMoCzIPLnhsYS5TaGFwZVByb3Rv", + "EiAKBmxheW91dBgFIAEoCzIQLnhsYS5MYXlvdXRQcm90bxIcChRpc19keW5h", + "bWljX2RpbWVuc2lvbhgGIAMoCEoECAEQAlIEcmFuayJyChFQcm9ncmFtU2hh", + "cGVQcm90bxIjCgpwYXJhbWV0ZXJzGAEgAygLMg8ueGxhLlNoYXBlUHJvdG8S", + "HwoGcmVzdWx0GAIgASgLMg8ueGxhLlNoYXBlUHJvdG8SFwoPcGFyYW1ldGVy", + "X25hbWVzGAMgAygJIkQKEENvbXB1dGF0aW9uU3RhdHMSEgoKZmxvcF9jb3Vu", + "dBgBIAEoARIcChR0cmFuc2NlbmRlbnRhbF9jb3VudBgCIAEoASL/AwoKT3BN", + "ZXRhZGF0YRIPCgdvcF90eXBlGAEgASgJEg8KB29wX25hbWUYAiABKAkSEwoL", + "c291cmNlX2ZpbGUYAyABKAkSEwoLc291cmNlX2xpbmUYBCABKAUSKgoMcHJv", + "ZmlsZV90eXBlGAUgAygOMhAueGxhLlByb2ZpbGVUeXBlQgIYARIYChBjcmVh", + "dGlvbl9wYXNzX2lkGAYgASgDEiAKGGxvZ2ljYWxfY3JlYXRpb25fcGFzc19p", + "ZBgHIAEoAxInCh9zaXplX29mX2dlbmVyYXRlZF9jb2RlX2luX2J5dGVzGAgg", + "ASgDEisKI3NpemVfb2ZfbWVtb3J5X3dvcmtpbmdfc2V0X2luX2J5dGVzGAkg", + "ASgDEjEKDHByb2ZpbGVfaW5mbxgKIAEoCzIbLnhsYS5PcE1ldGFkYXRhLlBy", + "b2ZpbGVJbmZvGq0BCgtQcm9maWxlSW5mbxImCgxwcm9maWxlX3R5cGUYASAD", + "KA4yEC54bGEuUHJvZmlsZVR5cGUSGAoQcmVsYXRpdmVfc3BlZWR1cBgCIAEo", + "ARIqCg5wcm9maWxlX3NvdXJjZRgDIAEoDjISLnhsYS5Qcm9maWxlU291cmNl", + "EjAKEWNvbXBpbGF0aW9uX2V2ZW50GAQgASgOMhUueGxhLkNvbXBpbGF0aW9u", + "RXZlbnRKBAgLEAwi4wEKEEV4ZWN1dGlvblByb2ZpbGUSHQoVY29tcGlsYXRp", + "b25fY2FjaGVfaGl0GAEgASgIEhcKD2NvbXBpbGVfdGltZV9tcxgCIAEoAxIb", + "ChNjb21wdXRlX2N5Y2xlX2NvdW50GAMgASgDEhcKD2NvbXB1dGVfdGltZV9u", + "cxgEIAEoAxIkChxjb21wdXRlX2FuZF90cmFuc2Zlcl90aW1lX25zGAUgASgD", + "EiAKGGV4ZWN1dGFibGVfc2l6ZV9pbl9ieXRlcxgGIAEoAxIZChFwcm9maWxl", + "X2NhY2hlX2hpdBgHIAEoCCIhCg9FeGVjdXRpb25IYW5kbGUSDgoGaGFuZGxl", + "GAEgASgDIiIKEEdsb2JhbERhdGFIYW5kbGUSDgoGaGFuZGxlGAEgASgDIjQK", + "DERldmljZUhhbmRsZRIOCgZoYW5kbGUYASABKAMSFAoMZGV2aWNlX2NvdW50", + "GAIgASgDIrQBCg1DaGFubmVsSGFuZGxlEg4KBmhhbmRsZRgBIAEoAxIsCgR0", + "eXBlGAIgASgOMh4ueGxhLkNoYW5uZWxIYW5kbGUuQ2hhbm5lbFR5cGUiZQoL", + "Q2hhbm5lbFR5cGUSGAoUQ0hBTk5FTF9UWVBFX0lOVkFMSUQQABIUChBERVZJ", + "Q0VfVE9fREVWSUNFEAESEgoOREVWSUNFX1RPX0hPU1QQAhISCg5IT1NUX1RP", + "X0RFVklDRRADIsUBChVEZXZpY2VBc3NpZ25tZW50UHJvdG8SFQoNcmVwbGlj", + "YV9jb3VudBgBIAEoBRIZChFjb21wdXRhdGlvbl9jb3VudBgCIAEoBRJJChNj", + "b21wdXRhdGlvbl9kZXZpY2VzGAMgAygLMiwueGxhLkRldmljZUFzc2lnbm1l", + "bnRQcm90by5Db21wdXRhdGlvbkRldmljZRovChFDb21wdXRhdGlvbkRldmlj", + "ZRIaChJyZXBsaWNhX2RldmljZV9pZHMYASADKAUixAIKDExpdGVyYWxQcm90", + "bxIeCgVzaGFwZRgBIAEoCzIPLnhsYS5TaGFwZVByb3RvEg0KBXByZWRzGAIg", + "AygIEgsKA3M4cxgPIAEoDBILCgN1OHMYAyABKAwSDAoEczMycxgEIAMoBRIM", + "CgRzNjRzGAUgAygDEgwKBHUzMnMYBiADKA0SDAoEdTY0cxgHIAMoBBIMCgRm", + "MzJzGAggAygCEgwKBGY2NHMYCSADKAESDAoEYzY0cxgMIAMoAhINCgVjMTI4", + "cxgSIAMoARIpCg50dXBsZV9saXRlcmFscxgKIAMoCzIRLnhsYS5MaXRlcmFs", + "UHJvdG8SDAoEZjE2cxgLIAEoDBINCgViZjE2cxgNIAEoDBIMCgR1MTZzGBAg", + "ASgMEgwKBHMxNnMYESABKAwSFgoOc3BhcnNlX2luZGljZXMYDiADKAMiowEK", + "D1dpbmRvd0RpbWVuc2lvbhIMCgRzaXplGAEgASgDEg4KBnN0cmlkZRgCIAEo", + "AxITCgtwYWRkaW5nX2xvdxgDIAEoAxIUCgxwYWRkaW5nX2hpZ2gYBCABKAMS", + "FwoPd2luZG93X2RpbGF0aW9uGAUgASgDEhUKDWJhc2VfZGlsYXRpb24YBiAB", + "KAMSFwoPd2luZG93X3JldmVyc2FsGAcgASgIIjIKBldpbmRvdxIoCgpkaW1l", + "bnNpb25zGAEgAygLMhQueGxhLldpbmRvd0RpbWVuc2lvbiJ+ChZHYXRoZXJE", + "aW1lbnNpb25OdW1iZXJzEhMKC29mZnNldF9kaW1zGAEgAygDEhwKFGNvbGxh", + "cHNlZF9zbGljZV9kaW1zGAIgAygDEhcKD3N0YXJ0X2luZGV4X21hcBgDIAMo", + "AxIYChBpbmRleF92ZWN0b3JfZGltGAQgASgDIpMBChdTY2F0dGVyRGltZW5z", + "aW9uTnVtYmVycxIaChJ1cGRhdGVfd2luZG93X2RpbXMYASADKAMSHAoUaW5z", + "ZXJ0ZWRfd2luZG93X2RpbXMYAiADKAMSJAocc2NhdHRlcl9kaW1zX3RvX29w", + "ZXJhbmRfZGltcxgDIAMoAxIYChBpbmRleF92ZWN0b3JfZGltGAQgASgDItgC", + "ChtDb252b2x1dGlvbkRpbWVuc2lvbk51bWJlcnMSHQoVaW5wdXRfYmF0Y2hf", + "ZGltZW5zaW9uGAcgASgDEh8KF2lucHV0X2ZlYXR1cmVfZGltZW5zaW9uGAgg", + "ASgDEiAKGGlucHV0X3NwYXRpYWxfZGltZW5zaW9ucxgLIAMoAxImCh5rZXJu", + "ZWxfaW5wdXRfZmVhdHVyZV9kaW1lbnNpb24YAyABKAMSJwofa2VybmVsX291", + "dHB1dF9mZWF0dXJlX2RpbWVuc2lvbhgEIAEoAxIhChlrZXJuZWxfc3BhdGlh", + "bF9kaW1lbnNpb25zGAYgAygDEh4KFm91dHB1dF9iYXRjaF9kaW1lbnNpb24Y", + "CSABKAMSIAoYb3V0cHV0X2ZlYXR1cmVfZGltZW5zaW9uGAogASgDEiEKGW91", + "dHB1dF9zcGF0aWFsX2RpbWVuc2lvbnMYDCADKAMimQEKE0RvdERpbWVuc2lv", + "bk51bWJlcnMSIgoabGhzX2NvbnRyYWN0aW5nX2RpbWVuc2lvbnMYASADKAMS", + "IgoacmhzX2NvbnRyYWN0aW5nX2RpbWVuc2lvbnMYAiADKAMSHAoUbGhzX2Jh", + "dGNoX2RpbWVuc2lvbnMYAyADKAMSHAoUcmhzX2JhdGNoX2RpbWVuc2lvbnMY", + "BCADKAMi3wEKFlRyaWFuZ3VsYXJTb2x2ZU9wdGlvbnMSEQoJbGVmdF9zaWRl", + "GAEgASgIEg0KBWxvd2VyGAIgASgIEhUKDXVuaXRfZGlhZ29uYWwYAyABKAgS", + "OgoLdHJhbnNwb3NlX2EYBCABKA4yJS54bGEuVHJpYW5ndWxhclNvbHZlT3B0", + "aW9ucy5UcmFuc3Bvc2UiUAoJVHJhbnNwb3NlEhUKEVRSQU5TUE9TRV9JTlZB", + "TElEEAASEAoMTk9fVFJBTlNQT1NFEAESDQoJVFJBTlNQT1NFEAISCwoHQURK", + "T0lOVBADIiAKD0Nob2xlc2t5T3B0aW9ucxINCgVsb3dlchgBIAEoCCJvChJG", + "cm9udGVuZEF0dHJpYnV0ZXMSLQoDbWFwGAEgAygLMiAueGxhLkZyb250ZW5k", + "QXR0cmlidXRlcy5NYXBFbnRyeRoqCghNYXBFbnRyeRILCgNrZXkYASABKAkS", + "DQoFdmFsdWUYAiABKAk6AjgBIoADCgpPcFNoYXJkaW5nEiIKBHR5cGUYASAB", + "KA4yFC54bGEuT3BTaGFyZGluZy5UeXBlEiMKCnRpbGVfc2hhcGUYAiABKAsy", + "Dy54bGEuU2hhcGVQcm90bxIiChp0aWxlX2Fzc2lnbm1lbnRfZGltZW5zaW9u", + "cxgDIAMoAxIfChd0aWxlX2Fzc2lnbm1lbnRfZGV2aWNlcxgEIAMoAxIoCg90", + "dXBsZV9zaGFyZGluZ3MYBSADKAsyDy54bGEuT3BTaGFyZGluZxIiChpyZXBs", + "aWNhdGVfb25fbGFzdF90aWxlX2RpbRgGIAEoCBIhCghtZXRhZGF0YRgHIAMo", + "CzIPLnhsYS5PcE1ldGFkYXRhEiwKDmxhc3RfdGlsZV9kaW1zGAggAygOMhQu", + "eGxhLk9wU2hhcmRpbmcuVHlwZSJFCgRUeXBlEg4KClJFUExJQ0FURUQQABIL", + "CgdNQVhJTUFMEAESCQoFVFVQTEUQAhIJCgVPVEhFUhADEgoKBk1BTlVBTBAE", + "IiMKDFJlcGxpY2FHcm91cBITCgtyZXBsaWNhX2lkcxgBIAMoAyIuCgxTb3Vy", + "Y2VUYXJnZXQSDgoGc291cmNlGAEgASgDEg4KBnRhcmdldBgCIAEoAyKQAQoP", + "UHJlY2lzaW9uQ29uZmlnEjkKEW9wZXJhbmRfcHJlY2lzaW9uGAEgAygOMh4u", + "eGxhLlByZWNpc2lvbkNvbmZpZy5QcmVjaXNpb24iQgoJUHJlY2lzaW9uEgsK", + "B0RFRkFVTFQQABIICgRISUdIEAESCwoHSElHSEVTVBACEhEKDVBBQ0tFRF9O", + "SUJCTEUQAyI6ChRQYXJhbWV0ZXJSZXBsaWNhdGlvbhIiChpyZXBsaWNhdGVk", + "X2F0X2xlYWZfYnVmZmVycxgBIAMoCCJ7ChZXaGlsZUxvb3BCYWNrZW5kQ29u", + "ZmlnEkQKEGtub3duX3RyaXBfY291bnQYASABKAsyKi54bGEuV2hpbGVMb29w", + "QmFja2VuZENvbmZpZy5Lbm93blRyaXBDb3VudBobCg5Lbm93blRyaXBDb3Vu", + "dBIJCgFuGAEgASgDInEKH0N1c3RvbUNhbGxPdXRwdXRPcGVyYW5kQWxpYXNp", + "bmcSGgoSb3V0cHV0X3NoYXBlX2luZGV4GAEgAygDEhUKDW9wZXJhbmRfaW5k", + "ZXgYAiABKAMSGwoTb3BlcmFuZF9zaGFwZV9pbmRleBgDIAMoAyraAQoNUHJp", + "bWl0aXZlVHlwZRIaChZQUklNSVRJVkVfVFlQRV9JTlZBTElEEAASCAoEUFJF", + "RBABEgYKAlM4EAISBwoDUzE2EAMSBwoDUzMyEAQSBwoDUzY0EAUSBgoCVTgQ", + "BhIHCgNVMTYQBxIHCgNVMzIQCBIHCgNVNjQQCRIHCgNGMTYQChIHCgNGMzIQ", + "CxIICgRCRjE2EBASBwoDRjY0EAwSBwoDQzY0EA8SCAoEQzEyOBASEgkKBVRV", + "UExFEA0SDwoLT1BBUVVFX1RZUEUQDhIJCgVUT0tFThARKkQKDERpbUxldmVs", + "VHlwZRINCglESU1fREVOU0UQABISCg5ESU1fQ09NUFJFU1NFRBABEhEKDURJ", + "TV9TSU5HTEVUT04QAio9CgtQcm9maWxlVHlwZRILCgdJTlZBTElEEAASCgoG", + "V0lORE9XEAESCAoERkxBRxACEgsKB0lOVEVHRVIQAypqCg1Qcm9maWxlU291", + "cmNlEiEKHVBST0ZJTEVfU09VUkNFX1VOS05PV05fU09VUkNFEAASGwoXUFJP", + "RklMRV9TT1VSQ0VfRU1CRURERUQQARIZChVQUk9GSUxFX1NPVVJDRV9SRU1P", + "VEUQAiqFAQoQQ29tcGlsYXRpb25FdmVudBIjCh9DT01QSUxBVElPTl9FVkVO", + "VF9VTktOT1dOX0VWRU5UEAASJwojQ09NUElMQVRJT05fRVZFTlRfRklSU1Rf", + "Q09NUElMQVRJT04QARIjCh9DT01QSUxBVElPTl9FVkVOVF9SRUNPTVBJTEFU", + "SU9OEAIqRwoLUGFkZGluZ1R5cGUSEwoPUEFERElOR19JTlZBTElEEAASEQoN", + "UEFERElOR19WQUxJRBABEhAKDFBBRERJTkdfU0FNRRACKjEKB0ZmdFR5cGUS", + "BwoDRkZUEAASCAoESUZGVBABEggKBFJGRlQQAhIJCgVJUkZGVBADKkYKElJh", + "bmRvbURpc3RyaWJ1dGlvbhIPCgtSTkdfSU5WQUxJRBAAEg8KC1JOR19VTklG", + "T1JNEAESDgoKUk5HX05PUk1BTBACKkUKD1JhbmRvbUFsZ29yaXRobRIPCgtS", + "TkdfREVGQVVMVBAAEhEKDVJOR19USFJFRV9GUlkQARIOCgpSTkdfUEhJTE9Y", + "EAJCA/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Xla.PrimitiveType), typeof(global::Xla.DimLevelType), typeof(global::Xla.ProfileType), typeof(global::Xla.ProfileSource), typeof(global::Xla.CompilationEvent), typeof(global::Xla.PaddingType), typeof(global::Xla.FftType), typeof(global::Xla.RandomDistribution), typeof(global::Xla.RandomAlgorithm), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.PaddingConfig), global::Xla.PaddingConfig.Parser, new[]{ "Dimensions" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.PaddingConfig.Types.PaddingConfigDimension), global::Xla.PaddingConfig.Types.PaddingConfigDimension.Parser, new[]{ "EdgePaddingLow", "EdgePaddingHigh", "InteriorPadding" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TileProto), global::Xla.TileProto.Parser, new[]{ "Dimensions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LayoutProto), global::Xla.LayoutProto.Parser, new[]{ "DimLevelTypes", "MinorToMajor", "Tiles", "ElementSizeInBits", "MemorySpace", "PhysicalShape" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ShapeProto), global::Xla.ShapeProto.Parser, new[]{ "ElementType", "Dimensions", "TupleShapes", "Layout", "IsDynamicDimension" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ProgramShapeProto), global::Xla.ProgramShapeProto.Parser, new[]{ "Parameters", "Result", "ParameterNames" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ComputationStats), global::Xla.ComputationStats.Parser, new[]{ "FlopCount", "TranscendentalCount" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.OpMetadata), global::Xla.OpMetadata.Parser, new[]{ "OpType", "OpName", "SourceFile", "SourceLine", "ProfileType", "CreationPassId", "LogicalCreationPassId", "SizeOfGeneratedCodeInBytes", "SizeOfMemoryWorkingSetInBytes", "ProfileInfo" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.OpMetadata.Types.ProfileInfo), global::Xla.OpMetadata.Types.ProfileInfo.Parser, new[]{ "ProfileType", "RelativeSpeedup", "ProfileSource", "CompilationEvent" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecutionProfile), global::Xla.ExecutionProfile.Parser, new[]{ "CompilationCacheHit", "CompileTimeMs", "ComputeCycleCount", "ComputeTimeNs", "ComputeAndTransferTimeNs", "ExecutableSizeInBytes", "ProfileCacheHit" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ExecutionHandle), global::Xla.ExecutionHandle.Parser, new[]{ "Handle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GlobalDataHandle), global::Xla.GlobalDataHandle.Parser, new[]{ "Handle" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeviceHandle), global::Xla.DeviceHandle.Parser, new[]{ "Handle", "DeviceCount" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ChannelHandle), global::Xla.ChannelHandle.Parser, new[]{ "Handle", "Type" }, null, new[]{ typeof(global::Xla.ChannelHandle.Types.ChannelType) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeviceAssignmentProto), global::Xla.DeviceAssignmentProto.Parser, new[]{ "ReplicaCount", "ComputationCount", "ComputationDevices" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DeviceAssignmentProto.Types.ComputationDevice), global::Xla.DeviceAssignmentProto.Types.ComputationDevice.Parser, new[]{ "ReplicaDeviceIds" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.LiteralProto), global::Xla.LiteralProto.Parser, new[]{ "Shape", "Preds", "S8S", "U8S", "S32S", "S64S", "U32S", "U64S", "F32S", "F64S", "C64S", "C128S", "TupleLiterals", "F16S", "Bf16S", "U16S", "S16S", "SparseIndices" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WindowDimension), global::Xla.WindowDimension.Parser, new[]{ "Size", "Stride", "PaddingLow", "PaddingHigh", "WindowDilation", "BaseDilation", "WindowReversal" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Window), global::Xla.Window.Parser, new[]{ "Dimensions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.GatherDimensionNumbers), global::Xla.GatherDimensionNumbers.Parser, new[]{ "OffsetDims", "CollapsedSliceDims", "StartIndexMap", "IndexVectorDim" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ScatterDimensionNumbers), global::Xla.ScatterDimensionNumbers.Parser, new[]{ "UpdateWindowDims", "InsertedWindowDims", "ScatterDimsToOperandDims", "IndexVectorDim" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ConvolutionDimensionNumbers), global::Xla.ConvolutionDimensionNumbers.Parser, new[]{ "InputBatchDimension", "InputFeatureDimension", "InputSpatialDimensions", "KernelInputFeatureDimension", "KernelOutputFeatureDimension", "KernelSpatialDimensions", "OutputBatchDimension", "OutputFeatureDimension", "OutputSpatialDimensions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.DotDimensionNumbers), global::Xla.DotDimensionNumbers.Parser, new[]{ "LhsContractingDimensions", "RhsContractingDimensions", "LhsBatchDimensions", "RhsBatchDimensions" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.TriangularSolveOptions), global::Xla.TriangularSolveOptions.Parser, new[]{ "LeftSide", "Lower", "UnitDiagonal", "TransposeA" }, null, new[]{ typeof(global::Xla.TriangularSolveOptions.Types.Transpose) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CholeskyOptions), global::Xla.CholeskyOptions.Parser, new[]{ "Lower" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.FrontendAttributes), global::Xla.FrontendAttributes.Parser, new[]{ "Map" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.OpSharding), global::Xla.OpSharding.Parser, new[]{ "Type", "TileShape", "TileAssignmentDimensions", "TileAssignmentDevices", "TupleShardings", "ReplicateOnLastTileDim", "Metadata", "LastTileDims" }, null, new[]{ typeof(global::Xla.OpSharding.Types.Type) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ReplicaGroup), global::Xla.ReplicaGroup.Parser, new[]{ "ReplicaIds" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.SourceTarget), global::Xla.SourceTarget.Parser, new[]{ "Source", "Target" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.PrecisionConfig), global::Xla.PrecisionConfig.Parser, new[]{ "OperandPrecision" }, null, new[]{ typeof(global::Xla.PrecisionConfig.Types.Precision) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.ParameterReplication), global::Xla.ParameterReplication.Parser, new[]{ "ReplicatedAtLeafBuffers" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WhileLoopBackendConfig), global::Xla.WhileLoopBackendConfig.Parser, new[]{ "KnownTripCount" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Xla.WhileLoopBackendConfig.Types.KnownTripCount), global::Xla.WhileLoopBackendConfig.Types.KnownTripCount.Parser, new[]{ "N" }, null, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.CustomCallOutputOperandAliasing), global::Xla.CustomCallOutputOperandAliasing.Parser, new[]{ "OutputShapeIndex", "OperandIndex", "OperandShapeIndex" }, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Primitive types are the individual values that can be held in rectangular + /// multidimensional arrays. A description of the rectangular multidimensional + /// array dimensions / primitive type is given by Shape, below. + /// + /// LINT.IfChange + /// + public enum PrimitiveType { + /// + /// Invalid primitive type to serve as default. + /// + [pbr::OriginalName("PRIMITIVE_TYPE_INVALID")] Invalid = 0, + /// + /// Predicates are two-state booleans. + /// + [pbr::OriginalName("PRED")] Pred = 1, + /// + /// Signed integral values of fixed width. + /// + [pbr::OriginalName("S8")] S8 = 2, + [pbr::OriginalName("S16")] S16 = 3, + [pbr::OriginalName("S32")] S32 = 4, + [pbr::OriginalName("S64")] S64 = 5, + /// + /// Unsigned integral values of fixed width. + /// + [pbr::OriginalName("U8")] U8 = 6, + [pbr::OriginalName("U16")] U16 = 7, + [pbr::OriginalName("U32")] U32 = 8, + [pbr::OriginalName("U64")] U64 = 9, + /// + /// Floating-point values of fixed width. + /// + /// Note: if f16s are not natively supported on the device, they will be + /// converted to f16 from f32 at arbirary points in the computation. + /// + [pbr::OriginalName("F16")] F16 = 10, + [pbr::OriginalName("F32")] F32 = 11, + /// + /// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + /// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + /// and 7 bits for the mantissa. + /// + [pbr::OriginalName("BF16")] Bf16 = 16, + [pbr::OriginalName("F64")] F64 = 12, + /// + /// Complex values of fixed width. + /// + [pbr::OriginalName("C64")] C64 = 15, + /// + /// Paired F64 (real, imag), as in std::complex<double>. + /// + [pbr::OriginalName("C128")] C128 = 18, + /// + /// A tuple is a polymorphic sequence; e.g. a shape that holds different + /// sub-shapes. They are used for things like returning multiple values from a + /// computation; e.g. a computation that returns weights and biases may have a + /// signature that results in a tuple like (f32[784x2000], f32[2000]) + /// + /// If a shape proto has the tuple element type, it may not have any entries + /// in the dimensions field. + /// + [pbr::OriginalName("TUPLE")] Tuple = 13, + /// + /// An opaque type used for passing context-specific data to a custom + /// operation. Shapes of this primitive type will have empty dimensions and + /// tuple_shapes fields. + /// + /// (OPAQUE would be a better name for this identifier, but that conflicts with + /// a macro defined in windows.h.) + /// + [pbr::OriginalName("OPAQUE_TYPE")] OpaqueType = 14, + /// + /// A token type threaded between side-effecting operations. Shapes of this + /// primitive type will have empty dimensions and tuple_shapes fields. + /// + [pbr::OriginalName("TOKEN")] Token = 17, + } + + /// + /// A DimLevelType indicates the encoding method for a dimension in an array. + /// The semantics of this field are identical to those of the MLIR SparseTensor + /// dialect. + /// This should be kept in sync with the SparseTensor DimLevelType enum: + /// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86 + /// + public enum DimLevelType { + /// + /// The corresponding dimension is Dense, every entry is stored. + /// + [pbr::OriginalName("DIM_DENSE")] DimDense = 0, + /// + /// The corresponding dimension is Compressed, only nonzeros are stored. + /// + [pbr::OriginalName("DIM_COMPRESSED")] DimCompressed = 1, + /// + /// The corresponding dimension contains a single coordinate, no sibling + /// elements for each parent. + /// + [pbr::OriginalName("DIM_SINGLETON")] DimSingleton = 2, + } + + /// + /// The type optimization profiles in use for Op-level optimizations. + /// + public enum ProfileType { + [pbr::OriginalName("INVALID")] Invalid = 0, + [pbr::OriginalName("WINDOW")] Window = 1, + [pbr::OriginalName("FLAG")] Flag = 2, + [pbr::OriginalName("INTEGER")] Integer = 3, + } + + /// + /// The source of the optimization profile. + /// + public enum ProfileSource { + [pbr::OriginalName("PROFILE_SOURCE_UNKNOWN_SOURCE")] UnknownSource = 0, + [pbr::OriginalName("PROFILE_SOURCE_EMBEDDED")] Embedded = 1, + [pbr::OriginalName("PROFILE_SOURCE_REMOTE")] Remote = 2, + } + + /// + /// The compilation event that triggered the use of the profile. + /// + public enum CompilationEvent { + [pbr::OriginalName("COMPILATION_EVENT_UNKNOWN_EVENT")] UnknownEvent = 0, + [pbr::OriginalName("COMPILATION_EVENT_FIRST_COMPILATION")] FirstCompilation = 1, + [pbr::OriginalName("COMPILATION_EVENT_RECOMPILATION")] Recompilation = 2, + } + + public enum PaddingType { + [pbr::OriginalName("PADDING_INVALID")] PaddingInvalid = 0, + /// + /// Only valid portion of the base are covered. + /// + [pbr::OriginalName("PADDING_VALID")] PaddingValid = 1, + /// + /// Extra is added to produce same output size as the input. + /// + [pbr::OriginalName("PADDING_SAME")] PaddingSame = 2, + } + + public enum FftType { + /// + /// Forward FFT; complex in, complex out. + /// + [pbr::OriginalName("FFT")] Fft = 0, + /// + /// Inverse FFT; complex in, complex out. + /// + [pbr::OriginalName("IFFT")] Ifft = 1, + /// + /// Forward real FFT; real in, fft_length / 2 + 1 complex out + /// + [pbr::OriginalName("RFFT")] Rfft = 2, + /// + /// Inverse real FFT; fft_length / 2 + 1 complex in, + /// + [pbr::OriginalName("IRFFT")] Irfft = 3, + } + + public enum RandomDistribution { + [pbr::OriginalName("RNG_INVALID")] RngInvalid = 0, + /// + /// Creates a uniform-distribution-generated random number on the semi-open + /// interval [parameter[0], parameter[1]). + /// + [pbr::OriginalName("RNG_UNIFORM")] RngUniform = 1, + /// + /// Creates a normal-distribution-generated random number with mean + /// parameter[0] and standard deviation parameter[1]. + /// + [pbr::OriginalName("RNG_NORMAL")] RngNormal = 2, + } + + public enum RandomAlgorithm { + /// + /// Backend dependent default algorithm. + /// + [pbr::OriginalName("RNG_DEFAULT")] RngDefault = 0, + [pbr::OriginalName("RNG_THREE_FRY")] RngThreeFry = 1, + /// + /// Next: 2 + /// + [pbr::OriginalName("RNG_PHILOX")] RngPhilox = 2, + } + + #endregion + + #region Messages + /// + /// Describes the padding configuration for Pad operation. The padding amount on + /// both edges as well as between the elements are specified for each dimension. + /// + public sealed partial class PaddingConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PaddingConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfig(PaddingConfig other) : this() { + dimensions_ = other.dimensions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfig Clone() { + return new PaddingConfig(this); + } + + /// Field number for the "dimensions" field. + public const int DimensionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_dimensions_codec + = pb::FieldCodec.ForMessage(10, global::Xla.PaddingConfig.Types.PaddingConfigDimension.Parser); + private readonly pbc::RepeatedField dimensions_ = new pbc::RepeatedField(); + /// + /// The padding configuration for all dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dimensions { + get { return dimensions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as PaddingConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(PaddingConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!dimensions_.Equals(other.dimensions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= dimensions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + dimensions_.WriteTo(output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + dimensions_.WriteTo(ref output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += dimensions_.CalculateSize(_repeated_dimensions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(PaddingConfig other) { + if (other == null) { + return; + } + dimensions_.Add(other.dimensions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + dimensions_.AddEntriesFrom(input, _repeated_dimensions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + dimensions_.AddEntriesFrom(ref input, _repeated_dimensions_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the PaddingConfig message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Describes the padding configuration for a dimension. + /// + public sealed partial class PaddingConfigDimension : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PaddingConfigDimension()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.PaddingConfig.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfigDimension() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfigDimension(PaddingConfigDimension other) : this() { + edgePaddingLow_ = other.edgePaddingLow_; + edgePaddingHigh_ = other.edgePaddingHigh_; + interiorPadding_ = other.interiorPadding_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PaddingConfigDimension Clone() { + return new PaddingConfigDimension(this); + } + + /// Field number for the "edge_padding_low" field. + public const int EdgePaddingLowFieldNumber = 1; + private long edgePaddingLow_; + /// + /// Padding amount on the low-end (next to the index 0). May be negative. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long EdgePaddingLow { + get { return edgePaddingLow_; } + set { + edgePaddingLow_ = value; + } + } + + /// Field number for the "edge_padding_high" field. + public const int EdgePaddingHighFieldNumber = 2; + private long edgePaddingHigh_; + /// + /// Padding amount on the high-end (next to the highest index). May be + /// negative. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long EdgePaddingHigh { + get { return edgePaddingHigh_; } + set { + edgePaddingHigh_ = value; + } + } + + /// Field number for the "interior_padding" field. + public const int InteriorPaddingFieldNumber = 3; + private long interiorPadding_; + /// + /// Padding amount between the elements. May not be negative. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long InteriorPadding { + get { return interiorPadding_; } + set { + interiorPadding_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as PaddingConfigDimension); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(PaddingConfigDimension other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (EdgePaddingLow != other.EdgePaddingLow) return false; + if (EdgePaddingHigh != other.EdgePaddingHigh) return false; + if (InteriorPadding != other.InteriorPadding) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (EdgePaddingLow != 0L) hash ^= EdgePaddingLow.GetHashCode(); + if (EdgePaddingHigh != 0L) hash ^= EdgePaddingHigh.GetHashCode(); + if (InteriorPadding != 0L) hash ^= InteriorPadding.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (EdgePaddingLow != 0L) { + output.WriteRawTag(8); + output.WriteInt64(EdgePaddingLow); + } + if (EdgePaddingHigh != 0L) { + output.WriteRawTag(16); + output.WriteInt64(EdgePaddingHigh); + } + if (InteriorPadding != 0L) { + output.WriteRawTag(24); + output.WriteInt64(InteriorPadding); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (EdgePaddingLow != 0L) { + output.WriteRawTag(8); + output.WriteInt64(EdgePaddingLow); + } + if (EdgePaddingHigh != 0L) { + output.WriteRawTag(16); + output.WriteInt64(EdgePaddingHigh); + } + if (InteriorPadding != 0L) { + output.WriteRawTag(24); + output.WriteInt64(InteriorPadding); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (EdgePaddingLow != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(EdgePaddingLow); + } + if (EdgePaddingHigh != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(EdgePaddingHigh); + } + if (InteriorPadding != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(InteriorPadding); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(PaddingConfigDimension other) { + if (other == null) { + return; + } + if (other.EdgePaddingLow != 0L) { + EdgePaddingLow = other.EdgePaddingLow; + } + if (other.EdgePaddingHigh != 0L) { + EdgePaddingHigh = other.EdgePaddingHigh; + } + if (other.InteriorPadding != 0L) { + InteriorPadding = other.InteriorPadding; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + EdgePaddingLow = input.ReadInt64(); + break; + } + case 16: { + EdgePaddingHigh = input.ReadInt64(); + break; + } + case 24: { + InteriorPadding = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + EdgePaddingLow = input.ReadInt64(); + break; + } + case 16: { + EdgePaddingHigh = input.ReadInt64(); + break; + } + case 24: { + InteriorPadding = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Describes a tile used in tiling-based layout. Refer to + /// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for + /// details about tiling-based layout. + /// + public sealed partial class TileProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TileProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TileProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TileProto(TileProto other) : this() { + dimensions_ = other.dimensions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TileProto Clone() { + return new TileProto(this); + } + + /// Field number for the "dimensions" field. + public const int DimensionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_dimensions_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField dimensions_ = new pbc::RepeatedField(); + /// + /// Number of elements in each dimension of the tile. It's ordered from the + /// most major dimension of the tile to the most minor dimension of the tile. + /// The dimensions correspond to a suffix of the dimensions of the shape being + /// tiled. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dimensions { + get { return dimensions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TileProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TileProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!dimensions_.Equals(other.dimensions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= dimensions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + dimensions_.WriteTo(output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + dimensions_.WriteTo(ref output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += dimensions_.CalculateSize(_repeated_dimensions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TileProto other) { + if (other == null) { + return; + } + dimensions_.Add(other.dimensions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + dimensions_.AddEntriesFrom(input, _repeated_dimensions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + dimensions_.AddEntriesFrom(ref input, _repeated_dimensions_codec); + break; + } + } + } + } + #endif + + } + + /// + /// A layout describes how the array is placed in (1D) memory space. This + /// includes the minor-to-major ordering of dimensions within a shape. + /// + /// Clients must specify the layouts of input Literals to the + /// computation. Layouts specified in interior operations which take Shapes (for + /// example, Convert) are ignored. + /// + /// See the XLA documentation for more information on shapes and layouts. + /// + /// LINT.IfChange + /// + public sealed partial class LayoutProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LayoutProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LayoutProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LayoutProto(LayoutProto other) : this() { + dimLevelTypes_ = other.dimLevelTypes_.Clone(); + minorToMajor_ = other.minorToMajor_.Clone(); + tiles_ = other.tiles_.Clone(); + elementSizeInBits_ = other.elementSizeInBits_; + memorySpace_ = other.memorySpace_; + physicalShape_ = other.physicalShape_ != null ? other.physicalShape_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LayoutProto Clone() { + return new LayoutProto(this); + } + + /// Field number for the "dim_level_types" field. + public const int DimLevelTypesFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_dimLevelTypes_codec + = pb::FieldCodec.ForEnum(74, x => (int) x, x => (global::Xla.DimLevelType) x); + private readonly pbc::RepeatedField dimLevelTypes_ = new pbc::RepeatedField(); + /// + /// The dimension level type list for this array, specifying the way in which + /// each array dimension is represented in memory. If this list is empty, the + /// array is assumed to be dense. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField DimLevelTypes { + get { return dimLevelTypes_; } + } + + /// Field number for the "minor_to_major" field. + public const int MinorToMajorFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_minorToMajor_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField minorToMajor_ = new pbc::RepeatedField(); + /// + /// Sequence of dimension numbers, from minor (fastest varying index) to major + /// (slowest varying index). This field is required. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MinorToMajor { + get { return minorToMajor_; } + } + + /// Field number for the "tiles" field. + public const int TilesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_tiles_codec + = pb::FieldCodec.ForMessage(50, global::Xla.TileProto.Parser); + private readonly pbc::RepeatedField tiles_ = new pbc::RepeatedField(); + /// + /// A sequence of tiles, starting from the tile that's applied first to the + /// Shape. + /// + /// TODO(b/119839262): implement tiling in each backend or add Unimplemented + /// error. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Tiles { + get { return tiles_; } + } + + /// Field number for the "element_size_in_bits" field. + public const int ElementSizeInBitsFieldNumber = 7; + private long elementSizeInBits_; + /// + /// Bit size of each element. If the size is bigger than what the element + /// type requires, the value is stored in the least significant + /// bits and the additional most significant bits are filled with 0's. + /// + /// TODO(b/119839262): implement in each backend or add Unimplemented error. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ElementSizeInBits { + get { return elementSizeInBits_; } + set { + elementSizeInBits_ = value; + } + } + + /// Field number for the "memory_space" field. + public const int MemorySpaceFieldNumber = 8; + private long memorySpace_; + /// + /// Memory space where this array resides. The integer field is interpreted in + /// a backend-specific manner. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long MemorySpace { + get { return memorySpace_; } + set { + memorySpace_ = value; + } + } + + /// Field number for the "physical_shape" field. + public const int PhysicalShapeFieldNumber = 10; + private global::Xla.ShapeProto physicalShape_; + /// + /// The physical, on-device shape used to represent the shape this layout + /// belongs to. Only used for sparse arrays. + /// The layout(s) contained within the physical shape should not also contain + /// a physical shape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto PhysicalShape { + get { return physicalShape_; } + set { + physicalShape_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LayoutProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LayoutProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!dimLevelTypes_.Equals(other.dimLevelTypes_)) return false; + if(!minorToMajor_.Equals(other.minorToMajor_)) return false; + if(!tiles_.Equals(other.tiles_)) return false; + if (ElementSizeInBits != other.ElementSizeInBits) return false; + if (MemorySpace != other.MemorySpace) return false; + if (!object.Equals(PhysicalShape, other.PhysicalShape)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= dimLevelTypes_.GetHashCode(); + hash ^= minorToMajor_.GetHashCode(); + hash ^= tiles_.GetHashCode(); + if (ElementSizeInBits != 0L) hash ^= ElementSizeInBits.GetHashCode(); + if (MemorySpace != 0L) hash ^= MemorySpace.GetHashCode(); + if (physicalShape_ != null) hash ^= PhysicalShape.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + minorToMajor_.WriteTo(output, _repeated_minorToMajor_codec); + tiles_.WriteTo(output, _repeated_tiles_codec); + if (ElementSizeInBits != 0L) { + output.WriteRawTag(56); + output.WriteInt64(ElementSizeInBits); + } + if (MemorySpace != 0L) { + output.WriteRawTag(64); + output.WriteInt64(MemorySpace); + } + dimLevelTypes_.WriteTo(output, _repeated_dimLevelTypes_codec); + if (physicalShape_ != null) { + output.WriteRawTag(82); + output.WriteMessage(PhysicalShape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + minorToMajor_.WriteTo(ref output, _repeated_minorToMajor_codec); + tiles_.WriteTo(ref output, _repeated_tiles_codec); + if (ElementSizeInBits != 0L) { + output.WriteRawTag(56); + output.WriteInt64(ElementSizeInBits); + } + if (MemorySpace != 0L) { + output.WriteRawTag(64); + output.WriteInt64(MemorySpace); + } + dimLevelTypes_.WriteTo(ref output, _repeated_dimLevelTypes_codec); + if (physicalShape_ != null) { + output.WriteRawTag(82); + output.WriteMessage(PhysicalShape); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += dimLevelTypes_.CalculateSize(_repeated_dimLevelTypes_codec); + size += minorToMajor_.CalculateSize(_repeated_minorToMajor_codec); + size += tiles_.CalculateSize(_repeated_tiles_codec); + if (ElementSizeInBits != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ElementSizeInBits); + } + if (MemorySpace != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(MemorySpace); + } + if (physicalShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PhysicalShape); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LayoutProto other) { + if (other == null) { + return; + } + dimLevelTypes_.Add(other.dimLevelTypes_); + minorToMajor_.Add(other.minorToMajor_); + tiles_.Add(other.tiles_); + if (other.ElementSizeInBits != 0L) { + ElementSizeInBits = other.ElementSizeInBits; + } + if (other.MemorySpace != 0L) { + MemorySpace = other.MemorySpace; + } + if (other.physicalShape_ != null) { + if (physicalShape_ == null) { + PhysicalShape = new global::Xla.ShapeProto(); + } + PhysicalShape.MergeFrom(other.PhysicalShape); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + minorToMajor_.AddEntriesFrom(input, _repeated_minorToMajor_codec); + break; + } + case 50: { + tiles_.AddEntriesFrom(input, _repeated_tiles_codec); + break; + } + case 56: { + ElementSizeInBits = input.ReadInt64(); + break; + } + case 64: { + MemorySpace = input.ReadInt64(); + break; + } + case 74: + case 72: { + dimLevelTypes_.AddEntriesFrom(input, _repeated_dimLevelTypes_codec); + break; + } + case 82: { + if (physicalShape_ == null) { + PhysicalShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(PhysicalShape); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + minorToMajor_.AddEntriesFrom(ref input, _repeated_minorToMajor_codec); + break; + } + case 50: { + tiles_.AddEntriesFrom(ref input, _repeated_tiles_codec); + break; + } + case 56: { + ElementSizeInBits = input.ReadInt64(); + break; + } + case 64: { + MemorySpace = input.ReadInt64(); + break; + } + case 74: + case 72: { + dimLevelTypes_.AddEntriesFrom(ref input, _repeated_dimLevelTypes_codec); + break; + } + case 82: { + if (physicalShape_ == null) { + PhysicalShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(PhysicalShape); + break; + } + } + } + } + #endif + + } + + /// + /// A shape describes the number of dimensions in the array, the size of each + /// dimension, and the primitive component type. + /// + /// Tuples are a special case in that they have rank zero and have tuple_shapes + /// defined. + /// + /// See the XLA documentation for more information on shapes and layouts. + /// + /// LINT.IfChange + /// + public sealed partial class ShapeProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ShapeProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeProto(ShapeProto other) : this() { + elementType_ = other.elementType_; + dimensions_ = other.dimensions_.Clone(); + tupleShapes_ = other.tupleShapes_.Clone(); + layout_ = other.layout_ != null ? other.layout_.Clone() : null; + isDynamicDimension_ = other.isDynamicDimension_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ShapeProto Clone() { + return new ShapeProto(this); + } + + /// Field number for the "element_type" field. + public const int ElementTypeFieldNumber = 2; + private global::Xla.PrimitiveType elementType_ = global::Xla.PrimitiveType.Invalid; + /// + /// The element type for this shape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.PrimitiveType ElementType { + get { return elementType_; } + set { + elementType_ = value; + } + } + + /// Field number for the "dimensions" field. + public const int DimensionsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_dimensions_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField dimensions_ = new pbc::RepeatedField(); + /// + /// The size (number of elements) for each dimension, or an upper bound on the + /// size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + /// to N-1 for an N-dimensional array. The first element of 'dimensions' is the + /// size of dimension 0, the second element is the size of dimension 1, and so + /// forth. Empty list indicates a scalar. + /// + /// If the respective element in 'is_dimension_dynamic' is true then the value + /// in this field represents an upper bound on the size of the dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dimensions { + get { return dimensions_; } + } + + /// Field number for the "tuple_shapes" field. + public const int TupleShapesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_tupleShapes_codec + = pb::FieldCodec.ForMessage(34, global::Xla.ShapeProto.Parser); + private readonly pbc::RepeatedField tupleShapes_ = new pbc::RepeatedField(); + /// + /// For tuples only, the shapes of constituent shapes in the tuple sequence. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TupleShapes { + get { return tupleShapes_; } + } + + /// Field number for the "layout" field. + public const int LayoutFieldNumber = 5; + private global::Xla.LayoutProto layout_; + /// + /// The layout used to back this shape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.LayoutProto Layout { + get { return layout_; } + set { + layout_ = value; + } + } + + /// Field number for the "is_dynamic_dimension" field. + public const int IsDynamicDimensionFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_isDynamicDimension_codec + = pb::FieldCodec.ForBool(50); + private readonly pbc::RepeatedField isDynamicDimension_ = new pbc::RepeatedField(); + /// + /// For arrays, this indicates whether or not each dimension is + /// dynamically-sized. The number of elements in this repeated field should be + /// zero (indicating that no dimensions are dynamic) or equal to the number of + /// elements in the 'dimensions' field. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField IsDynamicDimension { + get { return isDynamicDimension_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ShapeProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ShapeProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ElementType != other.ElementType) return false; + if(!dimensions_.Equals(other.dimensions_)) return false; + if(!tupleShapes_.Equals(other.tupleShapes_)) return false; + if (!object.Equals(Layout, other.Layout)) return false; + if(!isDynamicDimension_.Equals(other.isDynamicDimension_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ElementType != global::Xla.PrimitiveType.Invalid) hash ^= ElementType.GetHashCode(); + hash ^= dimensions_.GetHashCode(); + hash ^= tupleShapes_.GetHashCode(); + if (layout_ != null) hash ^= Layout.GetHashCode(); + hash ^= isDynamicDimension_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ElementType != global::Xla.PrimitiveType.Invalid) { + output.WriteRawTag(16); + output.WriteEnum((int) ElementType); + } + dimensions_.WriteTo(output, _repeated_dimensions_codec); + tupleShapes_.WriteTo(output, _repeated_tupleShapes_codec); + if (layout_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Layout); + } + isDynamicDimension_.WriteTo(output, _repeated_isDynamicDimension_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ElementType != global::Xla.PrimitiveType.Invalid) { + output.WriteRawTag(16); + output.WriteEnum((int) ElementType); + } + dimensions_.WriteTo(ref output, _repeated_dimensions_codec); + tupleShapes_.WriteTo(ref output, _repeated_tupleShapes_codec); + if (layout_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Layout); + } + isDynamicDimension_.WriteTo(ref output, _repeated_isDynamicDimension_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ElementType != global::Xla.PrimitiveType.Invalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ElementType); + } + size += dimensions_.CalculateSize(_repeated_dimensions_codec); + size += tupleShapes_.CalculateSize(_repeated_tupleShapes_codec); + if (layout_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Layout); + } + size += isDynamicDimension_.CalculateSize(_repeated_isDynamicDimension_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ShapeProto other) { + if (other == null) { + return; + } + if (other.ElementType != global::Xla.PrimitiveType.Invalid) { + ElementType = other.ElementType; + } + dimensions_.Add(other.dimensions_); + tupleShapes_.Add(other.tupleShapes_); + if (other.layout_ != null) { + if (layout_ == null) { + Layout = new global::Xla.LayoutProto(); + } + Layout.MergeFrom(other.Layout); + } + isDynamicDimension_.Add(other.isDynamicDimension_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 16: { + ElementType = (global::Xla.PrimitiveType) input.ReadEnum(); + break; + } + case 26: + case 24: { + dimensions_.AddEntriesFrom(input, _repeated_dimensions_codec); + break; + } + case 34: { + tupleShapes_.AddEntriesFrom(input, _repeated_tupleShapes_codec); + break; + } + case 42: { + if (layout_ == null) { + Layout = new global::Xla.LayoutProto(); + } + input.ReadMessage(Layout); + break; + } + case 50: + case 48: { + isDynamicDimension_.AddEntriesFrom(input, _repeated_isDynamicDimension_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 16: { + ElementType = (global::Xla.PrimitiveType) input.ReadEnum(); + break; + } + case 26: + case 24: { + dimensions_.AddEntriesFrom(ref input, _repeated_dimensions_codec); + break; + } + case 34: { + tupleShapes_.AddEntriesFrom(ref input, _repeated_tupleShapes_codec); + break; + } + case 42: { + if (layout_ == null) { + Layout = new global::Xla.LayoutProto(); + } + input.ReadMessage(Layout); + break; + } + case 50: + case 48: { + isDynamicDimension_.AddEntriesFrom(ref input, _repeated_isDynamicDimension_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Shape of the parameters and output of a computation (like a traditional + /// function signature). + /// + public sealed partial class ProgramShapeProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ProgramShapeProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProgramShapeProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProgramShapeProto(ProgramShapeProto other) : this() { + parameters_ = other.parameters_.Clone(); + result_ = other.result_ != null ? other.result_.Clone() : null; + parameterNames_ = other.parameterNames_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProgramShapeProto Clone() { + return new ProgramShapeProto(this); + } + + /// Field number for the "parameters" field. + public const int ParametersFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_parameters_codec + = pb::FieldCodec.ForMessage(10, global::Xla.ShapeProto.Parser); + private readonly pbc::RepeatedField parameters_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Parameters { + get { return parameters_; } + } + + /// Field number for the "result" field. + public const int ResultFieldNumber = 2; + private global::Xla.ShapeProto result_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto Result { + get { return result_; } + set { + result_ = value; + } + } + + /// Field number for the "parameter_names" field. + public const int ParameterNamesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_parameterNames_codec + = pb::FieldCodec.ForString(26); + private readonly pbc::RepeatedField parameterNames_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ParameterNames { + get { return parameterNames_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ProgramShapeProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ProgramShapeProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!parameters_.Equals(other.parameters_)) return false; + if (!object.Equals(Result, other.Result)) return false; + if(!parameterNames_.Equals(other.parameterNames_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= parameters_.GetHashCode(); + if (result_ != null) hash ^= Result.GetHashCode(); + hash ^= parameterNames_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + parameters_.WriteTo(output, _repeated_parameters_codec); + if (result_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Result); + } + parameterNames_.WriteTo(output, _repeated_parameterNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + parameters_.WriteTo(ref output, _repeated_parameters_codec); + if (result_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Result); + } + parameterNames_.WriteTo(ref output, _repeated_parameterNames_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += parameters_.CalculateSize(_repeated_parameters_codec); + if (result_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Result); + } + size += parameterNames_.CalculateSize(_repeated_parameterNames_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ProgramShapeProto other) { + if (other == null) { + return; + } + parameters_.Add(other.parameters_); + if (other.result_ != null) { + if (result_ == null) { + Result = new global::Xla.ShapeProto(); + } + Result.MergeFrom(other.Result); + } + parameterNames_.Add(other.parameterNames_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + parameters_.AddEntriesFrom(input, _repeated_parameters_codec); + break; + } + case 18: { + if (result_ == null) { + Result = new global::Xla.ShapeProto(); + } + input.ReadMessage(Result); + break; + } + case 26: { + parameterNames_.AddEntriesFrom(input, _repeated_parameterNames_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + parameters_.AddEntriesFrom(ref input, _repeated_parameters_codec); + break; + } + case 18: { + if (result_ == null) { + Result = new global::Xla.ShapeProto(); + } + input.ReadMessage(Result); + break; + } + case 26: { + parameterNames_.AddEntriesFrom(ref input, _repeated_parameterNames_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Statistics of a computation. + /// + public sealed partial class ComputationStats : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputationStats()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStats() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStats(ComputationStats other) : this() { + flopCount_ = other.flopCount_; + transcendentalCount_ = other.transcendentalCount_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationStats Clone() { + return new ComputationStats(this); + } + + /// Field number for the "flop_count" field. + public const int FlopCountFieldNumber = 1; + private double flopCount_; + /// + /// The number of floating point operations in the computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double FlopCount { + get { return flopCount_; } + set { + flopCount_ = value; + } + } + + /// Field number for the "transcendental_count" field. + public const int TranscendentalCountFieldNumber = 2; + private double transcendentalCount_; + /// + /// The number of transcendental operations (e.g., exp) in the computation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double TranscendentalCount { + get { return transcendentalCount_; } + set { + transcendentalCount_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputationStats); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputationStats other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(FlopCount, other.FlopCount)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(TranscendentalCount, other.TranscendentalCount)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (FlopCount != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(FlopCount); + if (TranscendentalCount != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(TranscendentalCount); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (FlopCount != 0D) { + output.WriteRawTag(9); + output.WriteDouble(FlopCount); + } + if (TranscendentalCount != 0D) { + output.WriteRawTag(17); + output.WriteDouble(TranscendentalCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (FlopCount != 0D) { + output.WriteRawTag(9); + output.WriteDouble(FlopCount); + } + if (TranscendentalCount != 0D) { + output.WriteRawTag(17); + output.WriteDouble(TranscendentalCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (FlopCount != 0D) { + size += 1 + 8; + } + if (TranscendentalCount != 0D) { + size += 1 + 8; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputationStats other) { + if (other == null) { + return; + } + if (other.FlopCount != 0D) { + FlopCount = other.FlopCount; + } + if (other.TranscendentalCount != 0D) { + TranscendentalCount = other.TranscendentalCount; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 9: { + FlopCount = input.ReadDouble(); + break; + } + case 17: { + TranscendentalCount = input.ReadDouble(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 9: { + FlopCount = input.ReadDouble(); + break; + } + case 17: { + TranscendentalCount = input.ReadDouble(); + break; + } + } + } + } + #endif + + } + + /// + /// Symbolization metadata for HLO Instructions. + /// + /// This metadata is used for debugging XLA code generation, as well as + /// performance profiling of XLA-generated executables. + /// + public sealed partial class OpMetadata : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpMetadata(OpMetadata other) : this() { + opType_ = other.opType_; + opName_ = other.opName_; + sourceFile_ = other.sourceFile_; + sourceLine_ = other.sourceLine_; + profileType_ = other.profileType_.Clone(); + creationPassId_ = other.creationPassId_; + logicalCreationPassId_ = other.logicalCreationPassId_; + sizeOfGeneratedCodeInBytes_ = other.sizeOfGeneratedCodeInBytes_; + sizeOfMemoryWorkingSetInBytes_ = other.sizeOfMemoryWorkingSetInBytes_; + profileInfo_ = other.profileInfo_ != null ? other.profileInfo_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpMetadata Clone() { + return new OpMetadata(this); + } + + /// Field number for the "op_type" field. + public const int OpTypeFieldNumber = 1; + private string opType_ = ""; + /// + /// The framework op name that generated this XLA op. + /// + /// Frameworks that build on top of XLA should mirror the names of their ops + /// back to users by specifying the op_type. In this way, even if the + /// framework's "ops" are implemented as multiple XLA HLO Ops, they can be + /// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + /// multiple ops, then each op should have the op_type be "SoftMax".) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string OpType { + get { return opType_; } + set { + opType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "op_name" field. + public const int OpNameFieldNumber = 2; + private string opName_ = ""; + /// + /// The user-specified name of the op. + /// + /// This name is often unique within a computation. Note: some frameworks + /// add auto-generated names if the user does not provide one. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string OpName { + get { return opName_; } + set { + opName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "source_file" field. + public const int SourceFileFieldNumber = 3; + private string sourceFile_ = ""; + /// + /// Indicate a file and line that this op is associated to in a user's program. + /// + /// e.g. it could be the file and line of user code that generated the op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string SourceFile { + get { return sourceFile_; } + set { + sourceFile_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "source_line" field. + public const int SourceLineFieldNumber = 4; + private int sourceLine_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int SourceLine { + get { return sourceLine_; } + set { + sourceLine_ = value; + } + } + + /// Field number for the "profile_type" field. + public const int ProfileTypeFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_profileType_codec + = pb::FieldCodec.ForEnum(42, x => (int) x, x => (global::Xla.ProfileType) x); + private readonly pbc::RepeatedField profileType_ = new pbc::RepeatedField(); + /// + /// Deprecated, use [ProfileInfo][profile_type] instead. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ProfileType { + get { return profileType_; } + } + + /// Field number for the "creation_pass_id" field. + public const int CreationPassIdFieldNumber = 6; + private long creationPassId_; + /// + /// HloPassMetadata.pass_id of the pass that created this HLO instruction + /// object. Should never be copied between HLO instructions. Zero if unset and + /// -1 if the instruction was created before HLO passes began. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long CreationPassId { + get { return creationPassId_; } + set { + creationPassId_ = value; + } + } + + /// Field number for the "logical_creation_pass_id" field. + public const int LogicalCreationPassIdFieldNumber = 7; + private long logicalCreationPassId_; + /// + /// HloPassMetadata.pass_id of the pass that created the logical functionality + /// that this HLO instruction represents. Should be copied between HLO + /// instructions that correspond across compilation passes. Zero if unset and + /// -1 if the instruction was created before HLO passes began. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long LogicalCreationPassId { + get { return logicalCreationPassId_; } + set { + logicalCreationPassId_ = value; + } + } + + /// Field number for the "size_of_generated_code_in_bytes" field. + public const int SizeOfGeneratedCodeInBytesFieldNumber = 8; + private long sizeOfGeneratedCodeInBytes_; + /// + /// The footprint of the generated code for the instruction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long SizeOfGeneratedCodeInBytes { + get { return sizeOfGeneratedCodeInBytes_; } + set { + sizeOfGeneratedCodeInBytes_ = value; + } + } + + /// Field number for the "size_of_memory_working_set_in_bytes" field. + public const int SizeOfMemoryWorkingSetInBytesFieldNumber = 9; + private long sizeOfMemoryWorkingSetInBytes_; + /// + /// The size of the working set, i.e., the amount of memory, used by the + /// instruction in a compiler-managed fast device memory. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long SizeOfMemoryWorkingSetInBytes { + get { return sizeOfMemoryWorkingSetInBytes_; } + set { + sizeOfMemoryWorkingSetInBytes_ = value; + } + } + + /// Field number for the "profile_info" field. + public const int ProfileInfoFieldNumber = 10; + private global::Xla.OpMetadata.Types.ProfileInfo profileInfo_; + /// + /// Profile information for the Op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpMetadata.Types.ProfileInfo ProfileInfo { + get { return profileInfo_; } + set { + profileInfo_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OpMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OpMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (OpType != other.OpType) return false; + if (OpName != other.OpName) return false; + if (SourceFile != other.SourceFile) return false; + if (SourceLine != other.SourceLine) return false; + if(!profileType_.Equals(other.profileType_)) return false; + if (CreationPassId != other.CreationPassId) return false; + if (LogicalCreationPassId != other.LogicalCreationPassId) return false; + if (SizeOfGeneratedCodeInBytes != other.SizeOfGeneratedCodeInBytes) return false; + if (SizeOfMemoryWorkingSetInBytes != other.SizeOfMemoryWorkingSetInBytes) return false; + if (!object.Equals(ProfileInfo, other.ProfileInfo)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (OpType.Length != 0) hash ^= OpType.GetHashCode(); + if (OpName.Length != 0) hash ^= OpName.GetHashCode(); + if (SourceFile.Length != 0) hash ^= SourceFile.GetHashCode(); + if (SourceLine != 0) hash ^= SourceLine.GetHashCode(); + hash ^= profileType_.GetHashCode(); + if (CreationPassId != 0L) hash ^= CreationPassId.GetHashCode(); + if (LogicalCreationPassId != 0L) hash ^= LogicalCreationPassId.GetHashCode(); + if (SizeOfGeneratedCodeInBytes != 0L) hash ^= SizeOfGeneratedCodeInBytes.GetHashCode(); + if (SizeOfMemoryWorkingSetInBytes != 0L) hash ^= SizeOfMemoryWorkingSetInBytes.GetHashCode(); + if (profileInfo_ != null) hash ^= ProfileInfo.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (OpType.Length != 0) { + output.WriteRawTag(10); + output.WriteString(OpType); + } + if (OpName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(OpName); + } + if (SourceFile.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SourceFile); + } + if (SourceLine != 0) { + output.WriteRawTag(32); + output.WriteInt32(SourceLine); + } + profileType_.WriteTo(output, _repeated_profileType_codec); + if (CreationPassId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(CreationPassId); + } + if (LogicalCreationPassId != 0L) { + output.WriteRawTag(56); + output.WriteInt64(LogicalCreationPassId); + } + if (SizeOfGeneratedCodeInBytes != 0L) { + output.WriteRawTag(64); + output.WriteInt64(SizeOfGeneratedCodeInBytes); + } + if (SizeOfMemoryWorkingSetInBytes != 0L) { + output.WriteRawTag(72); + output.WriteInt64(SizeOfMemoryWorkingSetInBytes); + } + if (profileInfo_ != null) { + output.WriteRawTag(82); + output.WriteMessage(ProfileInfo); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (OpType.Length != 0) { + output.WriteRawTag(10); + output.WriteString(OpType); + } + if (OpName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(OpName); + } + if (SourceFile.Length != 0) { + output.WriteRawTag(26); + output.WriteString(SourceFile); + } + if (SourceLine != 0) { + output.WriteRawTag(32); + output.WriteInt32(SourceLine); + } + profileType_.WriteTo(ref output, _repeated_profileType_codec); + if (CreationPassId != 0L) { + output.WriteRawTag(48); + output.WriteInt64(CreationPassId); + } + if (LogicalCreationPassId != 0L) { + output.WriteRawTag(56); + output.WriteInt64(LogicalCreationPassId); + } + if (SizeOfGeneratedCodeInBytes != 0L) { + output.WriteRawTag(64); + output.WriteInt64(SizeOfGeneratedCodeInBytes); + } + if (SizeOfMemoryWorkingSetInBytes != 0L) { + output.WriteRawTag(72); + output.WriteInt64(SizeOfMemoryWorkingSetInBytes); + } + if (profileInfo_ != null) { + output.WriteRawTag(82); + output.WriteMessage(ProfileInfo); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (OpType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(OpType); + } + if (OpName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(OpName); + } + if (SourceFile.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SourceFile); + } + if (SourceLine != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(SourceLine); + } + size += profileType_.CalculateSize(_repeated_profileType_codec); + if (CreationPassId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(CreationPassId); + } + if (LogicalCreationPassId != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(LogicalCreationPassId); + } + if (SizeOfGeneratedCodeInBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(SizeOfGeneratedCodeInBytes); + } + if (SizeOfMemoryWorkingSetInBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(SizeOfMemoryWorkingSetInBytes); + } + if (profileInfo_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ProfileInfo); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OpMetadata other) { + if (other == null) { + return; + } + if (other.OpType.Length != 0) { + OpType = other.OpType; + } + if (other.OpName.Length != 0) { + OpName = other.OpName; + } + if (other.SourceFile.Length != 0) { + SourceFile = other.SourceFile; + } + if (other.SourceLine != 0) { + SourceLine = other.SourceLine; + } + profileType_.Add(other.profileType_); + if (other.CreationPassId != 0L) { + CreationPassId = other.CreationPassId; + } + if (other.LogicalCreationPassId != 0L) { + LogicalCreationPassId = other.LogicalCreationPassId; + } + if (other.SizeOfGeneratedCodeInBytes != 0L) { + SizeOfGeneratedCodeInBytes = other.SizeOfGeneratedCodeInBytes; + } + if (other.SizeOfMemoryWorkingSetInBytes != 0L) { + SizeOfMemoryWorkingSetInBytes = other.SizeOfMemoryWorkingSetInBytes; + } + if (other.profileInfo_ != null) { + if (profileInfo_ == null) { + ProfileInfo = new global::Xla.OpMetadata.Types.ProfileInfo(); + } + ProfileInfo.MergeFrom(other.ProfileInfo); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + OpType = input.ReadString(); + break; + } + case 18: { + OpName = input.ReadString(); + break; + } + case 26: { + SourceFile = input.ReadString(); + break; + } + case 32: { + SourceLine = input.ReadInt32(); + break; + } + case 42: + case 40: { + profileType_.AddEntriesFrom(input, _repeated_profileType_codec); + break; + } + case 48: { + CreationPassId = input.ReadInt64(); + break; + } + case 56: { + LogicalCreationPassId = input.ReadInt64(); + break; + } + case 64: { + SizeOfGeneratedCodeInBytes = input.ReadInt64(); + break; + } + case 72: { + SizeOfMemoryWorkingSetInBytes = input.ReadInt64(); + break; + } + case 82: { + if (profileInfo_ == null) { + ProfileInfo = new global::Xla.OpMetadata.Types.ProfileInfo(); + } + input.ReadMessage(ProfileInfo); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + OpType = input.ReadString(); + break; + } + case 18: { + OpName = input.ReadString(); + break; + } + case 26: { + SourceFile = input.ReadString(); + break; + } + case 32: { + SourceLine = input.ReadInt32(); + break; + } + case 42: + case 40: { + profileType_.AddEntriesFrom(ref input, _repeated_profileType_codec); + break; + } + case 48: { + CreationPassId = input.ReadInt64(); + break; + } + case 56: { + LogicalCreationPassId = input.ReadInt64(); + break; + } + case 64: { + SizeOfGeneratedCodeInBytes = input.ReadInt64(); + break; + } + case 72: { + SizeOfMemoryWorkingSetInBytes = input.ReadInt64(); + break; + } + case 82: { + if (profileInfo_ == null) { + ProfileInfo = new global::Xla.OpMetadata.Types.ProfileInfo(); + } + input.ReadMessage(ProfileInfo); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the OpMetadata message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Information about the optimization profile that this operation contains. + /// + public sealed partial class ProfileInfo : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ProfileInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.OpMetadata.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo(ProfileInfo other) : this() { + profileType_ = other.profileType_.Clone(); + relativeSpeedup_ = other.relativeSpeedup_; + profileSource_ = other.profileSource_; + compilationEvent_ = other.compilationEvent_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ProfileInfo Clone() { + return new ProfileInfo(this); + } + + /// Field number for the "profile_type" field. + public const int ProfileTypeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_profileType_codec + = pb::FieldCodec.ForEnum(10, x => (int) x, x => (global::Xla.ProfileType) x); + private readonly pbc::RepeatedField profileType_ = new pbc::RepeatedField(); + /// + /// The type of optimization profiles that this operation contains. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ProfileType { + get { return profileType_; } + } + + /// Field number for the "relative_speedup" field. + public const int RelativeSpeedupFieldNumber = 2; + private double relativeSpeedup_; + /// + /// Speedup of tuned config compared to default config. + /// TODO(b/203817882) Set the relative_speedup. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public double RelativeSpeedup { + get { return relativeSpeedup_; } + set { + relativeSpeedup_ = value; + } + } + + /// Field number for the "profile_source" field. + public const int ProfileSourceFieldNumber = 3; + private global::Xla.ProfileSource profileSource_ = global::Xla.ProfileSource.UnknownSource; + /// + /// The source of the optimization profiles that this operation contains. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ProfileSource ProfileSource { + get { return profileSource_; } + set { + profileSource_ = value; + } + } + + /// Field number for the "compilation_event" field. + public const int CompilationEventFieldNumber = 4; + private global::Xla.CompilationEvent compilationEvent_ = global::Xla.CompilationEvent.UnknownEvent; + /// + /// The compilation event that triggered the use of the profiles. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.CompilationEvent CompilationEvent { + get { return compilationEvent_; } + set { + compilationEvent_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ProfileInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ProfileInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!profileType_.Equals(other.profileType_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(RelativeSpeedup, other.RelativeSpeedup)) return false; + if (ProfileSource != other.ProfileSource) return false; + if (CompilationEvent != other.CompilationEvent) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= profileType_.GetHashCode(); + if (RelativeSpeedup != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(RelativeSpeedup); + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) hash ^= ProfileSource.GetHashCode(); + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) hash ^= CompilationEvent.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + profileType_.WriteTo(output, _repeated_profileType_codec); + if (RelativeSpeedup != 0D) { + output.WriteRawTag(17); + output.WriteDouble(RelativeSpeedup); + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + output.WriteRawTag(24); + output.WriteEnum((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + output.WriteRawTag(32); + output.WriteEnum((int) CompilationEvent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + profileType_.WriteTo(ref output, _repeated_profileType_codec); + if (RelativeSpeedup != 0D) { + output.WriteRawTag(17); + output.WriteDouble(RelativeSpeedup); + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + output.WriteRawTag(24); + output.WriteEnum((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + output.WriteRawTag(32); + output.WriteEnum((int) CompilationEvent); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += profileType_.CalculateSize(_repeated_profileType_codec); + if (RelativeSpeedup != 0D) { + size += 1 + 8; + } + if (ProfileSource != global::Xla.ProfileSource.UnknownSource) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ProfileSource); + } + if (CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) CompilationEvent); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ProfileInfo other) { + if (other == null) { + return; + } + profileType_.Add(other.profileType_); + if (other.RelativeSpeedup != 0D) { + RelativeSpeedup = other.RelativeSpeedup; + } + if (other.ProfileSource != global::Xla.ProfileSource.UnknownSource) { + ProfileSource = other.ProfileSource; + } + if (other.CompilationEvent != global::Xla.CompilationEvent.UnknownEvent) { + CompilationEvent = other.CompilationEvent; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + profileType_.AddEntriesFrom(input, _repeated_profileType_codec); + break; + } + case 17: { + RelativeSpeedup = input.ReadDouble(); + break; + } + case 24: { + ProfileSource = (global::Xla.ProfileSource) input.ReadEnum(); + break; + } + case 32: { + CompilationEvent = (global::Xla.CompilationEvent) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + profileType_.AddEntriesFrom(ref input, _repeated_profileType_codec); + break; + } + case 17: { + RelativeSpeedup = input.ReadDouble(); + break; + } + case 24: { + ProfileSource = (global::Xla.ProfileSource) input.ReadEnum(); + break; + } + case 32: { + CompilationEvent = (global::Xla.CompilationEvent) input.ReadEnum(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Profile data from the execution of a computation. + /// + public sealed partial class ExecutionProfile : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecutionProfile()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionProfile() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionProfile(ExecutionProfile other) : this() { + compilationCacheHit_ = other.compilationCacheHit_; + compileTimeMs_ = other.compileTimeMs_; + computeCycleCount_ = other.computeCycleCount_; + computeTimeNs_ = other.computeTimeNs_; + computeAndTransferTimeNs_ = other.computeAndTransferTimeNs_; + executableSizeInBytes_ = other.executableSizeInBytes_; + profileCacheHit_ = other.profileCacheHit_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionProfile Clone() { + return new ExecutionProfile(this); + } + + /// Field number for the "compilation_cache_hit" field. + public const int CompilationCacheHitFieldNumber = 1; + private bool compilationCacheHit_; + /// + /// Whether the executable was read from the compilation cache. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool CompilationCacheHit { + get { return compilationCacheHit_; } + set { + compilationCacheHit_ = value; + } + } + + /// Field number for the "compile_time_ms" field. + public const int CompileTimeMsFieldNumber = 2; + private long compileTimeMs_; + /// + /// The time in milliseconds spent to compile the computation. This only set if + /// the executable was not read from the compilation cache + /// (compilation_cache_hit == false). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long CompileTimeMs { + get { return compileTimeMs_; } + set { + compileTimeMs_ = value; + } + } + + /// Field number for the "compute_cycle_count" field. + public const int ComputeCycleCountFieldNumber = 3; + private long computeCycleCount_; + /// + /// The number of cycles spent for the computation. This does not include the + /// time taken for the data transfers between the host and the device. This is + /// a target-dependent field and only used for debugging purposes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ComputeCycleCount { + get { return computeCycleCount_; } + set { + computeCycleCount_ = value; + } + } + + /// Field number for the "compute_time_ns" field. + public const int ComputeTimeNsFieldNumber = 4; + private long computeTimeNs_; + /// + /// The time in nanoseconds spent for the computation, without data transfer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ComputeTimeNs { + get { return computeTimeNs_; } + set { + computeTimeNs_ = value; + } + } + + /// Field number for the "compute_and_transfer_time_ns" field. + public const int ComputeAndTransferTimeNsFieldNumber = 5; + private long computeAndTransferTimeNs_; + /// + /// The time in nanoseconds spent for the entire computation, including the + /// result data transfer time. Current implementation does not spend any cycles + /// for the input data transfer since the memory is initialized with the proper + /// values before the execution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ComputeAndTransferTimeNs { + get { return computeAndTransferTimeNs_; } + set { + computeAndTransferTimeNs_ = value; + } + } + + /// Field number for the "executable_size_in_bytes" field. + public const int ExecutableSizeInBytesFieldNumber = 6; + private long executableSizeInBytes_; + /// + /// The size of the binary code in the executable. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long ExecutableSizeInBytes { + get { return executableSizeInBytes_; } + set { + executableSizeInBytes_ = value; + } + } + + /// Field number for the "profile_cache_hit" field. + public const int ProfileCacheHitFieldNumber = 7; + private bool profileCacheHit_; + /// + /// Whether this profile was drawn from a cache of profiles instead of from + /// execution on the hardware. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ProfileCacheHit { + get { return profileCacheHit_; } + set { + profileCacheHit_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecutionProfile); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecutionProfile other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (CompilationCacheHit != other.CompilationCacheHit) return false; + if (CompileTimeMs != other.CompileTimeMs) return false; + if (ComputeCycleCount != other.ComputeCycleCount) return false; + if (ComputeTimeNs != other.ComputeTimeNs) return false; + if (ComputeAndTransferTimeNs != other.ComputeAndTransferTimeNs) return false; + if (ExecutableSizeInBytes != other.ExecutableSizeInBytes) return false; + if (ProfileCacheHit != other.ProfileCacheHit) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (CompilationCacheHit != false) hash ^= CompilationCacheHit.GetHashCode(); + if (CompileTimeMs != 0L) hash ^= CompileTimeMs.GetHashCode(); + if (ComputeCycleCount != 0L) hash ^= ComputeCycleCount.GetHashCode(); + if (ComputeTimeNs != 0L) hash ^= ComputeTimeNs.GetHashCode(); + if (ComputeAndTransferTimeNs != 0L) hash ^= ComputeAndTransferTimeNs.GetHashCode(); + if (ExecutableSizeInBytes != 0L) hash ^= ExecutableSizeInBytes.GetHashCode(); + if (ProfileCacheHit != false) hash ^= ProfileCacheHit.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (CompilationCacheHit != false) { + output.WriteRawTag(8); + output.WriteBool(CompilationCacheHit); + } + if (CompileTimeMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(CompileTimeMs); + } + if (ComputeCycleCount != 0L) { + output.WriteRawTag(24); + output.WriteInt64(ComputeCycleCount); + } + if (ComputeTimeNs != 0L) { + output.WriteRawTag(32); + output.WriteInt64(ComputeTimeNs); + } + if (ComputeAndTransferTimeNs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(ComputeAndTransferTimeNs); + } + if (ExecutableSizeInBytes != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ExecutableSizeInBytes); + } + if (ProfileCacheHit != false) { + output.WriteRawTag(56); + output.WriteBool(ProfileCacheHit); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (CompilationCacheHit != false) { + output.WriteRawTag(8); + output.WriteBool(CompilationCacheHit); + } + if (CompileTimeMs != 0L) { + output.WriteRawTag(16); + output.WriteInt64(CompileTimeMs); + } + if (ComputeCycleCount != 0L) { + output.WriteRawTag(24); + output.WriteInt64(ComputeCycleCount); + } + if (ComputeTimeNs != 0L) { + output.WriteRawTag(32); + output.WriteInt64(ComputeTimeNs); + } + if (ComputeAndTransferTimeNs != 0L) { + output.WriteRawTag(40); + output.WriteInt64(ComputeAndTransferTimeNs); + } + if (ExecutableSizeInBytes != 0L) { + output.WriteRawTag(48); + output.WriteInt64(ExecutableSizeInBytes); + } + if (ProfileCacheHit != false) { + output.WriteRawTag(56); + output.WriteBool(ProfileCacheHit); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (CompilationCacheHit != false) { + size += 1 + 1; + } + if (CompileTimeMs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(CompileTimeMs); + } + if (ComputeCycleCount != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ComputeCycleCount); + } + if (ComputeTimeNs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ComputeTimeNs); + } + if (ComputeAndTransferTimeNs != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ComputeAndTransferTimeNs); + } + if (ExecutableSizeInBytes != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ExecutableSizeInBytes); + } + if (ProfileCacheHit != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecutionProfile other) { + if (other == null) { + return; + } + if (other.CompilationCacheHit != false) { + CompilationCacheHit = other.CompilationCacheHit; + } + if (other.CompileTimeMs != 0L) { + CompileTimeMs = other.CompileTimeMs; + } + if (other.ComputeCycleCount != 0L) { + ComputeCycleCount = other.ComputeCycleCount; + } + if (other.ComputeTimeNs != 0L) { + ComputeTimeNs = other.ComputeTimeNs; + } + if (other.ComputeAndTransferTimeNs != 0L) { + ComputeAndTransferTimeNs = other.ComputeAndTransferTimeNs; + } + if (other.ExecutableSizeInBytes != 0L) { + ExecutableSizeInBytes = other.ExecutableSizeInBytes; + } + if (other.ProfileCacheHit != false) { + ProfileCacheHit = other.ProfileCacheHit; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + CompilationCacheHit = input.ReadBool(); + break; + } + case 16: { + CompileTimeMs = input.ReadInt64(); + break; + } + case 24: { + ComputeCycleCount = input.ReadInt64(); + break; + } + case 32: { + ComputeTimeNs = input.ReadInt64(); + break; + } + case 40: { + ComputeAndTransferTimeNs = input.ReadInt64(); + break; + } + case 48: { + ExecutableSizeInBytes = input.ReadInt64(); + break; + } + case 56: { + ProfileCacheHit = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + CompilationCacheHit = input.ReadBool(); + break; + } + case 16: { + CompileTimeMs = input.ReadInt64(); + break; + } + case 24: { + ComputeCycleCount = input.ReadInt64(); + break; + } + case 32: { + ComputeTimeNs = input.ReadInt64(); + break; + } + case 40: { + ComputeAndTransferTimeNs = input.ReadInt64(); + break; + } + case 48: { + ExecutableSizeInBytes = input.ReadInt64(); + break; + } + case 56: { + ProfileCacheHit = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + /// + /// Handle given to a user that represents an execution that the user launched + /// asynchronously on the device. + /// + public sealed partial class ExecutionHandle : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExecutionHandle()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionHandle() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionHandle(ExecutionHandle other) : this() { + handle_ = other.handle_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ExecutionHandle Clone() { + return new ExecutionHandle(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private long handle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ExecutionHandle); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ExecutionHandle other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Handle != other.Handle) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Handle != 0L) hash ^= Handle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Handle != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Handle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ExecutionHandle other) { + if (other == null) { + return; + } + if (other.Handle != 0L) { + Handle = other.Handle; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Handle given to a user that represents a globally accessible allocation. + /// Contrast this against a ComputationDataHandle, which is not globally + /// accessible, since it only exists within a specific computation. + /// + public sealed partial class GlobalDataHandle : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GlobalDataHandle()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalDataHandle() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalDataHandle(GlobalDataHandle other) : this() { + handle_ = other.handle_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GlobalDataHandle Clone() { + return new GlobalDataHandle(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private long handle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GlobalDataHandle); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GlobalDataHandle other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Handle != other.Handle) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Handle != 0L) hash ^= Handle.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Handle != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Handle); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GlobalDataHandle other) { + if (other == null) { + return; + } + if (other.Handle != 0L) { + Handle = other.Handle; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Handle given to a user that represents a replicated virtual device. Each + /// replicated device represents N physical devices for execution where N is the + /// number of replicas. + /// + public sealed partial class DeviceHandle : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceHandle()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceHandle() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceHandle(DeviceHandle other) : this() { + handle_ = other.handle_; + deviceCount_ = other.deviceCount_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceHandle Clone() { + return new DeviceHandle(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private long handle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + /// Field number for the "device_count" field. + public const int DeviceCountFieldNumber = 2; + private long deviceCount_; + /// + /// The number of model-parallel virtual devices that communicate via XLA + /// Send/Recv instructions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long DeviceCount { + get { return deviceCount_; } + set { + deviceCount_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceHandle); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceHandle other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Handle != other.Handle) return false; + if (DeviceCount != other.DeviceCount) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Handle != 0L) hash ^= Handle.GetHashCode(); + if (DeviceCount != 0L) hash ^= DeviceCount.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (DeviceCount != 0L) { + output.WriteRawTag(16); + output.WriteInt64(DeviceCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (DeviceCount != 0L) { + output.WriteRawTag(16); + output.WriteInt64(DeviceCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Handle != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Handle); + } + if (DeviceCount != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DeviceCount); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceHandle other) { + if (other == null) { + return; + } + if (other.Handle != 0L) { + Handle = other.Handle; + } + if (other.DeviceCount != 0L) { + DeviceCount = other.DeviceCount; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + case 16: { + DeviceCount = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + case 16: { + DeviceCount = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Handle given to a user to represent a channel between two computations + /// via a Send and Recv instruction pair. Channels are unbuffered, so Send + /// Send instructions will be blocked until the data is transferred. + /// + public sealed partial class ChannelHandle : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ChannelHandle()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ChannelHandle() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ChannelHandle(ChannelHandle other) : this() { + handle_ = other.handle_; + type_ = other.type_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ChannelHandle Clone() { + return new ChannelHandle(this); + } + + /// Field number for the "handle" field. + public const int HandleFieldNumber = 1; + private long handle_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Handle { + get { return handle_; } + set { + handle_ = value; + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 2; + private global::Xla.ChannelHandle.Types.ChannelType type_ = global::Xla.ChannelHandle.Types.ChannelType.Invalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ChannelHandle.Types.ChannelType Type { + get { return type_; } + set { + type_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ChannelHandle); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ChannelHandle other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Handle != other.Handle) return false; + if (Type != other.Type) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Handle != 0L) hash ^= Handle.GetHashCode(); + if (Type != global::Xla.ChannelHandle.Types.ChannelType.Invalid) hash ^= Type.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (Type != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Type); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Handle != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Handle); + } + if (Type != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + output.WriteRawTag(16); + output.WriteEnum((int) Type); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Handle != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Handle); + } + if (Type != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Type); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ChannelHandle other) { + if (other == null) { + return; + } + if (other.Handle != 0L) { + Handle = other.Handle; + } + if (other.Type != global::Xla.ChannelHandle.Types.ChannelType.Invalid) { + Type = other.Type; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + case 16: { + Type = (global::Xla.ChannelHandle.Types.ChannelType) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Handle = input.ReadInt64(); + break; + } + case 16: { + Type = (global::Xla.ChannelHandle.Types.ChannelType) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the ChannelHandle message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum ChannelType { + /// + /// Invalid primitive type to serve as default. + /// + [pbr::OriginalName("CHANNEL_TYPE_INVALID")] Invalid = 0, + /// + /// A channel for sending data between devices. + /// + [pbr::OriginalName("DEVICE_TO_DEVICE")] DeviceToDevice = 1, + /// + /// A channel for sending data from the device to the host. Can only be used + /// with a Send operation. + /// + [pbr::OriginalName("DEVICE_TO_HOST")] DeviceToHost = 2, + /// + /// A channel for sending data from the host to the device. Can only be used + /// with a Recv operation. + /// + [pbr::OriginalName("HOST_TO_DEVICE")] HostToDevice = 3, + } + + } + #endregion + + } + + /// + /// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which + /// represents the device ids assigned to a set of replicated computations. + /// See xla::DeviceAssignment class comment for more details. + /// + public sealed partial class DeviceAssignmentProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DeviceAssignmentProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAssignmentProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAssignmentProto(DeviceAssignmentProto other) : this() { + replicaCount_ = other.replicaCount_; + computationCount_ = other.computationCount_; + computationDevices_ = other.computationDevices_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DeviceAssignmentProto Clone() { + return new DeviceAssignmentProto(this); + } + + /// Field number for the "replica_count" field. + public const int ReplicaCountFieldNumber = 1; + private int replicaCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ReplicaCount { + get { return replicaCount_; } + set { + replicaCount_ = value; + } + } + + /// Field number for the "computation_count" field. + public const int ComputationCountFieldNumber = 2; + private int computationCount_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ComputationCount { + get { return computationCount_; } + set { + computationCount_ = value; + } + } + + /// Field number for the "computation_devices" field. + public const int ComputationDevicesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_computationDevices_codec + = pb::FieldCodec.ForMessage(26, global::Xla.DeviceAssignmentProto.Types.ComputationDevice.Parser); + private readonly pbc::RepeatedField computationDevices_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ComputationDevices { + get { return computationDevices_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DeviceAssignmentProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DeviceAssignmentProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ReplicaCount != other.ReplicaCount) return false; + if (ComputationCount != other.ComputationCount) return false; + if(!computationDevices_.Equals(other.computationDevices_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (ReplicaCount != 0) hash ^= ReplicaCount.GetHashCode(); + if (ComputationCount != 0) hash ^= ComputationCount.GetHashCode(); + hash ^= computationDevices_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (ReplicaCount != 0) { + output.WriteRawTag(8); + output.WriteInt32(ReplicaCount); + } + if (ComputationCount != 0) { + output.WriteRawTag(16); + output.WriteInt32(ComputationCount); + } + computationDevices_.WriteTo(output, _repeated_computationDevices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (ReplicaCount != 0) { + output.WriteRawTag(8); + output.WriteInt32(ReplicaCount); + } + if (ComputationCount != 0) { + output.WriteRawTag(16); + output.WriteInt32(ComputationCount); + } + computationDevices_.WriteTo(ref output, _repeated_computationDevices_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (ReplicaCount != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ReplicaCount); + } + if (ComputationCount != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ComputationCount); + } + size += computationDevices_.CalculateSize(_repeated_computationDevices_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DeviceAssignmentProto other) { + if (other == null) { + return; + } + if (other.ReplicaCount != 0) { + ReplicaCount = other.ReplicaCount; + } + if (other.ComputationCount != 0) { + ComputationCount = other.ComputationCount; + } + computationDevices_.Add(other.computationDevices_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ReplicaCount = input.ReadInt32(); + break; + } + case 16: { + ComputationCount = input.ReadInt32(); + break; + } + case 26: { + computationDevices_.AddEntriesFrom(input, _repeated_computationDevices_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + ReplicaCount = input.ReadInt32(); + break; + } + case 16: { + ComputationCount = input.ReadInt32(); + break; + } + case 26: { + computationDevices_.AddEntriesFrom(ref input, _repeated_computationDevices_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the DeviceAssignmentProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Each logical computation runs on replica_count physical devices. + /// ComputationDevice represents the device ids assinged to the replicas. + /// + public sealed partial class ComputationDevice : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ComputationDevice()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.DeviceAssignmentProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationDevice() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationDevice(ComputationDevice other) : this() { + replicaDeviceIds_ = other.replicaDeviceIds_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ComputationDevice Clone() { + return new ComputationDevice(this); + } + + /// Field number for the "replica_device_ids" field. + public const int ReplicaDeviceIdsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_replicaDeviceIds_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField replicaDeviceIds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ReplicaDeviceIds { + get { return replicaDeviceIds_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ComputationDevice); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ComputationDevice other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!replicaDeviceIds_.Equals(other.replicaDeviceIds_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= replicaDeviceIds_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + replicaDeviceIds_.WriteTo(output, _repeated_replicaDeviceIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + replicaDeviceIds_.WriteTo(ref output, _repeated_replicaDeviceIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += replicaDeviceIds_.CalculateSize(_repeated_replicaDeviceIds_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ComputationDevice other) { + if (other == null) { + return; + } + replicaDeviceIds_.Add(other.replicaDeviceIds_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + replicaDeviceIds_.AddEntriesFrom(input, _repeated_replicaDeviceIds_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + replicaDeviceIds_.AddEntriesFrom(ref input, _repeated_replicaDeviceIds_codec); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Literals are used when the server and client need to exchange materialized + /// data / results. Literals are also used to describe constants used in + /// computations. + /// + /// Transfers to/from the client are encoded in literal form, and the structure + /// of the repeated fields is implied by the shape. + /// + public sealed partial class LiteralProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LiteralProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LiteralProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LiteralProto(LiteralProto other) : this() { + shape_ = other.shape_ != null ? other.shape_.Clone() : null; + preds_ = other.preds_.Clone(); + s8S_ = other.s8S_; + u8S_ = other.u8S_; + s32S_ = other.s32S_.Clone(); + s64S_ = other.s64S_.Clone(); + u32S_ = other.u32S_.Clone(); + u64S_ = other.u64S_.Clone(); + f32S_ = other.f32S_.Clone(); + f64S_ = other.f64S_.Clone(); + c64S_ = other.c64S_.Clone(); + c128S_ = other.c128S_.Clone(); + tupleLiterals_ = other.tupleLiterals_.Clone(); + f16S_ = other.f16S_; + bf16S_ = other.bf16S_; + u16S_ = other.u16S_; + s16S_ = other.s16S_; + sparseIndices_ = other.sparseIndices_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public LiteralProto Clone() { + return new LiteralProto(this); + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 1; + private global::Xla.ShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto Shape { + get { return shape_; } + set { + shape_ = value; + } + } + + /// Field number for the "preds" field. + public const int PredsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_preds_codec + = pb::FieldCodec.ForBool(18); + private readonly pbc::RepeatedField preds_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Preds { + get { return preds_; } + } + + /// Field number for the "s8s" field. + public const int S8SFieldNumber = 15; + private pb::ByteString s8S_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString S8S { + get { return s8S_; } + set { + s8S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "u8s" field. + public const int U8SFieldNumber = 3; + private pb::ByteString u8S_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString U8S { + get { return u8S_; } + set { + u8S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "s32s" field. + public const int S32SFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_s32S_codec + = pb::FieldCodec.ForInt32(34); + private readonly pbc::RepeatedField s32S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField S32S { + get { return s32S_; } + } + + /// Field number for the "s64s" field. + public const int S64SFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_s64S_codec + = pb::FieldCodec.ForInt64(42); + private readonly pbc::RepeatedField s64S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField S64S { + get { return s64S_; } + } + + /// Field number for the "u32s" field. + public const int U32SFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_u32S_codec + = pb::FieldCodec.ForUInt32(50); + private readonly pbc::RepeatedField u32S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField U32S { + get { return u32S_; } + } + + /// Field number for the "u64s" field. + public const int U64SFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_u64S_codec + = pb::FieldCodec.ForUInt64(58); + private readonly pbc::RepeatedField u64S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField U64S { + get { return u64S_; } + } + + /// Field number for the "f32s" field. + public const int F32SFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_f32S_codec + = pb::FieldCodec.ForFloat(66); + private readonly pbc::RepeatedField f32S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField F32S { + get { return f32S_; } + } + + /// Field number for the "f64s" field. + public const int F64SFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_f64S_codec + = pb::FieldCodec.ForDouble(74); + private readonly pbc::RepeatedField f64S_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField F64S { + get { return f64S_; } + } + + /// Field number for the "c64s" field. + public const int C64SFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_c64S_codec + = pb::FieldCodec.ForFloat(98); + private readonly pbc::RepeatedField c64S_ = new pbc::RepeatedField(); + /// + /// Stored as interleaved real, imag floats. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField C64S { + get { return c64S_; } + } + + /// Field number for the "c128s" field. + public const int C128SFieldNumber = 18; + private static readonly pb::FieldCodec _repeated_c128S_codec + = pb::FieldCodec.ForDouble(146); + private readonly pbc::RepeatedField c128S_ = new pbc::RepeatedField(); + /// + /// Stored as interleaved real, imag doubles. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField C128S { + get { return c128S_; } + } + + /// Field number for the "tuple_literals" field. + public const int TupleLiteralsFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_tupleLiterals_codec + = pb::FieldCodec.ForMessage(82, global::Xla.LiteralProto.Parser); + private readonly pbc::RepeatedField tupleLiterals_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TupleLiterals { + get { return tupleLiterals_; } + } + + /// Field number for the "f16s" field. + public const int F16SFieldNumber = 11; + private pb::ByteString f16S_ = pb::ByteString.Empty; + /// + /// The F16s, BF16s, U16s and S16s are encoded in little endian byte order + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString F16S { + get { return f16S_; } + set { + f16S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "bf16s" field. + public const int Bf16SFieldNumber = 13; + private pb::ByteString bf16S_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString Bf16S { + get { return bf16S_; } + set { + bf16S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "u16s" field. + public const int U16SFieldNumber = 16; + private pb::ByteString u16S_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString U16S { + get { return u16S_; } + set { + u16S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "s16s" field. + public const int S16SFieldNumber = 17; + private pb::ByteString s16S_ = pb::ByteString.Empty; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pb::ByteString S16S { + get { return s16S_; } + set { + s16S_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "sparse_indices" field. + public const int SparseIndicesFieldNumber = 14; + private static readonly pb::FieldCodec _repeated_sparseIndices_codec + = pb::FieldCodec.ForInt64(114); + private readonly pbc::RepeatedField sparseIndices_ = new pbc::RepeatedField(); + /// + /// Next = 19 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SparseIndices { + get { return sparseIndices_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as LiteralProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(LiteralProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Shape, other.Shape)) return false; + if(!preds_.Equals(other.preds_)) return false; + if (S8S != other.S8S) return false; + if (U8S != other.U8S) return false; + if(!s32S_.Equals(other.s32S_)) return false; + if(!s64S_.Equals(other.s64S_)) return false; + if(!u32S_.Equals(other.u32S_)) return false; + if(!u64S_.Equals(other.u64S_)) return false; + if(!f32S_.Equals(other.f32S_)) return false; + if(!f64S_.Equals(other.f64S_)) return false; + if(!c64S_.Equals(other.c64S_)) return false; + if(!c128S_.Equals(other.c128S_)) return false; + if(!tupleLiterals_.Equals(other.tupleLiterals_)) return false; + if (F16S != other.F16S) return false; + if (Bf16S != other.Bf16S) return false; + if (U16S != other.U16S) return false; + if (S16S != other.S16S) return false; + if(!sparseIndices_.Equals(other.sparseIndices_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (shape_ != null) hash ^= Shape.GetHashCode(); + hash ^= preds_.GetHashCode(); + if (S8S.Length != 0) hash ^= S8S.GetHashCode(); + if (U8S.Length != 0) hash ^= U8S.GetHashCode(); + hash ^= s32S_.GetHashCode(); + hash ^= s64S_.GetHashCode(); + hash ^= u32S_.GetHashCode(); + hash ^= u64S_.GetHashCode(); + hash ^= f32S_.GetHashCode(); + hash ^= f64S_.GetHashCode(); + hash ^= c64S_.GetHashCode(); + hash ^= c128S_.GetHashCode(); + hash ^= tupleLiterals_.GetHashCode(); + if (F16S.Length != 0) hash ^= F16S.GetHashCode(); + if (Bf16S.Length != 0) hash ^= Bf16S.GetHashCode(); + if (U16S.Length != 0) hash ^= U16S.GetHashCode(); + if (S16S.Length != 0) hash ^= S16S.GetHashCode(); + hash ^= sparseIndices_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + preds_.WriteTo(output, _repeated_preds_codec); + if (U8S.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(U8S); + } + s32S_.WriteTo(output, _repeated_s32S_codec); + s64S_.WriteTo(output, _repeated_s64S_codec); + u32S_.WriteTo(output, _repeated_u32S_codec); + u64S_.WriteTo(output, _repeated_u64S_codec); + f32S_.WriteTo(output, _repeated_f32S_codec); + f64S_.WriteTo(output, _repeated_f64S_codec); + tupleLiterals_.WriteTo(output, _repeated_tupleLiterals_codec); + if (F16S.Length != 0) { + output.WriteRawTag(90); + output.WriteBytes(F16S); + } + c64S_.WriteTo(output, _repeated_c64S_codec); + if (Bf16S.Length != 0) { + output.WriteRawTag(106); + output.WriteBytes(Bf16S); + } + sparseIndices_.WriteTo(output, _repeated_sparseIndices_codec); + if (S8S.Length != 0) { + output.WriteRawTag(122); + output.WriteBytes(S8S); + } + if (U16S.Length != 0) { + output.WriteRawTag(130, 1); + output.WriteBytes(U16S); + } + if (S16S.Length != 0) { + output.WriteRawTag(138, 1); + output.WriteBytes(S16S); + } + c128S_.WriteTo(output, _repeated_c128S_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (shape_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Shape); + } + preds_.WriteTo(ref output, _repeated_preds_codec); + if (U8S.Length != 0) { + output.WriteRawTag(26); + output.WriteBytes(U8S); + } + s32S_.WriteTo(ref output, _repeated_s32S_codec); + s64S_.WriteTo(ref output, _repeated_s64S_codec); + u32S_.WriteTo(ref output, _repeated_u32S_codec); + u64S_.WriteTo(ref output, _repeated_u64S_codec); + f32S_.WriteTo(ref output, _repeated_f32S_codec); + f64S_.WriteTo(ref output, _repeated_f64S_codec); + tupleLiterals_.WriteTo(ref output, _repeated_tupleLiterals_codec); + if (F16S.Length != 0) { + output.WriteRawTag(90); + output.WriteBytes(F16S); + } + c64S_.WriteTo(ref output, _repeated_c64S_codec); + if (Bf16S.Length != 0) { + output.WriteRawTag(106); + output.WriteBytes(Bf16S); + } + sparseIndices_.WriteTo(ref output, _repeated_sparseIndices_codec); + if (S8S.Length != 0) { + output.WriteRawTag(122); + output.WriteBytes(S8S); + } + if (U16S.Length != 0) { + output.WriteRawTag(130, 1); + output.WriteBytes(U16S); + } + if (S16S.Length != 0) { + output.WriteRawTag(138, 1); + output.WriteBytes(S16S); + } + c128S_.WriteTo(ref output, _repeated_c128S_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (shape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + size += preds_.CalculateSize(_repeated_preds_codec); + if (S8S.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(S8S); + } + if (U8S.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(U8S); + } + size += s32S_.CalculateSize(_repeated_s32S_codec); + size += s64S_.CalculateSize(_repeated_s64S_codec); + size += u32S_.CalculateSize(_repeated_u32S_codec); + size += u64S_.CalculateSize(_repeated_u64S_codec); + size += f32S_.CalculateSize(_repeated_f32S_codec); + size += f64S_.CalculateSize(_repeated_f64S_codec); + size += c64S_.CalculateSize(_repeated_c64S_codec); + size += c128S_.CalculateSize(_repeated_c128S_codec); + size += tupleLiterals_.CalculateSize(_repeated_tupleLiterals_codec); + if (F16S.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(F16S); + } + if (Bf16S.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Bf16S); + } + if (U16S.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeBytesSize(U16S); + } + if (S16S.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeBytesSize(S16S); + } + size += sparseIndices_.CalculateSize(_repeated_sparseIndices_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(LiteralProto other) { + if (other == null) { + return; + } + if (other.shape_ != null) { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + preds_.Add(other.preds_); + if (other.S8S.Length != 0) { + S8S = other.S8S; + } + if (other.U8S.Length != 0) { + U8S = other.U8S; + } + s32S_.Add(other.s32S_); + s64S_.Add(other.s64S_); + u32S_.Add(other.u32S_); + u64S_.Add(other.u64S_); + f32S_.Add(other.f32S_); + f64S_.Add(other.f64S_); + c64S_.Add(other.c64S_); + c128S_.Add(other.c128S_); + tupleLiterals_.Add(other.tupleLiterals_); + if (other.F16S.Length != 0) { + F16S = other.F16S; + } + if (other.Bf16S.Length != 0) { + Bf16S = other.Bf16S; + } + if (other.U16S.Length != 0) { + U16S = other.U16S; + } + if (other.S16S.Length != 0) { + S16S = other.S16S; + } + sparseIndices_.Add(other.sparseIndices_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 18: + case 16: { + preds_.AddEntriesFrom(input, _repeated_preds_codec); + break; + } + case 26: { + U8S = input.ReadBytes(); + break; + } + case 34: + case 32: { + s32S_.AddEntriesFrom(input, _repeated_s32S_codec); + break; + } + case 42: + case 40: { + s64S_.AddEntriesFrom(input, _repeated_s64S_codec); + break; + } + case 50: + case 48: { + u32S_.AddEntriesFrom(input, _repeated_u32S_codec); + break; + } + case 58: + case 56: { + u64S_.AddEntriesFrom(input, _repeated_u64S_codec); + break; + } + case 66: + case 69: { + f32S_.AddEntriesFrom(input, _repeated_f32S_codec); + break; + } + case 74: + case 73: { + f64S_.AddEntriesFrom(input, _repeated_f64S_codec); + break; + } + case 82: { + tupleLiterals_.AddEntriesFrom(input, _repeated_tupleLiterals_codec); + break; + } + case 90: { + F16S = input.ReadBytes(); + break; + } + case 98: + case 101: { + c64S_.AddEntriesFrom(input, _repeated_c64S_codec); + break; + } + case 106: { + Bf16S = input.ReadBytes(); + break; + } + case 114: + case 112: { + sparseIndices_.AddEntriesFrom(input, _repeated_sparseIndices_codec); + break; + } + case 122: { + S8S = input.ReadBytes(); + break; + } + case 130: { + U16S = input.ReadBytes(); + break; + } + case 138: { + S16S = input.ReadBytes(); + break; + } + case 146: + case 145: { + c128S_.AddEntriesFrom(input, _repeated_c128S_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (shape_ == null) { + Shape = new global::Xla.ShapeProto(); + } + input.ReadMessage(Shape); + break; + } + case 18: + case 16: { + preds_.AddEntriesFrom(ref input, _repeated_preds_codec); + break; + } + case 26: { + U8S = input.ReadBytes(); + break; + } + case 34: + case 32: { + s32S_.AddEntriesFrom(ref input, _repeated_s32S_codec); + break; + } + case 42: + case 40: { + s64S_.AddEntriesFrom(ref input, _repeated_s64S_codec); + break; + } + case 50: + case 48: { + u32S_.AddEntriesFrom(ref input, _repeated_u32S_codec); + break; + } + case 58: + case 56: { + u64S_.AddEntriesFrom(ref input, _repeated_u64S_codec); + break; + } + case 66: + case 69: { + f32S_.AddEntriesFrom(ref input, _repeated_f32S_codec); + break; + } + case 74: + case 73: { + f64S_.AddEntriesFrom(ref input, _repeated_f64S_codec); + break; + } + case 82: { + tupleLiterals_.AddEntriesFrom(ref input, _repeated_tupleLiterals_codec); + break; + } + case 90: { + F16S = input.ReadBytes(); + break; + } + case 98: + case 101: { + c64S_.AddEntriesFrom(ref input, _repeated_c64S_codec); + break; + } + case 106: { + Bf16S = input.ReadBytes(); + break; + } + case 114: + case 112: { + sparseIndices_.AddEntriesFrom(ref input, _repeated_sparseIndices_codec); + break; + } + case 122: { + S8S = input.ReadBytes(); + break; + } + case 130: { + U16S = input.ReadBytes(); + break; + } + case 138: { + S16S = input.ReadBytes(); + break; + } + case 146: + case 145: { + c128S_.AddEntriesFrom(ref input, _repeated_c128S_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class WindowDimension : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WindowDimension()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WindowDimension() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WindowDimension(WindowDimension other) : this() { + size_ = other.size_; + stride_ = other.stride_; + paddingLow_ = other.paddingLow_; + paddingHigh_ = other.paddingHigh_; + windowDilation_ = other.windowDilation_; + baseDilation_ = other.baseDilation_; + windowReversal_ = other.windowReversal_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WindowDimension Clone() { + return new WindowDimension(this); + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 1; + private long size_; + /// + /// The size of the window in this dimension. For a rectangle, this would be + /// the width or height. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + /// Field number for the "stride" field. + public const int StrideFieldNumber = 2; + private long stride_; + /// + /// The stride at which the window moves across the base area in this + /// dimension. In other words, this is the spacing between different + /// positions of the window in this dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Stride { + get { return stride_; } + set { + stride_ = value; + } + } + + /// Field number for the "padding_low" field. + public const int PaddingLowFieldNumber = 3; + private long paddingLow_; + /// + /// If positive, means the amount of padding to add to the base area at the low + /// end of this dimension; if negative, its negative means the number of + /// elements removed from the low end of this dimension. For example, in the + /// horizontal dimension of a rectangle, this would be the number of padding + /// values to pad on the left, given that indices increase when going right. + /// The actual padding value depends upon the context. Convolution pads with + /// zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + /// init value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PaddingLow { + get { return paddingLow_; } + set { + paddingLow_ = value; + } + } + + /// Field number for the "padding_high" field. + public const int PaddingHighFieldNumber = 4; + private long paddingHigh_; + /// + /// As padding_low, but on the high end of this dimension. For example, in the + /// horizontal dimension of a rectangle, this would be the number of values to + /// pad on the right, given that indices increase when going right. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long PaddingHigh { + get { return paddingHigh_; } + set { + paddingHigh_ = value; + } + } + + /// Field number for the "window_dilation" field. + public const int WindowDilationFieldNumber = 5; + private long windowDilation_; + /// + /// Dilation factor of the sliding window in this dimension. A dilation factor + /// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are + /// implicitly placed between each kernel element. This value may not be less + /// than 1. See documentation for convolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long WindowDilation { + get { return windowDilation_; } + set { + windowDilation_ = value; + } + } + + /// Field number for the "base_dilation" field. + public const int BaseDilationFieldNumber = 6; + private long baseDilation_; + /// + /// Dilation factor of the base area in this dimension. A dilation factor of 1 + /// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly + /// placed between each base area element. This value may not be less than 1. + /// See documentation for convolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long BaseDilation { + get { return baseDilation_; } + set { + baseDilation_ = value; + } + } + + /// Field number for the "window_reversal" field. + public const int WindowReversalFieldNumber = 7; + private bool windowReversal_; + /// + /// Window reversal means that this dimension was logically reversed before the + /// operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool WindowReversal { + get { return windowReversal_; } + set { + windowReversal_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WindowDimension); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WindowDimension other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Size != other.Size) return false; + if (Stride != other.Stride) return false; + if (PaddingLow != other.PaddingLow) return false; + if (PaddingHigh != other.PaddingHigh) return false; + if (WindowDilation != other.WindowDilation) return false; + if (BaseDilation != other.BaseDilation) return false; + if (WindowReversal != other.WindowReversal) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Size != 0L) hash ^= Size.GetHashCode(); + if (Stride != 0L) hash ^= Stride.GetHashCode(); + if (PaddingLow != 0L) hash ^= PaddingLow.GetHashCode(); + if (PaddingHigh != 0L) hash ^= PaddingHigh.GetHashCode(); + if (WindowDilation != 0L) hash ^= WindowDilation.GetHashCode(); + if (BaseDilation != 0L) hash ^= BaseDilation.GetHashCode(); + if (WindowReversal != false) hash ^= WindowReversal.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (Stride != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Stride); + } + if (PaddingLow != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PaddingLow); + } + if (PaddingHigh != 0L) { + output.WriteRawTag(32); + output.WriteInt64(PaddingHigh); + } + if (WindowDilation != 0L) { + output.WriteRawTag(40); + output.WriteInt64(WindowDilation); + } + if (BaseDilation != 0L) { + output.WriteRawTag(48); + output.WriteInt64(BaseDilation); + } + if (WindowReversal != false) { + output.WriteRawTag(56); + output.WriteBool(WindowReversal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (Stride != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Stride); + } + if (PaddingLow != 0L) { + output.WriteRawTag(24); + output.WriteInt64(PaddingLow); + } + if (PaddingHigh != 0L) { + output.WriteRawTag(32); + output.WriteInt64(PaddingHigh); + } + if (WindowDilation != 0L) { + output.WriteRawTag(40); + output.WriteInt64(WindowDilation); + } + if (BaseDilation != 0L) { + output.WriteRawTag(48); + output.WriteInt64(BaseDilation); + } + if (WindowReversal != false) { + output.WriteRawTag(56); + output.WriteBool(WindowReversal); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (Stride != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Stride); + } + if (PaddingLow != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PaddingLow); + } + if (PaddingHigh != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(PaddingHigh); + } + if (WindowDilation != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(WindowDilation); + } + if (BaseDilation != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(BaseDilation); + } + if (WindowReversal != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WindowDimension other) { + if (other == null) { + return; + } + if (other.Size != 0L) { + Size = other.Size; + } + if (other.Stride != 0L) { + Stride = other.Stride; + } + if (other.PaddingLow != 0L) { + PaddingLow = other.PaddingLow; + } + if (other.PaddingHigh != 0L) { + PaddingHigh = other.PaddingHigh; + } + if (other.WindowDilation != 0L) { + WindowDilation = other.WindowDilation; + } + if (other.BaseDilation != 0L) { + BaseDilation = other.BaseDilation; + } + if (other.WindowReversal != false) { + WindowReversal = other.WindowReversal; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 16: { + Stride = input.ReadInt64(); + break; + } + case 24: { + PaddingLow = input.ReadInt64(); + break; + } + case 32: { + PaddingHigh = input.ReadInt64(); + break; + } + case 40: { + WindowDilation = input.ReadInt64(); + break; + } + case 48: { + BaseDilation = input.ReadInt64(); + break; + } + case 56: { + WindowReversal = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 16: { + Stride = input.ReadInt64(); + break; + } + case 24: { + PaddingLow = input.ReadInt64(); + break; + } + case 32: { + PaddingHigh = input.ReadInt64(); + break; + } + case 40: { + WindowDilation = input.ReadInt64(); + break; + } + case 48: { + BaseDilation = input.ReadInt64(); + break; + } + case 56: { + WindowReversal = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + /// + /// Describes the windowing in an operation such as convolution. + /// + /// The window is moved across a base area and for each position of the + /// window a computation is performed. The field below describes the + /// window and the movement of the window across a base area. + /// + public sealed partial class Window : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Window()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Window() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Window(Window other) : this() { + dimensions_ = other.dimensions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public Window Clone() { + return new Window(this); + } + + /// Field number for the "dimensions" field. + public const int DimensionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_dimensions_codec + = pb::FieldCodec.ForMessage(10, global::Xla.WindowDimension.Parser); + private readonly pbc::RepeatedField dimensions_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Dimensions { + get { return dimensions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as Window); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(Window other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!dimensions_.Equals(other.dimensions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= dimensions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + dimensions_.WriteTo(output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + dimensions_.WriteTo(ref output, _repeated_dimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += dimensions_.CalculateSize(_repeated_dimensions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(Window other) { + if (other == null) { + return; + } + dimensions_.Add(other.dimensions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + dimensions_.AddEntriesFrom(input, _repeated_dimensions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + dimensions_.AddEntriesFrom(ref input, _repeated_dimensions_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Describes the dimension numbers for a gather operation. + /// + /// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for + /// more details. + /// + public sealed partial class GatherDimensionNumbers : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GatherDimensionNumbers()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GatherDimensionNumbers() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GatherDimensionNumbers(GatherDimensionNumbers other) : this() { + offsetDims_ = other.offsetDims_.Clone(); + collapsedSliceDims_ = other.collapsedSliceDims_.Clone(); + startIndexMap_ = other.startIndexMap_.Clone(); + indexVectorDim_ = other.indexVectorDim_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public GatherDimensionNumbers Clone() { + return new GatherDimensionNumbers(this); + } + + /// Field number for the "offset_dims" field. + public const int OffsetDimsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_offsetDims_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField offsetDims_ = new pbc::RepeatedField(); + /// + /// "Window indices" is a term for a set of indices that index into the + /// interior of a dynamic-slice from the input tensor, the starting indices for + /// which were computed from output_gather_dims (see the operation semantic for + /// how this is defined) and the start_indices tensor. + /// + /// The window indices for a specific output index Out is computed as: + /// + /// i = 0 + /// for (k : [0, input_tensor_shape.rank)) + /// window_indices[k] = + /// if k in collapsed_slice_dims + /// then 0 + /// else Out[offset_dims[i++]] + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OffsetDims { + get { return offsetDims_; } + } + + /// Field number for the "collapsed_slice_dims" field. + public const int CollapsedSliceDimsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_collapsedSliceDims_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField collapsedSliceDims_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField CollapsedSliceDims { + get { return collapsedSliceDims_; } + } + + /// Field number for the "start_index_map" field. + public const int StartIndexMapFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_startIndexMap_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField startIndexMap_ = new pbc::RepeatedField(); + /// + /// This is interpreted as a map from i to start_index_map[i]. It + /// transforms the gather index looked up from the start_indices tensor into + /// the starting index in the input space. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField StartIndexMap { + get { return startIndexMap_; } + } + + /// Field number for the "index_vector_dim" field. + public const int IndexVectorDimFieldNumber = 4; + private long indexVectorDim_; + /// + /// The dimension in the start_indices input that contains the starting + /// indices. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long IndexVectorDim { + get { return indexVectorDim_; } + set { + indexVectorDim_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as GatherDimensionNumbers); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(GatherDimensionNumbers other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!offsetDims_.Equals(other.offsetDims_)) return false; + if(!collapsedSliceDims_.Equals(other.collapsedSliceDims_)) return false; + if(!startIndexMap_.Equals(other.startIndexMap_)) return false; + if (IndexVectorDim != other.IndexVectorDim) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= offsetDims_.GetHashCode(); + hash ^= collapsedSliceDims_.GetHashCode(); + hash ^= startIndexMap_.GetHashCode(); + if (IndexVectorDim != 0L) hash ^= IndexVectorDim.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + offsetDims_.WriteTo(output, _repeated_offsetDims_codec); + collapsedSliceDims_.WriteTo(output, _repeated_collapsedSliceDims_codec); + startIndexMap_.WriteTo(output, _repeated_startIndexMap_codec); + if (IndexVectorDim != 0L) { + output.WriteRawTag(32); + output.WriteInt64(IndexVectorDim); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + offsetDims_.WriteTo(ref output, _repeated_offsetDims_codec); + collapsedSliceDims_.WriteTo(ref output, _repeated_collapsedSliceDims_codec); + startIndexMap_.WriteTo(ref output, _repeated_startIndexMap_codec); + if (IndexVectorDim != 0L) { + output.WriteRawTag(32); + output.WriteInt64(IndexVectorDim); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += offsetDims_.CalculateSize(_repeated_offsetDims_codec); + size += collapsedSliceDims_.CalculateSize(_repeated_collapsedSliceDims_codec); + size += startIndexMap_.CalculateSize(_repeated_startIndexMap_codec); + if (IndexVectorDim != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(IndexVectorDim); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(GatherDimensionNumbers other) { + if (other == null) { + return; + } + offsetDims_.Add(other.offsetDims_); + collapsedSliceDims_.Add(other.collapsedSliceDims_); + startIndexMap_.Add(other.startIndexMap_); + if (other.IndexVectorDim != 0L) { + IndexVectorDim = other.IndexVectorDim; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + offsetDims_.AddEntriesFrom(input, _repeated_offsetDims_codec); + break; + } + case 18: + case 16: { + collapsedSliceDims_.AddEntriesFrom(input, _repeated_collapsedSliceDims_codec); + break; + } + case 26: + case 24: { + startIndexMap_.AddEntriesFrom(input, _repeated_startIndexMap_codec); + break; + } + case 32: { + IndexVectorDim = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + offsetDims_.AddEntriesFrom(ref input, _repeated_offsetDims_codec); + break; + } + case 18: + case 16: { + collapsedSliceDims_.AddEntriesFrom(ref input, _repeated_collapsedSliceDims_codec); + break; + } + case 26: + case 24: { + startIndexMap_.AddEntriesFrom(ref input, _repeated_startIndexMap_codec); + break; + } + case 32: { + IndexVectorDim = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Describes the dimension numbers for a scatter operation. + /// + /// All the fields are similar to the corresponding fields in + /// GatherDimensionNumbers. Differences are noted below. + /// + public sealed partial class ScatterDimensionNumbers : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ScatterDimensionNumbers()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[17]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScatterDimensionNumbers() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScatterDimensionNumbers(ScatterDimensionNumbers other) : this() { + updateWindowDims_ = other.updateWindowDims_.Clone(); + insertedWindowDims_ = other.insertedWindowDims_.Clone(); + scatterDimsToOperandDims_ = other.scatterDimsToOperandDims_.Clone(); + indexVectorDim_ = other.indexVectorDim_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ScatterDimensionNumbers Clone() { + return new ScatterDimensionNumbers(this); + } + + /// Field number for the "update_window_dims" field. + public const int UpdateWindowDimsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_updateWindowDims_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField updateWindowDims_ = new pbc::RepeatedField(); + /// + /// The set of dimensions in the updates shape that are window dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField UpdateWindowDims { + get { return updateWindowDims_; } + } + + /// Field number for the "inserted_window_dims" field. + public const int InsertedWindowDimsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_insertedWindowDims_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField insertedWindowDims_ = new pbc::RepeatedField(); + /// + /// The set of window dimensions that must be inserted into the updates shape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InsertedWindowDims { + get { return insertedWindowDims_; } + } + + /// Field number for the "scatter_dims_to_operand_dims" field. + public const int ScatterDimsToOperandDimsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_scatterDimsToOperandDims_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField scatterDimsToOperandDims_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ScatterDimsToOperandDims { + get { return scatterDimsToOperandDims_; } + } + + /// Field number for the "index_vector_dim" field. + public const int IndexVectorDimFieldNumber = 4; + private long indexVectorDim_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long IndexVectorDim { + get { return indexVectorDim_; } + set { + indexVectorDim_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ScatterDimensionNumbers); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ScatterDimensionNumbers other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!updateWindowDims_.Equals(other.updateWindowDims_)) return false; + if(!insertedWindowDims_.Equals(other.insertedWindowDims_)) return false; + if(!scatterDimsToOperandDims_.Equals(other.scatterDimsToOperandDims_)) return false; + if (IndexVectorDim != other.IndexVectorDim) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= updateWindowDims_.GetHashCode(); + hash ^= insertedWindowDims_.GetHashCode(); + hash ^= scatterDimsToOperandDims_.GetHashCode(); + if (IndexVectorDim != 0L) hash ^= IndexVectorDim.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + updateWindowDims_.WriteTo(output, _repeated_updateWindowDims_codec); + insertedWindowDims_.WriteTo(output, _repeated_insertedWindowDims_codec); + scatterDimsToOperandDims_.WriteTo(output, _repeated_scatterDimsToOperandDims_codec); + if (IndexVectorDim != 0L) { + output.WriteRawTag(32); + output.WriteInt64(IndexVectorDim); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + updateWindowDims_.WriteTo(ref output, _repeated_updateWindowDims_codec); + insertedWindowDims_.WriteTo(ref output, _repeated_insertedWindowDims_codec); + scatterDimsToOperandDims_.WriteTo(ref output, _repeated_scatterDimsToOperandDims_codec); + if (IndexVectorDim != 0L) { + output.WriteRawTag(32); + output.WriteInt64(IndexVectorDim); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += updateWindowDims_.CalculateSize(_repeated_updateWindowDims_codec); + size += insertedWindowDims_.CalculateSize(_repeated_insertedWindowDims_codec); + size += scatterDimsToOperandDims_.CalculateSize(_repeated_scatterDimsToOperandDims_codec); + if (IndexVectorDim != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(IndexVectorDim); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ScatterDimensionNumbers other) { + if (other == null) { + return; + } + updateWindowDims_.Add(other.updateWindowDims_); + insertedWindowDims_.Add(other.insertedWindowDims_); + scatterDimsToOperandDims_.Add(other.scatterDimsToOperandDims_); + if (other.IndexVectorDim != 0L) { + IndexVectorDim = other.IndexVectorDim; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + updateWindowDims_.AddEntriesFrom(input, _repeated_updateWindowDims_codec); + break; + } + case 18: + case 16: { + insertedWindowDims_.AddEntriesFrom(input, _repeated_insertedWindowDims_codec); + break; + } + case 26: + case 24: { + scatterDimsToOperandDims_.AddEntriesFrom(input, _repeated_scatterDimsToOperandDims_codec); + break; + } + case 32: { + IndexVectorDim = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + updateWindowDims_.AddEntriesFrom(ref input, _repeated_updateWindowDims_codec); + break; + } + case 18: + case 16: { + insertedWindowDims_.AddEntriesFrom(ref input, _repeated_insertedWindowDims_codec); + break; + } + case 26: + case 24: { + scatterDimsToOperandDims_.AddEntriesFrom(ref input, _repeated_scatterDimsToOperandDims_codec); + break; + } + case 32: { + IndexVectorDim = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + public sealed partial class ConvolutionDimensionNumbers : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConvolutionDimensionNumbers()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[18]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConvolutionDimensionNumbers() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConvolutionDimensionNumbers(ConvolutionDimensionNumbers other) : this() { + inputBatchDimension_ = other.inputBatchDimension_; + inputFeatureDimension_ = other.inputFeatureDimension_; + inputSpatialDimensions_ = other.inputSpatialDimensions_.Clone(); + kernelInputFeatureDimension_ = other.kernelInputFeatureDimension_; + kernelOutputFeatureDimension_ = other.kernelOutputFeatureDimension_; + kernelSpatialDimensions_ = other.kernelSpatialDimensions_.Clone(); + outputBatchDimension_ = other.outputBatchDimension_; + outputFeatureDimension_ = other.outputFeatureDimension_; + outputSpatialDimensions_ = other.outputSpatialDimensions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ConvolutionDimensionNumbers Clone() { + return new ConvolutionDimensionNumbers(this); + } + + /// Field number for the "input_batch_dimension" field. + public const int InputBatchDimensionFieldNumber = 7; + private long inputBatchDimension_; + /// + /// The number of the dimension that represents batch in the input. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long InputBatchDimension { + get { return inputBatchDimension_; } + set { + inputBatchDimension_ = value; + } + } + + /// Field number for the "input_feature_dimension" field. + public const int InputFeatureDimensionFieldNumber = 8; + private long inputFeatureDimension_; + /// + /// The number of the dimension that represents features in the input. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long InputFeatureDimension { + get { return inputFeatureDimension_; } + set { + inputFeatureDimension_ = value; + } + } + + /// Field number for the "input_spatial_dimensions" field. + public const int InputSpatialDimensionsFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_inputSpatialDimensions_codec + = pb::FieldCodec.ForInt64(90); + private readonly pbc::RepeatedField inputSpatialDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers for the spatial dimensions that the window + /// moves through in the input. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField InputSpatialDimensions { + get { return inputSpatialDimensions_; } + } + + /// Field number for the "kernel_input_feature_dimension" field. + public const int KernelInputFeatureDimensionFieldNumber = 3; + private long kernelInputFeatureDimension_; + /// + /// The number of the dimension that represents input features in the + /// convolutional kernel (rhs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long KernelInputFeatureDimension { + get { return kernelInputFeatureDimension_; } + set { + kernelInputFeatureDimension_ = value; + } + } + + /// Field number for the "kernel_output_feature_dimension" field. + public const int KernelOutputFeatureDimensionFieldNumber = 4; + private long kernelOutputFeatureDimension_; + /// + /// The number of the dimension that represents output features in + /// the convolutional kernel (rhs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long KernelOutputFeatureDimension { + get { return kernelOutputFeatureDimension_; } + set { + kernelOutputFeatureDimension_ = value; + } + } + + /// Field number for the "kernel_spatial_dimensions" field. + public const int KernelSpatialDimensionsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_kernelSpatialDimensions_codec + = pb::FieldCodec.ForInt64(50); + private readonly pbc::RepeatedField kernelSpatialDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers for the spatial dimensions that the window + /// moves through in the kernel (rhs). window.strides(0) is the + /// stride in the kernel_spatial_dimensions(0) dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField KernelSpatialDimensions { + get { return kernelSpatialDimensions_; } + } + + /// Field number for the "output_batch_dimension" field. + public const int OutputBatchDimensionFieldNumber = 9; + private long outputBatchDimension_; + /// + /// The number of the dimension that represents batch in the output. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OutputBatchDimension { + get { return outputBatchDimension_; } + set { + outputBatchDimension_ = value; + } + } + + /// Field number for the "output_feature_dimension" field. + public const int OutputFeatureDimensionFieldNumber = 10; + private long outputFeatureDimension_; + /// + /// The number of the dimension that represents features in the output. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OutputFeatureDimension { + get { return outputFeatureDimension_; } + set { + outputFeatureDimension_ = value; + } + } + + /// Field number for the "output_spatial_dimensions" field. + public const int OutputSpatialDimensionsFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_outputSpatialDimensions_codec + = pb::FieldCodec.ForInt64(98); + private readonly pbc::RepeatedField outputSpatialDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers for the spatial dimensions that the window + /// moves through in the output. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OutputSpatialDimensions { + get { return outputSpatialDimensions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ConvolutionDimensionNumbers); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ConvolutionDimensionNumbers other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (InputBatchDimension != other.InputBatchDimension) return false; + if (InputFeatureDimension != other.InputFeatureDimension) return false; + if(!inputSpatialDimensions_.Equals(other.inputSpatialDimensions_)) return false; + if (KernelInputFeatureDimension != other.KernelInputFeatureDimension) return false; + if (KernelOutputFeatureDimension != other.KernelOutputFeatureDimension) return false; + if(!kernelSpatialDimensions_.Equals(other.kernelSpatialDimensions_)) return false; + if (OutputBatchDimension != other.OutputBatchDimension) return false; + if (OutputFeatureDimension != other.OutputFeatureDimension) return false; + if(!outputSpatialDimensions_.Equals(other.outputSpatialDimensions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (InputBatchDimension != 0L) hash ^= InputBatchDimension.GetHashCode(); + if (InputFeatureDimension != 0L) hash ^= InputFeatureDimension.GetHashCode(); + hash ^= inputSpatialDimensions_.GetHashCode(); + if (KernelInputFeatureDimension != 0L) hash ^= KernelInputFeatureDimension.GetHashCode(); + if (KernelOutputFeatureDimension != 0L) hash ^= KernelOutputFeatureDimension.GetHashCode(); + hash ^= kernelSpatialDimensions_.GetHashCode(); + if (OutputBatchDimension != 0L) hash ^= OutputBatchDimension.GetHashCode(); + if (OutputFeatureDimension != 0L) hash ^= OutputFeatureDimension.GetHashCode(); + hash ^= outputSpatialDimensions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (KernelInputFeatureDimension != 0L) { + output.WriteRawTag(24); + output.WriteInt64(KernelInputFeatureDimension); + } + if (KernelOutputFeatureDimension != 0L) { + output.WriteRawTag(32); + output.WriteInt64(KernelOutputFeatureDimension); + } + kernelSpatialDimensions_.WriteTo(output, _repeated_kernelSpatialDimensions_codec); + if (InputBatchDimension != 0L) { + output.WriteRawTag(56); + output.WriteInt64(InputBatchDimension); + } + if (InputFeatureDimension != 0L) { + output.WriteRawTag(64); + output.WriteInt64(InputFeatureDimension); + } + if (OutputBatchDimension != 0L) { + output.WriteRawTag(72); + output.WriteInt64(OutputBatchDimension); + } + if (OutputFeatureDimension != 0L) { + output.WriteRawTag(80); + output.WriteInt64(OutputFeatureDimension); + } + inputSpatialDimensions_.WriteTo(output, _repeated_inputSpatialDimensions_codec); + outputSpatialDimensions_.WriteTo(output, _repeated_outputSpatialDimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (KernelInputFeatureDimension != 0L) { + output.WriteRawTag(24); + output.WriteInt64(KernelInputFeatureDimension); + } + if (KernelOutputFeatureDimension != 0L) { + output.WriteRawTag(32); + output.WriteInt64(KernelOutputFeatureDimension); + } + kernelSpatialDimensions_.WriteTo(ref output, _repeated_kernelSpatialDimensions_codec); + if (InputBatchDimension != 0L) { + output.WriteRawTag(56); + output.WriteInt64(InputBatchDimension); + } + if (InputFeatureDimension != 0L) { + output.WriteRawTag(64); + output.WriteInt64(InputFeatureDimension); + } + if (OutputBatchDimension != 0L) { + output.WriteRawTag(72); + output.WriteInt64(OutputBatchDimension); + } + if (OutputFeatureDimension != 0L) { + output.WriteRawTag(80); + output.WriteInt64(OutputFeatureDimension); + } + inputSpatialDimensions_.WriteTo(ref output, _repeated_inputSpatialDimensions_codec); + outputSpatialDimensions_.WriteTo(ref output, _repeated_outputSpatialDimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (InputBatchDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(InputBatchDimension); + } + if (InputFeatureDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(InputFeatureDimension); + } + size += inputSpatialDimensions_.CalculateSize(_repeated_inputSpatialDimensions_codec); + if (KernelInputFeatureDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(KernelInputFeatureDimension); + } + if (KernelOutputFeatureDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(KernelOutputFeatureDimension); + } + size += kernelSpatialDimensions_.CalculateSize(_repeated_kernelSpatialDimensions_codec); + if (OutputBatchDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OutputBatchDimension); + } + if (OutputFeatureDimension != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OutputFeatureDimension); + } + size += outputSpatialDimensions_.CalculateSize(_repeated_outputSpatialDimensions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ConvolutionDimensionNumbers other) { + if (other == null) { + return; + } + if (other.InputBatchDimension != 0L) { + InputBatchDimension = other.InputBatchDimension; + } + if (other.InputFeatureDimension != 0L) { + InputFeatureDimension = other.InputFeatureDimension; + } + inputSpatialDimensions_.Add(other.inputSpatialDimensions_); + if (other.KernelInputFeatureDimension != 0L) { + KernelInputFeatureDimension = other.KernelInputFeatureDimension; + } + if (other.KernelOutputFeatureDimension != 0L) { + KernelOutputFeatureDimension = other.KernelOutputFeatureDimension; + } + kernelSpatialDimensions_.Add(other.kernelSpatialDimensions_); + if (other.OutputBatchDimension != 0L) { + OutputBatchDimension = other.OutputBatchDimension; + } + if (other.OutputFeatureDimension != 0L) { + OutputFeatureDimension = other.OutputFeatureDimension; + } + outputSpatialDimensions_.Add(other.outputSpatialDimensions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 24: { + KernelInputFeatureDimension = input.ReadInt64(); + break; + } + case 32: { + KernelOutputFeatureDimension = input.ReadInt64(); + break; + } + case 50: + case 48: { + kernelSpatialDimensions_.AddEntriesFrom(input, _repeated_kernelSpatialDimensions_codec); + break; + } + case 56: { + InputBatchDimension = input.ReadInt64(); + break; + } + case 64: { + InputFeatureDimension = input.ReadInt64(); + break; + } + case 72: { + OutputBatchDimension = input.ReadInt64(); + break; + } + case 80: { + OutputFeatureDimension = input.ReadInt64(); + break; + } + case 90: + case 88: { + inputSpatialDimensions_.AddEntriesFrom(input, _repeated_inputSpatialDimensions_codec); + break; + } + case 98: + case 96: { + outputSpatialDimensions_.AddEntriesFrom(input, _repeated_outputSpatialDimensions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 24: { + KernelInputFeatureDimension = input.ReadInt64(); + break; + } + case 32: { + KernelOutputFeatureDimension = input.ReadInt64(); + break; + } + case 50: + case 48: { + kernelSpatialDimensions_.AddEntriesFrom(ref input, _repeated_kernelSpatialDimensions_codec); + break; + } + case 56: { + InputBatchDimension = input.ReadInt64(); + break; + } + case 64: { + InputFeatureDimension = input.ReadInt64(); + break; + } + case 72: { + OutputBatchDimension = input.ReadInt64(); + break; + } + case 80: { + OutputFeatureDimension = input.ReadInt64(); + break; + } + case 90: + case 88: { + inputSpatialDimensions_.AddEntriesFrom(ref input, _repeated_inputSpatialDimensions_codec); + break; + } + case 98: + case 96: { + outputSpatialDimensions_.AddEntriesFrom(ref input, _repeated_outputSpatialDimensions_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class DotDimensionNumbers : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DotDimensionNumbers()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[19]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DotDimensionNumbers() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DotDimensionNumbers(DotDimensionNumbers other) : this() { + lhsContractingDimensions_ = other.lhsContractingDimensions_.Clone(); + rhsContractingDimensions_ = other.rhsContractingDimensions_.Clone(); + lhsBatchDimensions_ = other.lhsBatchDimensions_.Clone(); + rhsBatchDimensions_ = other.rhsBatchDimensions_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public DotDimensionNumbers Clone() { + return new DotDimensionNumbers(this); + } + + /// Field number for the "lhs_contracting_dimensions" field. + public const int LhsContractingDimensionsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_lhsContractingDimensions_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField lhsContractingDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers that represent the 'lhs' contracting dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LhsContractingDimensions { + get { return lhsContractingDimensions_; } + } + + /// Field number for the "rhs_contracting_dimensions" field. + public const int RhsContractingDimensionsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_rhsContractingDimensions_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField rhsContractingDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers that represent the 'rhs' contracting dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField RhsContractingDimensions { + get { return rhsContractingDimensions_; } + } + + /// Field number for the "lhs_batch_dimensions" field. + public const int LhsBatchDimensionsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_lhsBatchDimensions_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField lhsBatchDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers that represent the 'lhs' batch dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LhsBatchDimensions { + get { return lhsBatchDimensions_; } + } + + /// Field number for the "rhs_batch_dimensions" field. + public const int RhsBatchDimensionsFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_rhsBatchDimensions_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField rhsBatchDimensions_ = new pbc::RepeatedField(); + /// + /// The dimension numbers that represent the 'rhs' batch dimensions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField RhsBatchDimensions { + get { return rhsBatchDimensions_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as DotDimensionNumbers); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(DotDimensionNumbers other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!lhsContractingDimensions_.Equals(other.lhsContractingDimensions_)) return false; + if(!rhsContractingDimensions_.Equals(other.rhsContractingDimensions_)) return false; + if(!lhsBatchDimensions_.Equals(other.lhsBatchDimensions_)) return false; + if(!rhsBatchDimensions_.Equals(other.rhsBatchDimensions_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= lhsContractingDimensions_.GetHashCode(); + hash ^= rhsContractingDimensions_.GetHashCode(); + hash ^= lhsBatchDimensions_.GetHashCode(); + hash ^= rhsBatchDimensions_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + lhsContractingDimensions_.WriteTo(output, _repeated_lhsContractingDimensions_codec); + rhsContractingDimensions_.WriteTo(output, _repeated_rhsContractingDimensions_codec); + lhsBatchDimensions_.WriteTo(output, _repeated_lhsBatchDimensions_codec); + rhsBatchDimensions_.WriteTo(output, _repeated_rhsBatchDimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + lhsContractingDimensions_.WriteTo(ref output, _repeated_lhsContractingDimensions_codec); + rhsContractingDimensions_.WriteTo(ref output, _repeated_rhsContractingDimensions_codec); + lhsBatchDimensions_.WriteTo(ref output, _repeated_lhsBatchDimensions_codec); + rhsBatchDimensions_.WriteTo(ref output, _repeated_rhsBatchDimensions_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += lhsContractingDimensions_.CalculateSize(_repeated_lhsContractingDimensions_codec); + size += rhsContractingDimensions_.CalculateSize(_repeated_rhsContractingDimensions_codec); + size += lhsBatchDimensions_.CalculateSize(_repeated_lhsBatchDimensions_codec); + size += rhsBatchDimensions_.CalculateSize(_repeated_rhsBatchDimensions_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(DotDimensionNumbers other) { + if (other == null) { + return; + } + lhsContractingDimensions_.Add(other.lhsContractingDimensions_); + rhsContractingDimensions_.Add(other.rhsContractingDimensions_); + lhsBatchDimensions_.Add(other.lhsBatchDimensions_); + rhsBatchDimensions_.Add(other.rhsBatchDimensions_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + lhsContractingDimensions_.AddEntriesFrom(input, _repeated_lhsContractingDimensions_codec); + break; + } + case 18: + case 16: { + rhsContractingDimensions_.AddEntriesFrom(input, _repeated_rhsContractingDimensions_codec); + break; + } + case 26: + case 24: { + lhsBatchDimensions_.AddEntriesFrom(input, _repeated_lhsBatchDimensions_codec); + break; + } + case 34: + case 32: { + rhsBatchDimensions_.AddEntriesFrom(input, _repeated_rhsBatchDimensions_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + lhsContractingDimensions_.AddEntriesFrom(ref input, _repeated_lhsContractingDimensions_codec); + break; + } + case 18: + case 16: { + rhsContractingDimensions_.AddEntriesFrom(ref input, _repeated_rhsContractingDimensions_codec); + break; + } + case 26: + case 24: { + lhsBatchDimensions_.AddEntriesFrom(ref input, _repeated_lhsBatchDimensions_codec); + break; + } + case 34: + case 32: { + rhsBatchDimensions_.AddEntriesFrom(ref input, _repeated_rhsBatchDimensions_codec); + break; + } + } + } + } + #endif + + } + + public sealed partial class TriangularSolveOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TriangularSolveOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[20]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TriangularSolveOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TriangularSolveOptions(TriangularSolveOptions other) : this() { + leftSide_ = other.leftSide_; + lower_ = other.lower_; + unitDiagonal_ = other.unitDiagonal_; + transposeA_ = other.transposeA_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public TriangularSolveOptions Clone() { + return new TriangularSolveOptions(this); + } + + /// Field number for the "left_side" field. + public const int LeftSideFieldNumber = 1; + private bool leftSide_; + /// + /// If true, solves ax = b. If false, solves xa = b. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool LeftSide { + get { return leftSide_; } + set { + leftSide_ = value; + } + } + + /// Field number for the "lower" field. + public const int LowerFieldNumber = 2; + private bool lower_; + /// + /// If true, 'a' is lower triangular. If false, 'a' is upper triangular. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Lower { + get { return lower_; } + set { + lower_ = value; + } + } + + /// Field number for the "unit_diagonal" field. + public const int UnitDiagonalFieldNumber = 3; + private bool unitDiagonal_; + /// + /// If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool UnitDiagonal { + get { return unitDiagonal_; } + set { + unitDiagonal_ = value; + } + } + + /// Field number for the "transpose_a" field. + public const int TransposeAFieldNumber = 4; + private global::Xla.TriangularSolveOptions.Types.Transpose transposeA_ = global::Xla.TriangularSolveOptions.Types.Transpose.Invalid; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.TriangularSolveOptions.Types.Transpose TransposeA { + get { return transposeA_; } + set { + transposeA_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as TriangularSolveOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(TriangularSolveOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (LeftSide != other.LeftSide) return false; + if (Lower != other.Lower) return false; + if (UnitDiagonal != other.UnitDiagonal) return false; + if (TransposeA != other.TransposeA) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (LeftSide != false) hash ^= LeftSide.GetHashCode(); + if (Lower != false) hash ^= Lower.GetHashCode(); + if (UnitDiagonal != false) hash ^= UnitDiagonal.GetHashCode(); + if (TransposeA != global::Xla.TriangularSolveOptions.Types.Transpose.Invalid) hash ^= TransposeA.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (LeftSide != false) { + output.WriteRawTag(8); + output.WriteBool(LeftSide); + } + if (Lower != false) { + output.WriteRawTag(16); + output.WriteBool(Lower); + } + if (UnitDiagonal != false) { + output.WriteRawTag(24); + output.WriteBool(UnitDiagonal); + } + if (TransposeA != global::Xla.TriangularSolveOptions.Types.Transpose.Invalid) { + output.WriteRawTag(32); + output.WriteEnum((int) TransposeA); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (LeftSide != false) { + output.WriteRawTag(8); + output.WriteBool(LeftSide); + } + if (Lower != false) { + output.WriteRawTag(16); + output.WriteBool(Lower); + } + if (UnitDiagonal != false) { + output.WriteRawTag(24); + output.WriteBool(UnitDiagonal); + } + if (TransposeA != global::Xla.TriangularSolveOptions.Types.Transpose.Invalid) { + output.WriteRawTag(32); + output.WriteEnum((int) TransposeA); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (LeftSide != false) { + size += 1 + 1; + } + if (Lower != false) { + size += 1 + 1; + } + if (UnitDiagonal != false) { + size += 1 + 1; + } + if (TransposeA != global::Xla.TriangularSolveOptions.Types.Transpose.Invalid) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TransposeA); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(TriangularSolveOptions other) { + if (other == null) { + return; + } + if (other.LeftSide != false) { + LeftSide = other.LeftSide; + } + if (other.Lower != false) { + Lower = other.Lower; + } + if (other.UnitDiagonal != false) { + UnitDiagonal = other.UnitDiagonal; + } + if (other.TransposeA != global::Xla.TriangularSolveOptions.Types.Transpose.Invalid) { + TransposeA = other.TransposeA; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + LeftSide = input.ReadBool(); + break; + } + case 16: { + Lower = input.ReadBool(); + break; + } + case 24: { + UnitDiagonal = input.ReadBool(); + break; + } + case 32: { + TransposeA = (global::Xla.TriangularSolveOptions.Types.Transpose) input.ReadEnum(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + LeftSide = input.ReadBool(); + break; + } + case 16: { + Lower = input.ReadBool(); + break; + } + case 24: { + UnitDiagonal = input.ReadBool(); + break; + } + case 32: { + TransposeA = (global::Xla.TriangularSolveOptions.Types.Transpose) input.ReadEnum(); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the TriangularSolveOptions message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + /// + /// Should we transpose or use the adjoint of 'a'? + /// + public enum Transpose { + [pbr::OriginalName("TRANSPOSE_INVALID")] Invalid = 0, + /// + /// Don't transpose 'a'. + /// + [pbr::OriginalName("NO_TRANSPOSE")] NoTranspose = 1, + /// + /// Transpose 'a'. + /// + [pbr::OriginalName("TRANSPOSE")] Transpose = 2, + /// + /// Complex conjugate and transpose 'a'. + /// + [pbr::OriginalName("ADJOINT")] Adjoint = 3, + } + + } + #endregion + + } + + public sealed partial class CholeskyOptions : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CholeskyOptions()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[21]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CholeskyOptions() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CholeskyOptions(CholeskyOptions other) : this() { + lower_ = other.lower_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CholeskyOptions Clone() { + return new CholeskyOptions(this); + } + + /// Field number for the "lower" field. + public const int LowerFieldNumber = 1; + private bool lower_; + /// + /// If true, uses the lower triangle of `a`. If false, uses the upper triangle + /// of `a`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Lower { + get { return lower_; } + set { + lower_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CholeskyOptions); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CholeskyOptions other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Lower != other.Lower) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Lower != false) hash ^= Lower.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Lower != false) { + output.WriteRawTag(8); + output.WriteBool(Lower); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Lower != false) { + output.WriteRawTag(8); + output.WriteBool(Lower); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Lower != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CholeskyOptions other) { + if (other == null) { + return; + } + if (other.Lower != false) { + Lower = other.Lower; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Lower = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Lower = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + /// + /// Generic map of attributes used to pass hints / configuration options from + /// the Python frontend to the XLA backend. + /// + public sealed partial class FrontendAttributes : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FrontendAttributes()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[22]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FrontendAttributes() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FrontendAttributes(FrontendAttributes other) : this() { + map_ = other.map_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public FrontendAttributes Clone() { + return new FrontendAttributes(this); + } + + /// Field number for the "map" field. + public const int MapFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_map_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForString(18, ""), 10); + private readonly pbc::MapField map_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::MapField Map { + get { return map_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as FrontendAttributes); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(FrontendAttributes other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!Map.Equals(other.Map)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= Map.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + map_.WriteTo(output, _map_map_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + map_.WriteTo(ref output, _map_map_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += map_.CalculateSize(_map_map_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(FrontendAttributes other) { + if (other == null) { + return; + } + map_.Add(other.map_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + map_.AddEntriesFrom(input, _map_map_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + map_.AddEntriesFrom(ref input, _map_map_codec); + break; + } + } + } + } + #endif + + } + + /// + /// LINT.IfChange + /// + public sealed partial class OpSharding : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpSharding()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[23]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpSharding() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpSharding(OpSharding other) : this() { + type_ = other.type_; + tileShape_ = other.tileShape_ != null ? other.tileShape_.Clone() : null; + tileAssignmentDimensions_ = other.tileAssignmentDimensions_.Clone(); + tileAssignmentDevices_ = other.tileAssignmentDevices_.Clone(); + tupleShardings_ = other.tupleShardings_.Clone(); + replicateOnLastTileDim_ = other.replicateOnLastTileDim_; + metadata_ = other.metadata_.Clone(); + lastTileDims_ = other.lastTileDims_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OpSharding Clone() { + return new OpSharding(this); + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 1; + private global::Xla.OpSharding.Types.Type type_ = global::Xla.OpSharding.Types.Type.Replicated; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.OpSharding.Types.Type Type { + get { return type_; } + set { + type_ = value; + } + } + + /// Field number for the "tile_shape" field. + public const int TileShapeFieldNumber = 2; + private global::Xla.ShapeProto tileShape_; + /// + /// The shape of the sharded tile. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.ShapeProto TileShape { + get { return tileShape_; } + set { + tileShape_ = value; + } + } + + /// Field number for the "tile_assignment_dimensions" field. + public const int TileAssignmentDimensionsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_tileAssignmentDimensions_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField tileAssignmentDimensions_ = new pbc::RepeatedField(); + /// + /// The shape of the tile assignment tensor - this must be the same rank as + /// tile_shape and the product of its dimensions must equal + /// tile_assignment_devices.size(). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TileAssignmentDimensions { + get { return tileAssignmentDimensions_; } + } + + /// Field number for the "tile_assignment_devices" field. + public const int TileAssignmentDevicesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_tileAssignmentDevices_codec + = pb::FieldCodec.ForInt64(34); + private readonly pbc::RepeatedField tileAssignmentDevices_ = new pbc::RepeatedField(); + /// + /// Flattened list of device IDs. The order of flattening is the same as used + /// by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TileAssignmentDevices { + get { return tileAssignmentDevices_; } + } + + /// Field number for the "tuple_shardings" field. + public const int TupleShardingsFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_tupleShardings_codec + = pb::FieldCodec.ForMessage(42, global::Xla.OpSharding.Parser); + private readonly pbc::RepeatedField tupleShardings_ = new pbc::RepeatedField(); + /// + /// If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + /// in pre-order. The tuple shape could be nested; here we store just a + /// flattened list of all leaves in the tuple shape. Note that the tuple shape + /// is not stored here; shardings do not store the shapes to which they are + /// applied, this is inferred from the instruction this sharding gets attached + /// to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TupleShardings { + get { return tupleShardings_; } + } + + /// Field number for the "replicate_on_last_tile_dim" field. + public const int ReplicateOnLastTileDimFieldNumber = 6; + private bool replicateOnLastTileDim_; + /// + /// Only used for OTHER type. If true, data is sharded according to other + /// dimensions of tile_assignment(), but replicated across devices along the + /// last dimension. (Experimental) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool ReplicateOnLastTileDim { + get { return replicateOnLastTileDim_; } + set { + replicateOnLastTileDim_ = value; + } + } + + /// Field number for the "metadata" field. + public const int MetadataFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_metadata_codec + = pb::FieldCodec.ForMessage(58, global::Xla.OpMetadata.Parser); + private readonly pbc::RepeatedField metadata_ = new pbc::RepeatedField(); + /// + /// This field is used to track the source of this sharding, usually derived + /// from instructions. Multple metadata may be populated if sharding is + /// combined with other shardings. Metadata are to not be populated when + /// type == TUPLE and instead metadata should be set on individual tuple + /// elements. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Metadata { + get { return metadata_; } + } + + /// Field number for the "last_tile_dims" field. + public const int LastTileDimsFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_lastTileDims_codec + = pb::FieldCodec.ForEnum(66, x => (int) x, x => (global::Xla.OpSharding.Types.Type) x); + private readonly pbc::RepeatedField lastTileDims_ = new pbc::RepeatedField(); + /// + /// This field is used to represented the sharding type of each subgroup. + /// For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ + /// replicate, manual, unreduced}} means that each of the last 3 dimensions + /// in [2,2,2,2] represents a subgrouping in replicate, manual, + /// unreduced sharding type respectively. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField LastTileDims { + get { return lastTileDims_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OpSharding); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OpSharding other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Type != other.Type) return false; + if (!object.Equals(TileShape, other.TileShape)) return false; + if(!tileAssignmentDimensions_.Equals(other.tileAssignmentDimensions_)) return false; + if(!tileAssignmentDevices_.Equals(other.tileAssignmentDevices_)) return false; + if(!tupleShardings_.Equals(other.tupleShardings_)) return false; + if (ReplicateOnLastTileDim != other.ReplicateOnLastTileDim) return false; + if(!metadata_.Equals(other.metadata_)) return false; + if(!lastTileDims_.Equals(other.lastTileDims_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Type != global::Xla.OpSharding.Types.Type.Replicated) hash ^= Type.GetHashCode(); + if (tileShape_ != null) hash ^= TileShape.GetHashCode(); + hash ^= tileAssignmentDimensions_.GetHashCode(); + hash ^= tileAssignmentDevices_.GetHashCode(); + hash ^= tupleShardings_.GetHashCode(); + if (ReplicateOnLastTileDim != false) hash ^= ReplicateOnLastTileDim.GetHashCode(); + hash ^= metadata_.GetHashCode(); + hash ^= lastTileDims_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Type != global::Xla.OpSharding.Types.Type.Replicated) { + output.WriteRawTag(8); + output.WriteEnum((int) Type); + } + if (tileShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TileShape); + } + tileAssignmentDimensions_.WriteTo(output, _repeated_tileAssignmentDimensions_codec); + tileAssignmentDevices_.WriteTo(output, _repeated_tileAssignmentDevices_codec); + tupleShardings_.WriteTo(output, _repeated_tupleShardings_codec); + if (ReplicateOnLastTileDim != false) { + output.WriteRawTag(48); + output.WriteBool(ReplicateOnLastTileDim); + } + metadata_.WriteTo(output, _repeated_metadata_codec); + lastTileDims_.WriteTo(output, _repeated_lastTileDims_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Type != global::Xla.OpSharding.Types.Type.Replicated) { + output.WriteRawTag(8); + output.WriteEnum((int) Type); + } + if (tileShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TileShape); + } + tileAssignmentDimensions_.WriteTo(ref output, _repeated_tileAssignmentDimensions_codec); + tileAssignmentDevices_.WriteTo(ref output, _repeated_tileAssignmentDevices_codec); + tupleShardings_.WriteTo(ref output, _repeated_tupleShardings_codec); + if (ReplicateOnLastTileDim != false) { + output.WriteRawTag(48); + output.WriteBool(ReplicateOnLastTileDim); + } + metadata_.WriteTo(ref output, _repeated_metadata_codec); + lastTileDims_.WriteTo(ref output, _repeated_lastTileDims_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Type != global::Xla.OpSharding.Types.Type.Replicated) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Type); + } + if (tileShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TileShape); + } + size += tileAssignmentDimensions_.CalculateSize(_repeated_tileAssignmentDimensions_codec); + size += tileAssignmentDevices_.CalculateSize(_repeated_tileAssignmentDevices_codec); + size += tupleShardings_.CalculateSize(_repeated_tupleShardings_codec); + if (ReplicateOnLastTileDim != false) { + size += 1 + 1; + } + size += metadata_.CalculateSize(_repeated_metadata_codec); + size += lastTileDims_.CalculateSize(_repeated_lastTileDims_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OpSharding other) { + if (other == null) { + return; + } + if (other.Type != global::Xla.OpSharding.Types.Type.Replicated) { + Type = other.Type; + } + if (other.tileShape_ != null) { + if (tileShape_ == null) { + TileShape = new global::Xla.ShapeProto(); + } + TileShape.MergeFrom(other.TileShape); + } + tileAssignmentDimensions_.Add(other.tileAssignmentDimensions_); + tileAssignmentDevices_.Add(other.tileAssignmentDevices_); + tupleShardings_.Add(other.tupleShardings_); + if (other.ReplicateOnLastTileDim != false) { + ReplicateOnLastTileDim = other.ReplicateOnLastTileDim; + } + metadata_.Add(other.metadata_); + lastTileDims_.Add(other.lastTileDims_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Type = (global::Xla.OpSharding.Types.Type) input.ReadEnum(); + break; + } + case 18: { + if (tileShape_ == null) { + TileShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(TileShape); + break; + } + case 26: + case 24: { + tileAssignmentDimensions_.AddEntriesFrom(input, _repeated_tileAssignmentDimensions_codec); + break; + } + case 34: + case 32: { + tileAssignmentDevices_.AddEntriesFrom(input, _repeated_tileAssignmentDevices_codec); + break; + } + case 42: { + tupleShardings_.AddEntriesFrom(input, _repeated_tupleShardings_codec); + break; + } + case 48: { + ReplicateOnLastTileDim = input.ReadBool(); + break; + } + case 58: { + metadata_.AddEntriesFrom(input, _repeated_metadata_codec); + break; + } + case 66: + case 64: { + lastTileDims_.AddEntriesFrom(input, _repeated_lastTileDims_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Type = (global::Xla.OpSharding.Types.Type) input.ReadEnum(); + break; + } + case 18: { + if (tileShape_ == null) { + TileShape = new global::Xla.ShapeProto(); + } + input.ReadMessage(TileShape); + break; + } + case 26: + case 24: { + tileAssignmentDimensions_.AddEntriesFrom(ref input, _repeated_tileAssignmentDimensions_codec); + break; + } + case 34: + case 32: { + tileAssignmentDevices_.AddEntriesFrom(ref input, _repeated_tileAssignmentDevices_codec); + break; + } + case 42: { + tupleShardings_.AddEntriesFrom(ref input, _repeated_tupleShardings_codec); + break; + } + case 48: { + ReplicateOnLastTileDim = input.ReadBool(); + break; + } + case 58: { + metadata_.AddEntriesFrom(ref input, _repeated_metadata_codec); + break; + } + case 66: + case 64: { + lastTileDims_.AddEntriesFrom(ref input, _repeated_lastTileDims_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the OpSharding message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Type { + /// + /// This sharding is replicated across all devices (implies maximal, + /// all other fields are unused). + /// + [pbr::OriginalName("REPLICATED")] Replicated = 0, + /// + /// This sharding is maximal - one device runs the entire operation. + /// + [pbr::OriginalName("MAXIMAL")] Maximal = 1, + /// + /// This sharding is a tuple - only the tuple_shardings field is valid. + /// + [pbr::OriginalName("TUPLE")] Tuple = 2, + /// + /// None of the above; tile_shape and tile_assignment are both used. + /// + [pbr::OriginalName("OTHER")] Other = 3, + /// + /// This op is manually sharded: the shapes are already partitioned and the + /// partitioner should not change this op. + /// + [pbr::OriginalName("MANUAL")] Manual = 4, + } + + } + #endregion + + } + + /// + /// Describes the replica groups in a cross replica op (e.g., all-reduce and + /// all-to-all). + /// + public sealed partial class ReplicaGroup : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ReplicaGroup()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[24]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReplicaGroup() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReplicaGroup(ReplicaGroup other) : this() { + replicaIds_ = other.replicaIds_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ReplicaGroup Clone() { + return new ReplicaGroup(this); + } + + /// Field number for the "replica_ids" field. + public const int ReplicaIdsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_replicaIds_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField replicaIds_ = new pbc::RepeatedField(); + /// + /// The ids of the replicas that belongs to the same group. The ordering of the + /// ids matters in some ops (e.g., all-to-all). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ReplicaIds { + get { return replicaIds_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ReplicaGroup); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ReplicaGroup other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!replicaIds_.Equals(other.replicaIds_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= replicaIds_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + replicaIds_.WriteTo(output, _repeated_replicaIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + replicaIds_.WriteTo(ref output, _repeated_replicaIds_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += replicaIds_.CalculateSize(_repeated_replicaIds_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ReplicaGroup other) { + if (other == null) { + return; + } + replicaIds_.Add(other.replicaIds_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + replicaIds_.AddEntriesFrom(input, _repeated_replicaIds_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + replicaIds_.AddEntriesFrom(ref input, _repeated_replicaIds_codec); + break; + } + } + } + } + #endif + + } + + /// + /// Describes the source target pair in the collective permute op. + /// + public sealed partial class SourceTarget : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SourceTarget()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[25]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SourceTarget() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SourceTarget(SourceTarget other) : this() { + source_ = other.source_; + target_ = other.target_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SourceTarget Clone() { + return new SourceTarget(this); + } + + /// Field number for the "source" field. + public const int SourceFieldNumber = 1; + private long source_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Source { + get { return source_; } + set { + source_ = value; + } + } + + /// Field number for the "target" field. + public const int TargetFieldNumber = 2; + private long target_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Target { + get { return target_; } + set { + target_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SourceTarget); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SourceTarget other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Source != other.Source) return false; + if (Target != other.Target) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Source != 0L) hash ^= Source.GetHashCode(); + if (Target != 0L) hash ^= Target.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Source != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Source); + } + if (Target != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Target); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Source != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Source); + } + if (Target != 0L) { + output.WriteRawTag(16); + output.WriteInt64(Target); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Source != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Source); + } + if (Target != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Target); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SourceTarget other) { + if (other == null) { + return; + } + if (other.Source != 0L) { + Source = other.Source; + } + if (other.Target != 0L) { + Target = other.Target; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Source = input.ReadInt64(); + break; + } + case 16: { + Target = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + Source = input.ReadInt64(); + break; + } + case 16: { + Target = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + /// + /// Used to indicate the precision configuration. It has backend specific + /// meaning. + /// + public sealed partial class PrecisionConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PrecisionConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[26]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PrecisionConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PrecisionConfig(PrecisionConfig other) : this() { + operandPrecision_ = other.operandPrecision_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public PrecisionConfig Clone() { + return new PrecisionConfig(this); + } + + /// Field number for the "operand_precision" field. + public const int OperandPrecisionFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_operandPrecision_codec + = pb::FieldCodec.ForEnum(10, x => (int) x, x => (global::Xla.PrecisionConfig.Types.Precision) x); + private readonly pbc::RepeatedField operandPrecision_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OperandPrecision { + get { return operandPrecision_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as PrecisionConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(PrecisionConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!operandPrecision_.Equals(other.operandPrecision_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= operandPrecision_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + operandPrecision_.WriteTo(output, _repeated_operandPrecision_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + operandPrecision_.WriteTo(ref output, _repeated_operandPrecision_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += operandPrecision_.CalculateSize(_repeated_operandPrecision_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(PrecisionConfig other) { + if (other == null) { + return; + } + operandPrecision_.Add(other.operandPrecision_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + operandPrecision_.AddEntriesFrom(input, _repeated_operandPrecision_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + operandPrecision_.AddEntriesFrom(ref input, _repeated_operandPrecision_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the PrecisionConfig message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum Precision { + [pbr::OriginalName("DEFAULT")] Default = 0, + [pbr::OriginalName("HIGH")] High = 1, + [pbr::OriginalName("HIGHEST")] Highest = 2, + /// + /// Each U8/S8 value in a tensor actually represents 2 nibble values. + /// + [pbr::OriginalName("PACKED_NIBBLE")] PackedNibble = 3, + } + + } + #endregion + + } + + /// + /// Describes whether all data-parallelism replicas will receive the same + /// parameter data at each buffer. + /// + public sealed partial class ParameterReplication : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ParameterReplication()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[27]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ParameterReplication() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ParameterReplication(ParameterReplication other) : this() { + replicatedAtLeafBuffers_ = other.replicatedAtLeafBuffers_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public ParameterReplication Clone() { + return new ParameterReplication(this); + } + + /// Field number for the "replicated_at_leaf_buffers" field. + public const int ReplicatedAtLeafBuffersFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_replicatedAtLeafBuffers_codec + = pb::FieldCodec.ForBool(10); + private readonly pbc::RepeatedField replicatedAtLeafBuffers_ = new pbc::RepeatedField(); + /// + /// A list of boolean values for the flattened leaf buffers. Each value + /// indicates whether the corresponding leaf buffer is replicated. + /// + /// If this field is empty, it means no buffer is replicated. Otherwise, the + /// number of elements in this field must match the number of leaf buffers in + /// the HLO instruction's shape. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField ReplicatedAtLeafBuffers { + get { return replicatedAtLeafBuffers_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as ParameterReplication); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(ParameterReplication other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!replicatedAtLeafBuffers_.Equals(other.replicatedAtLeafBuffers_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= replicatedAtLeafBuffers_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + replicatedAtLeafBuffers_.WriteTo(output, _repeated_replicatedAtLeafBuffers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + replicatedAtLeafBuffers_.WriteTo(ref output, _repeated_replicatedAtLeafBuffers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += replicatedAtLeafBuffers_.CalculateSize(_repeated_replicatedAtLeafBuffers_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(ParameterReplication other) { + if (other == null) { + return; + } + replicatedAtLeafBuffers_.Add(other.replicatedAtLeafBuffers_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + replicatedAtLeafBuffers_.AddEntriesFrom(input, _repeated_replicatedAtLeafBuffers_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + replicatedAtLeafBuffers_.AddEntriesFrom(ref input, _repeated_replicatedAtLeafBuffers_codec); + break; + } + } + } + } + #endif + + } + + /// + /// A backend-config for kWhile loops that stores the loop's trip count, if it is + /// known. + /// + /// This is useful for backends that can implement a `for i in 0..N` loop more + /// efficiently than a `while` loop. For example, on GPUs, we can implement a + /// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, + /// whereas implementing a `while` loop requires a host-device sync on each + /// iteration. + /// + public sealed partial class WhileLoopBackendConfig : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WhileLoopBackendConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[28]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileLoopBackendConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileLoopBackendConfig(WhileLoopBackendConfig other) : this() { + knownTripCount_ = other.knownTripCount_ != null ? other.knownTripCount_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public WhileLoopBackendConfig Clone() { + return new WhileLoopBackendConfig(this); + } + + /// Field number for the "known_trip_count" field. + public const int KnownTripCountFieldNumber = 1; + private global::Xla.WhileLoopBackendConfig.Types.KnownTripCount knownTripCount_; + /// + /// This indirection lets us distinguish between known-trip-count == 0 and + /// unknown-trip-count. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Xla.WhileLoopBackendConfig.Types.KnownTripCount KnownTripCount { + get { return knownTripCount_; } + set { + knownTripCount_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as WhileLoopBackendConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(WhileLoopBackendConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(KnownTripCount, other.KnownTripCount)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (knownTripCount_ != null) hash ^= KnownTripCount.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (knownTripCount_ != null) { + output.WriteRawTag(10); + output.WriteMessage(KnownTripCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (knownTripCount_ != null) { + output.WriteRawTag(10); + output.WriteMessage(KnownTripCount); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (knownTripCount_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(KnownTripCount); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(WhileLoopBackendConfig other) { + if (other == null) { + return; + } + if (other.knownTripCount_ != null) { + if (knownTripCount_ == null) { + KnownTripCount = new global::Xla.WhileLoopBackendConfig.Types.KnownTripCount(); + } + KnownTripCount.MergeFrom(other.KnownTripCount); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (knownTripCount_ == null) { + KnownTripCount = new global::Xla.WhileLoopBackendConfig.Types.KnownTripCount(); + } + input.ReadMessage(KnownTripCount); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + if (knownTripCount_ == null) { + KnownTripCount = new global::Xla.WhileLoopBackendConfig.Types.KnownTripCount(); + } + input.ReadMessage(KnownTripCount); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the WhileLoopBackendConfig message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public sealed partial class KnownTripCount : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KnownTripCount()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.WhileLoopBackendConfig.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KnownTripCount() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KnownTripCount(KnownTripCount other) : this() { + n_ = other.n_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public KnownTripCount Clone() { + return new KnownTripCount(this); + } + + /// Field number for the "n" field. + public const int NFieldNumber = 1; + private long n_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long N { + get { return n_; } + set { + n_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as KnownTripCount); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(KnownTripCount other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (N != other.N) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (N != 0L) hash ^= N.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (N != 0L) { + output.WriteRawTag(8); + output.WriteInt64(N); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (N != 0L) { + output.WriteRawTag(8); + output.WriteInt64(N); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (N != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(N); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(KnownTripCount other) { + if (other == null) { + return; + } + if (other.N != 0L) { + N = other.N; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + N = input.ReadInt64(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 8: { + N = input.ReadInt64(); + break; + } + } + } + } + #endif + + } + + } + #endregion + + } + + /// + /// Specifies a pair of output/operand buffers for kCustomCall that alias each + /// other. + /// + public sealed partial class CustomCallOutputOperandAliasing : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CustomCallOutputOperandAliasing()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.XlaDataReflection.Descriptor.MessageTypes[29]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomCallOutputOperandAliasing() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomCallOutputOperandAliasing(CustomCallOutputOperandAliasing other) : this() { + outputShapeIndex_ = other.outputShapeIndex_.Clone(); + operandIndex_ = other.operandIndex_; + operandShapeIndex_ = other.operandShapeIndex_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public CustomCallOutputOperandAliasing Clone() { + return new CustomCallOutputOperandAliasing(this); + } + + /// Field number for the "output_shape_index" field. + public const int OutputShapeIndexFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_outputShapeIndex_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField outputShapeIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OutputShapeIndex { + get { return outputShapeIndex_; } + } + + /// Field number for the "operand_index" field. + public const int OperandIndexFieldNumber = 2; + private long operandIndex_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long OperandIndex { + get { return operandIndex_; } + set { + operandIndex_ = value; + } + } + + /// Field number for the "operand_shape_index" field. + public const int OperandShapeIndexFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_operandShapeIndex_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField operandShapeIndex_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OperandShapeIndex { + get { return operandShapeIndex_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as CustomCallOutputOperandAliasing); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(CustomCallOutputOperandAliasing other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!outputShapeIndex_.Equals(other.outputShapeIndex_)) return false; + if (OperandIndex != other.OperandIndex) return false; + if(!operandShapeIndex_.Equals(other.operandShapeIndex_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= outputShapeIndex_.GetHashCode(); + if (OperandIndex != 0L) hash ^= OperandIndex.GetHashCode(); + hash ^= operandShapeIndex_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + outputShapeIndex_.WriteTo(output, _repeated_outputShapeIndex_codec); + if (OperandIndex != 0L) { + output.WriteRawTag(16); + output.WriteInt64(OperandIndex); + } + operandShapeIndex_.WriteTo(output, _repeated_operandShapeIndex_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + outputShapeIndex_.WriteTo(ref output, _repeated_outputShapeIndex_codec); + if (OperandIndex != 0L) { + output.WriteRawTag(16); + output.WriteInt64(OperandIndex); + } + operandShapeIndex_.WriteTo(ref output, _repeated_operandShapeIndex_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += outputShapeIndex_.CalculateSize(_repeated_outputShapeIndex_codec); + if (OperandIndex != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(OperandIndex); + } + size += operandShapeIndex_.CalculateSize(_repeated_operandShapeIndex_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(CustomCallOutputOperandAliasing other) { + if (other == null) { + return; + } + outputShapeIndex_.Add(other.outputShapeIndex_); + if (other.OperandIndex != 0L) { + OperandIndex = other.OperandIndex; + } + operandShapeIndex_.Add(other.operandShapeIndex_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + outputShapeIndex_.AddEntriesFrom(input, _repeated_outputShapeIndex_codec); + break; + } + case 16: { + OperandIndex = input.ReadInt64(); + break; + } + case 26: + case 24: { + operandShapeIndex_.AddEntriesFrom(input, _repeated_operandShapeIndex_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + outputShapeIndex_.AddEntriesFrom(ref input, _repeated_outputShapeIndex_codec); + break; + } + case 16: { + OperandIndex = input.ReadInt64(); + break; + } + case 26: + case 24: { + operandShapeIndex_.AddEntriesFrom(ref input, _repeated_operandShapeIndex_codec); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/XlaFramework.cs b/src/TensorFlowNET.Core/Protobuf/XlaFramework.cs new file mode 100644 index 000000000..1cad3ef3b --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/XlaFramework.cs @@ -0,0 +1,360 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/compiler/xla/service/cpu/xla_framework.proto +// +#pragma warning disable 1591, 0612, 3021, 8981 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Xla.Cpu { + + /// Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/xla_framework.proto + public static partial class XlaFrameworkReflection { + + #region Descriptor + /// File descriptor for tensorflow/compiler/xla/service/cpu/xla_framework.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static XlaFrameworkReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS94bGFfZnJh", + "bWV3b3JrLnByb3RvEgd4bGEuY3B1InoKGFhsYUZyYW1ld29ya01hcHBpbmdQ", + "cm90bxISCgZpbnB1dHMYASADKANCAhABEh0KEWZsYXR0ZW5lZF9vdXRwdXRz", + "GAIgAygDQgIQARISCgZyZXN1bHQYAyABKAM6Ai0xEhcKD291dHB1dF9pc190", + "dXBsZRgEIAEoCA==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaFrameworkMappingProto), global::Xla.Cpu.XlaFrameworkMappingProto.Parser, new[]{ "Inputs", "FlattenedOutputs", "Result", "OutputIsTuple" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class XlaFrameworkMappingProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XlaFrameworkMappingProto()); + private pb::UnknownFieldSet _unknownFields; + private int _hasBits0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Xla.Cpu.XlaFrameworkReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaFrameworkMappingProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaFrameworkMappingProto(XlaFrameworkMappingProto other) : this() { + _hasBits0 = other._hasBits0; + inputs_ = other.inputs_.Clone(); + flattenedOutputs_ = other.flattenedOutputs_.Clone(); + result_ = other.result_; + outputIsTuple_ = other.outputIsTuple_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public XlaFrameworkMappingProto Clone() { + return new XlaFrameworkMappingProto(this); + } + + /// Field number for the "inputs" field. + public const int InputsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_inputs_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField inputs_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Inputs { + get { return inputs_; } + } + + /// Field number for the "flattened_outputs" field. + public const int FlattenedOutputsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_flattenedOutputs_codec + = pb::FieldCodec.ForInt64(18); + private readonly pbc::RepeatedField flattenedOutputs_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField FlattenedOutputs { + get { return flattenedOutputs_; } + } + + /// Field number for the "result" field. + public const int ResultFieldNumber = 3; + private readonly static long ResultDefaultValue = -1L; + + private long result_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public long Result { + get { if ((_hasBits0 & 1) != 0) { return result_; } else { return ResultDefaultValue; } } + set { + _hasBits0 |= 1; + result_ = value; + } + } + /// Gets whether the "result" field is set + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool HasResult { + get { return (_hasBits0 & 1) != 0; } + } + /// Clears the value of the "result" field + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearResult() { + _hasBits0 &= ~1; + } + + /// Field number for the "output_is_tuple" field. + public const int OutputIsTupleFieldNumber = 4; + private readonly static bool OutputIsTupleDefaultValue = false; + + private bool outputIsTuple_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool OutputIsTuple { + get { if ((_hasBits0 & 2) != 0) { return outputIsTuple_; } else { return OutputIsTupleDefaultValue; } } + set { + _hasBits0 |= 2; + outputIsTuple_ = value; + } + } + /// Gets whether the "output_is_tuple" field is set + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool HasOutputIsTuple { + get { return (_hasBits0 & 2) != 0; } + } + /// Clears the value of the "output_is_tuple" field + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void ClearOutputIsTuple() { + _hasBits0 &= ~2; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as XlaFrameworkMappingProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(XlaFrameworkMappingProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!inputs_.Equals(other.inputs_)) return false; + if(!flattenedOutputs_.Equals(other.flattenedOutputs_)) return false; + if (Result != other.Result) return false; + if (OutputIsTuple != other.OutputIsTuple) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + hash ^= inputs_.GetHashCode(); + hash ^= flattenedOutputs_.GetHashCode(); + if (HasResult) hash ^= Result.GetHashCode(); + if (HasOutputIsTuple) hash ^= OutputIsTuple.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + inputs_.WriteTo(output, _repeated_inputs_codec); + flattenedOutputs_.WriteTo(output, _repeated_flattenedOutputs_codec); + if (HasResult) { + output.WriteRawTag(24); + output.WriteInt64(Result); + } + if (HasOutputIsTuple) { + output.WriteRawTag(32); + output.WriteBool(OutputIsTuple); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + inputs_.WriteTo(ref output, _repeated_inputs_codec); + flattenedOutputs_.WriteTo(ref output, _repeated_flattenedOutputs_codec); + if (HasResult) { + output.WriteRawTag(24); + output.WriteInt64(Result); + } + if (HasOutputIsTuple) { + output.WriteRawTag(32); + output.WriteBool(OutputIsTuple); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + size += inputs_.CalculateSize(_repeated_inputs_codec); + size += flattenedOutputs_.CalculateSize(_repeated_flattenedOutputs_codec); + if (HasResult) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Result); + } + if (HasOutputIsTuple) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(XlaFrameworkMappingProto other) { + if (other == null) { + return; + } + inputs_.Add(other.inputs_); + flattenedOutputs_.Add(other.flattenedOutputs_); + if (other.HasResult) { + Result = other.Result; + } + if (other.HasOutputIsTuple) { + OutputIsTuple = other.OutputIsTuple; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + inputs_.AddEntriesFrom(input, _repeated_inputs_codec); + break; + } + case 18: + case 16: { + flattenedOutputs_.AddEntriesFrom(input, _repeated_flattenedOutputs_codec); + break; + } + case 24: { + Result = input.ReadInt64(); + break; + } + case 32: { + OutputIsTuple = input.ReadBool(); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: + case 8: { + inputs_.AddEntriesFrom(ref input, _repeated_inputs_codec); + break; + } + case 18: + case 16: { + flattenedOutputs_.AddEntriesFrom(ref input, _repeated_flattenedOutputs_codec); + break; + } + case 24: { + Result = input.ReadInt64(); + break; + } + case 32: { + OutputIsTuple = input.ReadBool(); + break; + } + } + } + } + #endif + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin b/src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin deleted file mode 100644 index 62d31e670..000000000 Binary files a/src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin and /dev/null differ diff --git a/src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin b/src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin deleted file mode 100644 index c94552c83..000000000 Binary files a/src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin and /dev/null differ diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 485213204..3dab4ec71 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,139 +1,286 @@ -using NumSharp.Core; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; - -namespace Tensorflow +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BaseSession : IDisposable { - public class BaseSession : IDisposable + protected SafeSessionHandle _handle; + protected Graph _graph; + protected Status _status; + public Graph graph => _graph; + + public BaseSession(SafeSessionHandle handle, Graph g) { - private Graph _graph; - private bool _opened; - private bool _closed; - private int _current_version; - private byte[] _target; - private IntPtr _session; + _handle = handle; + _graph = g ?? ops.get_default_graph(); + _status = tf.Status; + } - public BaseSession(string target = "", Graph graph = null) + public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) + { + _graph = g ?? ops.get_default_graph(); + if (!_graph.building_function) { - if(graph is null) - { - _graph = ops.get_default_graph(); - } - else - { - _graph = graph; - } + if (ops.get_default_graph() != _graph) + _graph.as_default(); + } + + var opts = new SessionOptions(target, config); + _status = status ?? tf.Status; + _handle = c_api.TF_NewSession(_graph, opts, _status); + _status.Check(true); + } - _target = UTF8Encoding.UTF8.GetBytes(target); - var opts = c_api.TF_NewSessionOptions(); - var status = new Status(); - _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); + public virtual void run(Operation op, params FeedItem[] feed_dict) + { + _run(op, feed_dict); + } - c_api.TF_DeleteSessionOptions(opts); - } + public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } - public void Dispose() + public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) + { + var results = _run(fetche, feed_dict); + return fetche is Tensor ? results[0] : null; + } + + public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( + (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, + params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); + return (results[0], results[1], results[2], results[3], results[4]); + } + + public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + return (results[0], results[1], results[2], results[3]); + } + + public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + return (results[0], results[1], results[2]); + } + + public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + return (results[0], results[1]); + } + + public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) + { + return _run(fetches, feed_dict); + } + + public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) + { + var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + return _run(fetches, feed_items); + } + + private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) + { + var feed_dict_tensor = new Dictionary(); + //var feed_map = new Dictionary(); + + // Validate and process feed_dict. + if (feed_dict != null) { - + foreach (var subfeed in feed_dict) + { + var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); + //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed.Value; + //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); + } } - public virtual object run(Tensor fetches, Dictionary feed_dict = null) - { - var result = _run(fetches, feed_dict); + // Create a fetch handler to take care of the structure of fetches. + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); - return result; - } + // Run request and get response. + // We need to keep the returned movers alive for the following _do_run(). + // These movers are no longer needed when _do_run() completes, and + // are deleted when `movers` goes out of scope when this _run() ends. + var _ = _update_with_movers(); + var final_fetches = fetch_handler.fetches(); + var final_targets = fetch_handler.targets(); - private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) - { - var feed_dict_tensor = new Dictionary(); + // We only want to really perform the run if fetches or targets are provided, + // or if the call is a partial run that specifies feeds. + var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); + + return fetch_handler.build_results(this, results); + } - if (feed_dict != null) + /// + /// Runs a step based on the given fetches and feeds. + /// + /// A list of operations to be run, but not fetched. + /// + /// + /// + /// A list of numpy ndarrays, corresponding to the elements of + /// `fetch_list`. If the ith element of `fetch_list` contains the + /// name of an operation, the first Tensor output of that operation + /// will be returned for that element. + /// + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) + { + var feeds = new KeyValuePair[feed_dict.Count]; + int i = 0; + foreach (var x in feed_dict) + { + if (x.Key is Tensor key) { - NDArray np_val = null; - foreach (var feed in feed_dict) + switch (x.Value) { - switch (feed.Value) - { - case float value: - np_val = np.asarray(value); - break; - } - - feed_dict_tensor[feed.Key] = np_val; + case Tensor v: + if (v.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); + feeds[i++] = new KeyValuePair(key._as_tf_output(), v); + break; + case SafeTensorHandle v: + var tensor = new Tensor(v); + if (tensor.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); + feeds[i++] = new KeyValuePair(key._as_tf_output(), tensor); + break; + case bool v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case byte v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case int v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case long v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case float v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case double v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case string v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v)); + break; + case Array v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), new Tensor(v, v.GetShape())); + break; + default: + throw new NotImplementedException(""); } } + else + throw new NotImplementedException(""); + } - // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + //var targets = target_list; + return _call_tf_sessionrun(feeds, fetches, target_list); + } - // Run request and get response. - // We need to keep the returned movers alive for the following _do_run(). - // These movers are no longer needed when _do_run() completes, and - // are deleted when `movers` goes out of scope when this _run() ends. - var _ = _update_with_movers(); - var final_fetches = fetch_handler.fetches(); - var final_targets = fetch_handler.targets(); - // We only want to really perform the run if fetches or targets are provided, - // or if the call is a partial run that specifies feeds. - var results = _do_run(final_fetches); + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) + { + // Ensure any changes to the graph are reflected in the runtime. + _extend_graph(); - return fetch_handler.build_results(null, results); - } + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - private object[] _do_run(List fetch_list) - { - var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray(); + c_api.TF_SessionRun(_handle, + run_options: null, + inputs: feed_dict.Select(f => f.Key).ToArray(), + input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), + ninputs: feed_dict.Length, + outputs: fetch_list, + output_values: output_values, + noutputs: fetch_list.Length, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + ntargets: target_list.Count, + run_metadata: IntPtr.Zero, + status: _status); - return _call_tf_sessionrun(fetches); - } + _status.Check(true); - private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list) - { - // Ensure any changes to the graph are reflected in the runtime. - _extend_graph(); - - var status = new Status(); - - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - - c_api.TF_SessionRun(_session, - run_options: IntPtr.Zero, - inputs: new TF_Output[] { }, - input_values: new IntPtr[] { }, - ninputs: 0, - outputs: fetch_list, - output_values: output_values, - noutputs: fetch_list.Length, - target_opers: new IntPtr[] { }, - ntargets: 0, - run_metadata: IntPtr.Zero, - status: status.Handle); - - var result = output_values.Select(x => c_api.TF_TensorData(x)) - .Select(x => (object)*(float*)x) - .ToArray(); - - return result; - } + var result = new NDArray[fetch_list.Length]; - /// - /// If a tensor handle that is fed to a device incompatible placeholder, - /// we move the tensor to the right device, generate a new tensor handle, - /// and update feed_dict to use the new handle. - /// - private List _update_with_movers() - { - return new List { }; - } + for (int i = 0; i < fetch_list.Length; i++) + result[i] = fetchValue(new SafeTensorHandle(output_values[i])); - private void _extend_graph() - { + return result; + } - } + public unsafe Tensor eval(Tensor tensor) + { + var output_values = new IntPtr[1]; + var fetch_list = new[] { tensor._as_tf_output() }; + + c_api.TF_SessionRun(_handle, + run_options: null, + inputs: new TF_Output[0], + input_values: new IntPtr[0], + ninputs: 0, + outputs: fetch_list, + output_values: output_values, + noutputs: 1, + target_opers: new IntPtr[0], + ntargets: 0, + run_metadata: IntPtr.Zero, + status: _status); + + _status.Check(true); + + return new Tensor(new SafeTensorHandle(output_values[0])); + } + + private static unsafe NDArray fetchValue(SafeTensorHandle output) + { + var tensor = new Tensor(output); + return tensor.numpy(); + } + + /// + /// If a tensor handle that is fed to a device incompatible placeholder, + /// we move the tensor to the right device, generate a new tensor handle, + /// and update feed_dict to use the new handle. + /// + private List _update_with_movers() + { + return new List { }; + } + + private void _extend_graph() + { } + + public void Dispose() + { + } } diff --git a/src/TensorFlowNET.Core/Sessions/FeedDict.cs b/src/TensorFlowNET.Core/Sessions/FeedDict.cs new file mode 100644 index 000000000..f39a761db --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/FeedDict.cs @@ -0,0 +1,8 @@ +using System.Collections; + +namespace Tensorflow.Sessions +{ + public class FeedDict : Hashtable + { + } +} diff --git a/src/TensorFlowNET.Core/Sessions/FeedItem.cs b/src/TensorFlowNET.Core/Sessions/FeedItem.cs new file mode 100644 index 000000000..c3a3dc675 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/FeedItem.cs @@ -0,0 +1,26 @@ +namespace Tensorflow +{ + /// + /// Feed dictionary item + /// + public class FeedItem + { + public object Key { get; } + public object Value { get; } + + public FeedItem(object key, object val) + { + Key = key; + Value = val; + } + + public static implicit operator FeedItem((object, object) feed) + => new FeedItem(feed.Item1, feed.Item2); + + public void Deconstruct(out object key, out object value) + { + key = Key; + value = Value; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/SafeSessionHandle.cs b/src/TensorFlowNET.Core/Sessions/SafeSessionHandle.cs new file mode 100644 index 000000000..4e4b013c1 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/SafeSessionHandle.cs @@ -0,0 +1,46 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Net.NetworkInformation; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeSessionHandle : SafeTensorflowHandle + { + private SafeSessionHandle() + { + } + + public SafeSessionHandle(IntPtr handle) + : base(handle) + { + } + + public override string ToString() + => $"0x{handle:x16}"; + + protected override bool ReleaseHandle() + { + var status = new Status(); + // c_api.TF_CloseSession(handle, tf.Status.Handle); + c_api.TF_DeleteSession(handle, status); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs new file mode 100644 index 000000000..00f2e35bd --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle + { + private SafeSessionOptionsHandle() + { + } + + public SafeSessionOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteSessionOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 20c31f1f9..3b91b4898 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -1,10 +1,62 @@ -using System; -using System.Collections.Generic; -using System.Text; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. -namespace Tensorflow + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow; + +public class Session : BaseSession { - public class Session : BaseSession + public Session(string target = "", Graph g = null) : base(target, g, null) + { } + + public Session(SafeSessionHandle handle, Graph g = null) : base(handle, g) + { } + + public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) + { } + + public Session as_default() { + return ops.set_default_session(this); } + + public static Session LoadFromSavedModel(string path) + { + var graph = new Graph(); + var status = new Status(); + using var opt = c_api.TF_NewSessionOptions(); + + var tags = new string[] { "serve" }; + + var sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + IntPtr.Zero, + status); + status.Check(true); + + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ + return new Session(sess, g: graph); + } + + public static implicit operator SafeSessionHandle(Session session) => session._handle; + public static implicit operator Session(SafeSessionHandle handle) => new Session(handle); } diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs new file mode 100644 index 000000000..4a11a7f91 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -0,0 +1,51 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; + +namespace Tensorflow +{ + internal sealed class SessionOptions + { + SafeSessionOptionsHandle _handle { get; } + + public SessionOptions(string target = "", ConfigProto config = null) + { + _handle = c_api.TF_NewSessionOptions(); + c_api.TF_SetTarget(_handle, target); + if (config != null) + SetConfig(config); + } + + private unsafe void SetConfig(ConfigProto config) + { + var bytes = config.ToByteArray(); + + fixed (byte* proto2 = bytes) + { + var status = new Status(); + c_api.TF_SetConfig(_handle, (IntPtr)proto2, (ulong)bytes.Length, status); + status.Check(false); + } + } + + public static implicit operator SafeSessionOptionsHandle(SessionOptions opt) + { + return opt._handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index 908f516c1..4086713a6 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -1,6 +1,22 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; using System.Collections.Generic; -using System.Text; namespace Tensorflow { @@ -9,26 +25,56 @@ namespace Tensorflow /// public class _ElementFetchMapper : _FetchMapper { - private List _unique_fetches = new List(); - private Action _contraction_fn; + private Func, object> _contraction_fn; - public _ElementFetchMapper(List fetches, Action contraction_fn) + public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn, Graph graph = null) { - foreach(var tensor in fetches) + var g = graph ?? ops.get_default_graph(); + + foreach (var fetch in fetches) { - var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); - _unique_fetches.Add(fetch); + var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); + _unique_fetches.Add(el); } - } - public object build_results(object[] values) - { - return values[0]; + _contraction_fn = contraction_fn; } - public List unique_fetches() + /// + /// Build results matching the original fetch shape. + /// + /// + /// + public override NDArray[] build_results(List values) { - return _unique_fetches; + NDArray[] result = null; + + if (values.Count > 0) + { + var ret = _contraction_fn(values); + switch (ret) + { + case NDArray value: + result = new[] { value }; + break; + case bool value: + result = new[] { NDArray.Scalar(value) }; + break; + case byte value: + result = new[] { NDArray.Scalar(value) }; + break; + case int value: + result = new[] { NDArray.Scalar(value) }; + break; + case float value: + result = new[] { NDArray.Scalar(value) }; + break; + default: + break; + } + } + + return result; } } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 0ec355d54..93656cf7e 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -1,6 +1,22 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; using System.Collections.Generic; -using System.Text; namespace Tensorflow { @@ -9,24 +25,31 @@ namespace Tensorflow /// public class _FetchHandler { - private _ElementFetchMapper _fetch_mapper; - private List _fetches = new List(); + private _FetchMapper _fetch_mapper; + private List _fetches = new List(); private List _ops = new List(); - private List _final_fetches = new List(); + private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, object fetches, Dictionary feeds = null, Action feed_handles = null) { - _fetch_mapper = new _FetchMapper().for_fetch(fetches); - foreach(var fetch in _fetch_mapper.unique_fetches()) + _fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph); + foreach (var fetch in _fetch_mapper.unique_fetches()) { switch (fetch) { + case Operation val: + _assert_fetchable(graph, val); + _targets.Add(val); + _ops.Add(true); + break; case Tensor val: _assert_fetchable(graph, val.op); - _fetches.Add(fetch); + _fetches.Add(val); _ops.Add(false); break; + default: + throw new NotImplementedException("_FetchHandler fetch"); } } @@ -34,9 +57,36 @@ public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object fe _final_fetches = _fetches; } - public object build_results(Session session, object[] results) + public NDArray[] build_results(BaseSession session, NDArray[] tensor_values) { - return _fetch_mapper.build_results(results); + var full_values = new List(); + if (_final_fetches.Count != tensor_values.Length) + throw new InvalidOperationException("_final_fetches mismatch tensor_values"); + + int i = 0; + int j = 0; + foreach (var is_op in _ops) + { + if (is_op) + { + if (tensor_values.Length > 0) + full_values.Add(float.NaN); + else + full_values.Add(null); + } + else + { + var value = tensor_values[j]; + j += 1; + full_values.Add(value); + } + i += 1; + } + + if (j != tensor_values.Length) + throw new InvalidOperationException("j mismatch tensor_values"); + + return _fetch_mapper.build_results(full_values); } private void _assert_fetchable(Graph graph, Operation op) @@ -47,12 +97,12 @@ private void _assert_fetchable(Graph graph, Operation op) } } - public List fetches() + public List fetches() { return _final_fetches; } - public List targets() + public List targets() { return _targets; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 763b67a0e..eb72dfc9c 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -1,16 +1,48 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; using System.Collections.Generic; -using System.Text; namespace Tensorflow { public class _FetchMapper { - public _ElementFetchMapper for_fetch(Tensor fetch) + protected List _unique_fetches = new List(); + protected List _value_indices = new List(); + public static _FetchMapper for_fetch(object fetch, Graph graph = null) + { + var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; + + if (fetch is List fetches1) + return new _ListFetchMapper(fetches1.ToArray()); + if (fetch.GetType().IsArray) + return new _ListFetchMapper(fetches); + else + return new _ElementFetchMapper(fetches, (List fetched_vals) => fetched_vals[0], graph: graph); + } + + public virtual NDArray[] build_results(List values) { - var fetches = new List { fetch }; + return values.ToArray(); + } - return new _ElementFetchMapper(fetches, null); + public virtual List unique_fetches() + { + return _unique_fetches; } } } diff --git a/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs new file mode 100644 index 000000000..f7b25ea58 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs @@ -0,0 +1,72 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow +{ + public class _ListFetchMapper : _FetchMapper + { + private _FetchMapper[] _mappers; + + public _ListFetchMapper(object[] fetches) + { + _mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); + (_unique_fetches, _value_indices) = _uniquify_fetches(_mappers); + } + + private (List, List) _uniquify_fetches(_FetchMapper[] fetch_mappers) + { + var unique_fetches = new List(); + var value_indices = new List(); + var seen_fetches = new Dictionary(); + + foreach (var m in fetch_mappers) + { + var m_value_indices = new List(); + foreach (var uf in m.unique_fetches()) + { + switch (uf) + { + case Tensor f: + if (!seen_fetches.ContainsKey(f)) + { + seen_fetches[f] = seen_fetches.Count; + unique_fetches.Add(f); + } + m_value_indices.Add(seen_fetches.Count - 1); + break; + case Operation f: + if (!seen_fetches.ContainsKey(f)) + { + seen_fetches[f] = seen_fetches.Count; + unique_fetches.Add(f); + } + m_value_indices.Add(seen_fetches.Count - 1); + break; + default: + throw new NotImplementedException("_uniquify_fetches"); + } + } + value_indices.Add(m_value_indices.ToArray()); + } + + return (unique_fetches, value_indices); + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 49c32890b..a26ab56d7 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -1,59 +1,136 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { + /// + /// Close a session. + /// + /// Contacts any other processes associated with the session, if applicable. + /// May not be called after TF_DeleteSession(). + /// + /// + /// + + [DllImport(TensorFlowLibName)] + public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status); + + /// + /// Destroy a session object. + /// + /// Even if error information is recorded in *status, this call discards all + /// local resources associated with the session. The session may not be used + /// during or after this call (and the session drops its reference to the + /// corresponding graph). + /// + /// TF_Session* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteSession(IntPtr session, SafeStatusHandle status); + /// /// Destroy an options object. /// - /// + /// TF_SessionOptions* [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteSessionOptions(IntPtr opts); + public static extern void TF_DeleteSessionOptions(IntPtr opts); /// /// Return a new execution session with the associated graph, or NULL on /// error. Does not take ownership of any input parameters. /// - /// - /// - /// - /// + /// TF_Graph* + /// const TF_SessionOptions* + /// TF_Status* + /// TF_Session* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, IntPtr status); + public static extern SafeSessionHandle TF_NewSession(SafeGraphHandle graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); /// /// Return a new options object. /// - /// + /// TF_SessionOptions* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewSessionOptions(); + public static extern SafeSessionOptionsHandle TF_NewSessionOptions(); /// /// Run the graph associated with the session starting with the supplied inputs /// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). + /// + /// Any NULL and non-NULL value combinations for (`run_options`, + /// `run_metadata`) are valid. + /// + /// - `run_options` may be NULL, in which case it will be ignored; or + /// non-NULL, in which case it must point to a `TF_Buffer` containing the + /// serialized representation of a `RunOptions` protocol buffer. + /// - `run_metadata` may be NULL, in which case it will be ignored; or + /// non-NULL, in which case it must point to an empty, freshly allocated + /// `TF_Buffer` that may be updated to contain the serialized representation + /// of a `RunMetadata` protocol buffer. + /// + /// The caller retains ownership of `input_values` (which can be deleted using + /// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or + /// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on + /// them. + /// + /// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in + /// output_values[]. Ownership of the elements of output_values[] is transferred + /// to the caller, which must eventually call TF_DeleteTensor on them. + /// + /// On failure, output_values[] contains NULLs. /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// + /// TF_Session* + /// const TF_Buffer* + /// const TF_Output* + /// TF_Tensor* const* + /// int + /// const TF_Output* + /// TF_Tensor** + /// int + /// const TF_Operation* const* + /// int + /// TF_Buffer* + /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SessionRun(IntPtr session, IntPtr run_options, + public static extern unsafe void TF_SessionRun(SafeSessionHandle session, TF_Buffer* run_options, TF_Output[] inputs, IntPtr[] input_values, int ninputs, TF_Output[] outputs, IntPtr[] output_values, int noutputs, IntPtr[] target_opers, int ntargets, IntPtr run_metadata, - IntPtr status); + SafeStatusHandle status); + + /// + /// Set the config in TF_SessionOptions.options. + /// config should be a serialized tensorflow.ConfigProto proto. + /// If config was not parsed successfully as a ConfigProto, record the + /// error information in *status. + /// + /// TF_SessionOptions* + /// const void* + /// size_t + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target); } } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs new file mode 100644 index 000000000..4077efa98 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs @@ -0,0 +1,43 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public partial class c_api + { + public static string[] TF_OperationOutputConsumers_wrapper(TF_Output oper_out) + { + int num_consumers = TF_OperationOutputNumConsumers(oper_out); + int size = Marshal.SizeOf(); + var handle = Marshal.AllocHGlobal(size * num_consumers); + int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); + var consumers = new string[num_consumers]; + unsafe + { + var inputptr = (TF_Input*)handle; + for (int i = 0; i < num; i++) + { + var oper = (inputptr + i)->oper; + consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); + } + } + Marshal.FreeHGlobal(handle); + return consumers; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/SafeStatusHandle.cs b/src/TensorFlowNET.Core/Status/SafeStatusHandle.cs new file mode 100644 index 000000000..d20a9d572 --- /dev/null +++ b/src/TensorFlowNET.Core/Status/SafeStatusHandle.cs @@ -0,0 +1,39 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeStatusHandle : SafeTensorflowHandle + { + private SafeStatusHandle() + { + } + + public SafeStatusHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteStatus(handle); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index dc3693866..12b6fba2b 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -1,37 +1,106 @@ -using System; -using System.Collections.Generic; -using System.Text; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using Tensorflow.Exceptions; +using Tensorflow.Util; +using static Tensorflow.c_api; namespace Tensorflow { - public class Status : IDisposable + /// + /// TF_Status holds error information. It either has an OK code, or + /// else an error code with an associated error message. + /// + public sealed class Status { - private readonly IntPtr _handle; - public IntPtr Handle => _handle; - /// /// Error message /// - public string Message => c_api.TF_Message(_handle); + public string Message + { + get + { + using (_handle.Lease()) + { + return StringPiece(TF_Message(_handle)); + } + } + } /// /// Error code /// - public TF_Code Code => c_api.TF_GetCode(_handle); + public TF_Code Code => TF_GetCode(_handle); + + SafeStatusHandle _handle { get; } public Status() { - _handle = c_api.TF_NewStatus(); + _handle = TF_NewStatus(); + } + + public Status(SafeStatusHandle handle) + { + _handle = handle ?? throw new ArgumentNullException(nameof(handle)); } public void SetStatus(TF_Code code, string msg) { - c_api.TF_SetStatus(_handle, code, msg); + TF_SetStatus(_handle, code, msg); } - public void Dispose() + public bool ok() => Code == TF_Code.TF_OK; + + /// + /// Check status + /// Throw exception with error message if code != TF_OK + /// + /// When the returned check is not TF_Code.TF_OK + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [DebuggerHidden] + public void Check(bool throwException = false) + { + if (Code != TF_Code.TF_OK) + { + var message = Message; + + if (throwException) + { + switch (Code) + { + case TF_Code.TF_OUT_OF_RANGE: + throw new OutOfRangeError(message); + case TF_Code.TF_INVALID_ARGUMENT: + throw new InvalidArgumentError(message); + default: + throw new NotOkStatusException(message); + } + } + } + } + + public override string ToString() + => $"{Code} 0x{_handle.DangerousGetHandle():x16}"; + + public static implicit operator SafeStatusHandle(Status status) { - c_api.TF_DeleteStatus(_handle); + return status._handle; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/TF_Code.cs b/src/TensorFlowNET.Core/Status/TF_Code.cs index 5e8e3c8dd..bb26c370e 100644 --- a/src/TensorFlowNET.Core/Status/TF_Code.cs +++ b/src/TensorFlowNET.Core/Status/TF_Code.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow +namespace Tensorflow { public enum TF_Code { diff --git a/src/TensorFlowNET.Core/Status/c_api.status.cs b/src/TensorFlowNET.Core/Status/c_api.status.cs index efd2a959c..7854481d6 100644 --- a/src/TensorFlowNET.Core/Status/c_api.status.cs +++ b/src/TensorFlowNET.Core/Status/c_api.status.cs @@ -1,18 +1,32 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Runtime.InteropServices; -using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Delete a previously created status object. /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteStatus(IntPtr s); + public static extern void TF_DeleteStatus(IntPtr s); /// /// Return the code record in *s. @@ -20,7 +34,7 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe TF_Code TF_GetCode(IntPtr s); + public static extern TF_Code TF_GetCode(SafeStatusHandle s); /// /// Return a pointer to the (null-terminated) error message in *s. @@ -30,23 +44,23 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe string TF_Message(IntPtr s); + public static extern IntPtr TF_Message(SafeStatusHandle s); /// /// Return a new status object. /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewStatus(); + public static extern SafeStatusHandle TF_NewStatus(); /// - /// Record in *s. Any previous information is lost. + /// Record <code, msg> in *s. Any previous information is lost. /// A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); /// /// /// /// [DllImport(TensorFlowLibName)] - public static extern void TF_SetStatus(IntPtr s, TF_Code code, string msg); + public static extern void TF_SetStatus(SafeStatusHandle s, TF_Code code, string msg); } } diff --git a/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs new file mode 100644 index 000000000..8a6d9bb08 --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs @@ -0,0 +1,56 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.IO; + +namespace Tensorflow.Summaries +{ + /// + /// Creates a `EventFileWriter` and an event file to write to. + /// + public class EventFileWriter + { + string _logdir; + // Represents a first-in, first-out collection of objects. + Queue _event_queue; + EventsWriter _ev_writer; + int _flush_secs; + Event _sentinel_event; +#pragma warning disable CS0414 // The field 'EventFileWriter._closed' is assigned but its value is never used + bool _closed; +#pragma warning restore CS0414 // The field 'EventFileWriter._closed' is assigned but its value is never used + EventLoggerThread _worker; + + public EventFileWriter(string logdir, int max_queue = 10, int flush_secs = 120, + string filename_suffix = null) + { + _logdir = logdir; + Directory.CreateDirectory(_logdir); + _event_queue = new Queue(max_queue); + _ev_writer = new EventsWriter(Path.Combine(_logdir, "events")); + _flush_secs = flush_secs; + _sentinel_event = new Event(); + if (!string.IsNullOrEmpty(filename_suffix)) + // self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix))) + throw new NotImplementedException("EventFileWriter filename_suffix is not null"); + _closed = false; + _worker = new EventLoggerThread(_event_queue, _ev_writer, _flush_secs, _sentinel_event); + _worker.start(); + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs new file mode 100644 index 000000000..cbe9665da --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs @@ -0,0 +1,67 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow.Summaries +{ + /// + /// Thread that logs events. + /// + public class EventLoggerThread + { + Queue _queue; +#pragma warning disable CS0414 // The field 'EventLoggerThread.daemon' is assigned but its value is never used + bool daemon; +#pragma warning restore CS0414 // The field 'EventLoggerThread.daemon' is assigned but its value is never used + EventsWriter _ev_writer; + int _flush_secs; + Event _sentinel_event; + + public EventLoggerThread(Queue queue, EventsWriter ev_writer, int flush_secs, Event sentinel_event) + { + daemon = true; + _queue = queue; + _ev_writer = ev_writer; + _flush_secs = flush_secs; + _sentinel_event = sentinel_event; + } + + public void start() => run(); + + public void run() + { + Task.Run(delegate + { + while (true) + { + if (_queue.Count == 0) + { + Thread.Sleep(_flush_secs * 1000); + continue; + } + + var @event = _queue.Dequeue(); + _ev_writer._WriteSerializedEvent(@event.ToByteArray()); + Thread.Sleep(1000); + } + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/EventsWriter.cs b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs new file mode 100644 index 000000000..ae6a94f1f --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.IO; + +namespace Tensorflow.Summaries +{ + public class EventsWriter + { + string _file_prefix; + + public EventsWriter(string file_prefix) + { + _file_prefix = file_prefix; + } + + public void _WriteSerializedEvent(byte[] event_str) + { + File.WriteAllBytes(_file_prefix, event_str); + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/FileWriter.cs b/src/TensorFlowNET.Core/Summaries/FileWriter.cs new file mode 100644 index 000000000..68bba4db1 --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/FileWriter.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Summaries +{ + /// + /// Writes `Summary` protocol buffers to event files. + /// + public class FileWriter : SummaryToEventTransformer + { + EventFileWriter event_writer; + + public FileWriter(string logdir, Graph graph, + int max_queue = 10, int flush_secs = 120, string filename_suffix = null, + Session session = null) + { + if (session == null) + { + event_writer = new EventFileWriter(logdir, max_queue, flush_secs, filename_suffix); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/Summary.cs b/src/TensorFlowNET.Core/Summaries/Summary.cs new file mode 100644 index 000000000..a1f47bc02 --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/Summary.cs @@ -0,0 +1,103 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Summaries +{ + public class Summary + { + public FileWriter FileWriter(string logdir, Graph graph, + int max_queue = 10, int flush_secs = 120, string filename_suffix = null, + Session session = null) + => new FileWriter(logdir, graph, max_queue: max_queue, + flush_secs: flush_secs, filename_suffix: filename_suffix, + session: session); + + public Tensor histogram(string name, Tensor tensor, string[] collections = null, string family = null) + { + var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); + var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); + return val; + } + + public Tensor merge_all(string key = "summaries", string scope = null, string name = null) + { + var summary_ops = ops.get_collection(key, scope: scope); + if (summary_ops == null) + return null; + else + return merge((summary_ops as List).Select(x => x as Tensor).ToArray(), name: name); + } + + /// + /// Merges summaries. + /// + /// + /// + /// + /// + public Tensor merge(Tensor[] inputs, string[] collections = null, string name = null) + { + return tf_with(ops.name_scope(name, "Merge", inputs), delegate + { + var val = gen_logging_ops.merge_summary(inputs: inputs, name: name); + collect(val, collections?.ToList(), new List()); + return val; + }); + } + + public Tensor scalar(string name, Tensor tensor, string[] collections = null, string family = null) + { + var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); + var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); + return val; + } + + /// + /// Adds keys to a collection. + /// + /// The value to add per each key. + /// A collection of keys to add. + /// Used if collections is None. + public void collect(ITensorOrOperation val, List collections, List default_collections) + { + if (collections == null) + collections = default_collections; + foreach (var key in collections) + ops.add_to_collection(key, val); + } + + public (string, string) summary_scope(string name, string family = null, string default_name = null, Tensor[] values = null) + { + string scope_base_name = string.IsNullOrEmpty(family) ? name : $"{family}/{name}"; + return tf_with(ops.name_scope(scope_base_name, default_name: default_name, values), scope => + { + var tag = scope.scope_name; + if (string.IsNullOrEmpty(family)) + tag = tag.Remove(tag.Length - 1); + else + tag = $"{family}/{tag.Remove(tag.Length - 1)}"; + + return (tag, scope.scope_name); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs new file mode 100644 index 000000000..2a395e19b --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Text; + +namespace Tensorflow.Summaries +{ + /// + /// Abstractly implements the SummaryWriter API. + /// + public abstract class SummaryToEventTransformer + { + public void add_summary(string summary, int global_step = 0) + { + var bytes = UTF8Encoding.Unicode.GetBytes(summary); + // var summ = Tensorflow.Summary.Parser.ParseFrom(bytes); + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj deleted file mode 100644 index 8689e8c9e..000000000 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ /dev/null @@ -1,38 +0,0 @@ - - - - netstandard2.0 - TensorFlow.NET - Tensorflow - - - - true - DEBUG;TRACE - - - - - - - - - - - - - - - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - - diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj.DotSettings b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj.DotSettings new file mode 100644 index 000000000..d31d2eab1 --- /dev/null +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj new file mode 100644 index 000000000..42c0399da --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -0,0 +1,192 @@ + + + + netstandard2.0;net6.0 + Tensorflow.Binding + Tensorflow + 2.15.0 + 0.150.0 + 10.0 + enable + Haiping Chen, Eli Belash, Yaohui Liu, Meinrad Recheis + SciSharp STACK + False + Apache 2.0, Haiping Chen since 2018 + https://github.com/SciSharp/TensorFlow.NET + git + http://scisharpstack.org + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI + Google's TensorFlow full binding in .NET Standard. +Building, training and infering deep learning models. +https://tensorflownet.readthedocs.io + 0.150.0.0 + + tf.net 0.150.x and above are based on tensorflow native 2.15.0 + * Support BERT model. + + tf.net 0.110.x and above are based on tensorflow native 2.11.0 + * Support RNN, LSTM model. + * Support Transformer model. + * Added IMDB dataset. + + tf.net 0.100.x and above are based on tensorflow native 2.10.0 + + * Eager Mode is added finally. + * tf.keras is partially working. + * tf.data is added. + * Autograph works partially. + * Improve memory usage. + + TensorFlow .NET v0.3x is focused on making more Keras API works. + Keras API is a separate package released as TensorFlow.Keras. + + tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. + tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. + tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. + tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. + tf.net 0.11x.x aligns with TensorFlow v2.11.x native library. + tf.net 0.15x.x aligns with TensorFlow v2.15.x native library. + + 0.150.0.0 + LICENSE + true + packages + true + AnyCPU;x64 + TensorFlow.NET + Debug;Release;GPU + + + + true + TRACE;DEBUG;TRACK_TENSOR_LIFE_1 + AnyCPU + + + + true + TRACE;DEBUG;TRACK_TENSOR_LIFE_1 + AnyCPU + + + + true + TRACE;DEBUG;TRACK_TENSOR_LIFE1 + x64 + TensorFlow.NET.xml + + + + true + TRACE;DEBUG;TRACK_TENSOR_LIFE1 + x64 + TensorFlow.NET.xml + + + + true + + + + true + + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + 1 + $(NoWarn),1570,1573,1591,1712,8603,8604,8625,CS0612 + + + + + + + + + + + + + + + True + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Dimension.cs b/src/TensorFlowNET.Core/Tensors/Dimension.cs new file mode 100644 index 000000000..1bf551948 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Dimension.cs @@ -0,0 +1,29 @@ +namespace Tensorflow +{ + public class Dimension + { + long _value; + public long value => _value; + + public Dimension(long value) + { + _value = value; + } + + public Dimension merge_with(Dimension other) + { + if (_value == -1) + return new Dimension(other.value); + else + return new Dimension(_value); + } + + public static implicit operator Dimension(long value) + => new Dimension(value); + + public static implicit operator long(Dimension dimension) + => dimension.value; + + public override string ToString() => $"Dimension({_value})"; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs b/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs new file mode 100644 index 000000000..c6404c3fa --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class ParsedSliceArgs + { + public int[] Begin { get; set; } + public Tensor PackedBegin { get; set; } + public int[] End { get; set; } + public Tensor PackedEnd { get; set; } + public int[] Strides { get; set; } + public Tensor PackedStrides { get; set; } + public int BeginMask { get; set; } + public int EndMask { get; set; } + public int ShrinkAxisMask { get; set; } + public int NewAxisMask { get; set; } + public int EllipsisMask { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs new file mode 100644 index 000000000..0f09d4128 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs @@ -0,0 +1,200 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using Tensorflow.NumPy; + +namespace Tensorflow +{ + /// + /// Represents a ragged tensor. + /// + public class RaggedTensor : CompositeTensor + { + Tensor _values; + RowPartition _row_partition; + Tensor _row_splits => _row_partition.row_splits; + + public TF_DataType dtype => _values.dtype; + public Shape shape + { + get + { + var nrows = _row_partition.static_nrows; + var ncols = _row_partition.static_uniform_row_length; + return new Shape(nrows, ncols); + } + } + + public Tensor this[int index] + { + get + { + return tf_with(ops.name_scope(null, "RaggedGetItem"), scope => + { + string name = scope; + return _ragged_getitem(index); + }); + } + } + + public RaggedTensor this[params Slice[] slices] + { + get + { + var row_key = slices[0]; + var inner_keys = slices.Skip(1).ToArray(); + + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "RaggedGetItem", args), scope => + { + string name = scope; + return _ragged_getitem_inner_dimensions(this, inner_keys); + }); + } + } + + Tensor _ragged_getitem(int row_key) + { + var starts = _row_splits[":-1"]; + var limits = _row_splits["1:"]; + var row = _values[starts[row_key], limits[row_key]]; + return row; + } + + RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) + { + return input; + } + + public RaggedTensor(Tensor values, + bool @internal = true, + RowPartition row_partition = null) + { + _values = values; + _row_partition = row_partition; + } + + public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true) + { + return new RaggedTensor(values, @internal: true, row_partition: row_partition); + } + + /// + /// Creates a `RaggedTensor` with rows partitioned by `value_rowids`. + /// + /// + /// + /// + /// + /// + /// + public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids, + Tensor nrows = null, string name = null, bool validate = true) + { + return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope => + { + var row_partition = RowPartition.from_value_rowids(value_rowids, + nrows: nrows, + validate: validate); + return from_row_partition(values, row_partition, validate: validate); + }); + } + + public static RaggedTensor from_row_splits(Tensor values, Tensor row_splits, + string name = null, bool validate = true) + { + return tf_with(ops.name_scope(name, "RaggedFromRowSplits"), scope => + { + var row_partition = RowPartition.from_row_splits(row_splits, + validate: validate); + return from_row_partition(values, row_partition, validate: validate); + }); + } + + Tensor _to_variant(bool batched_input = false, string name = null) + => tf_with(ops.name_scope(name, "RaggedToVariant"), scope => + { + return tf.Context.ExecuteOp("RaggedTensorToVariant", name, + new ExecuteOpArgs(nested_row_splits, flat_values) + { + GetGradientAttrs = op => new + { + RAGGED_RANK = op.get_attr("RAGGED_RANK"), + Tvalues = op.get_attr("Tvalues"), + Tsplits = op.get_attr("Tsplits"), + batched_input = op.get_attr("batched_input") + } + }.SetAttributes(new { batched_input })); + }); + + Tensor flat_values + => _values; + + Tensor[] nested_row_splits + => new[] { _row_splits }; + + public override string ToString() + => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; + + public static implicit operator Tensor(RaggedTensor indexedSlices) + => indexedSlices._to_variant(); + + public static implicit operator RaggedTensor(Tensor tensor) + { + return tensor.Tag as RaggedTensor; + } + public Tensor nrows(TF_DataType out_type, string name = null) + { + tf_with(ops.name_scope(name, "RaggedNRows"), scope => + { + return math_ops.cast(this._row_partition.nrows(), dtype: out_type); + }); + return null; + } + public RaggedTensor row_lengths(int axis=-1, string name=null) + { + if (axis == 0) return this._row_partition.nrows(); + if (axis == 1) return this._row_partition.row_lengths(); + var values = (RaggedTensor)this._values; + axis = array_ops.get_positive_axis( + axis, this.shape.rank, ndims_name: "rank(this)"); + if (axis == 0) return this.nrows(this._row_partition.GetDataType()); + else if (axis == 1) + { + var splits = this._row_partition.row_splits; + return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)]; + + } + else if (this._values is RaggedTensor) + { + return values.row_lengths(axis - 1); + } + else + { + var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType()); + return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) * + shape[axis - 1]; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs new file mode 100644 index 000000000..9e242ff38 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs @@ -0,0 +1,158 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Serilog.Debugging; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +//using System.ComponentModel.DataAnnotations; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Framework; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Partitioning of a sequence of values into contiguous subsequences ("rows"). + /// + public class RowPartition : CompositeTensor + { + Tensor _row_splits; + public Tensor row_splits => _row_splits; + Tensor _row_lengths; + Tensor _value_rowids; + Tensor _nrows; + + public int static_nrows + { + get + { + return (int)_row_splits.shape[0] - 1; + } + } + + public int static_uniform_row_length + { + get + { + return -1; + } + } + + public RowPartition(Tensor row_splits, + Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null, + Tensor uniform_row_length = null) + { + _row_splits = row_splits; + _row_lengths = row_lengths; + _value_rowids = value_rowids; + _nrows = nrows; + } + + /// + /// Creates a `RowPartition` with rows partitioned by `value_rowids`. + /// + /// + /// + /// + /// + /// + public static RowPartition from_value_rowids(Tensor value_rowids, + Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope => + { + var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32); + var nrows_int32 = math_ops.cast(nrows, dtypes.int32); + var row_lengths = tf.math.bincount(value_rowids_int32, + minlength: nrows_int32, + maxlength: nrows_int32, + dtype: value_rowids.dtype); + var row_splits = array_ops.concat(new Tensor[] + { + ops.convert_to_tensor(new long[] { 0 }), + tf.cumsum(row_lengths) + }, axis: 0); + + return new RowPartition(row_splits, + row_lengths: row_lengths, + value_rowids: value_rowids, + nrows: nrows); + }); + } + + public static RowPartition from_row_splits(Tensor row_splits, + bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(null, "RowPartitionFromRowSplits"), scope => + { + return new RowPartition(row_splits); + }); + } + + public static RowPartition from_row_lengths(Tensor row_lengths, + bool validate=true, + TF_DataType dtype = TF_DataType.TF_INT32, + TF_DataType dtype_hint= TF_DataType.TF_INT32) + { + row_lengths = _convert_row_partition( + row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); + Tensor row_limits = math_ops.cumsum(row_lengths, tf.constant(-1)); + Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); + return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); + } + + public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, + TF_DataType dtype_hint= TF_DataType.TF_INT64) + { + if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); + if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); + return partition; + } + + public Tensor nrows() + { + /*Returns the number of rows created by this `RowPartition*/ + if (this._nrows != null) return this._nrows; + var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); + if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; + else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); + } + + public Tensor row_lengths() + { + + if (this._row_splits != null) + { + int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); + return tf.constant(nrows_plus_one - 1); + + } + if (this._row_lengths != null) + { + var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); + return tf.constant(nrows); + } + if(this._nrows != null) + { + return tensor_util.constant_value(this._nrows); + } + return tf.constant(-1); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs new file mode 100644 index 000000000..54ba2a5f5 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs @@ -0,0 +1,76 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Represents a sparse tensor. + /// + public class SparseTensor : CompositeTensor + { + public Tensor indices; + + public Tensor values; + + public Tensor dense_shape; + + public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape) + { + this.indices = indices; + this.values = values; + this.dense_shape = dense_shape; + _init(); + } + + public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_) + { + tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate + { + indices = ops.convert_to_tensor( + indices_, name: "indices", dtype: dtypes.int64); + values = ops.convert_to_tensor(values_, name: "values"); + dense_shape = ops.convert_to_tensor( + dense_shape_, name: "dense_shape", dtype: dtypes.int64); + }); + _init(); + } + + void _init() + { + var indices_shape = indices.shape.with_rank(2); + var values_shape = values.shape.with_rank(1); + var dense_shape_shape = dense_shape.shape.with_rank(1); + + indices_shape["0"].merge_with(new Shape(values_shape[0])); + indices_shape["1"].merge_with(new Shape(dense_shape_shape[0])); + } + + public static implicit operator Tensor(SparseTensor indexedSlices) + { + return indexedSlices.values; + } + + public static implicit operator SparseTensor(Tensor tensor) + { + return tensor.Tag as SparseTensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Tensors/RefVariable.cs deleted file mode 100644 index 129d06184..000000000 --- a/src/TensorFlowNET.Core/Tensors/RefVariable.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class RefVariable : Variable - { - public bool _in_graph_mode = true; - - public RefVariable(object initial_value, - TF_DataType trainable, - bool validate_shape = true) : - base(initial_value, trainable, validate_shape) - { - - } - - private void _init_from_args() - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs new file mode 100644 index 000000000..d7ece8d22 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeStringTensorHandle : SafeTensorHandle + { + Shape _shape; + SafeTensorHandle _tensorHandle; + const int TF_TSRING_SIZE = 24; + + protected SafeStringTensorHandle() + { + } + + public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) + : base(handle.DangerousGetHandle()) + { + _tensorHandle = handle; + _shape = shape; + bool success = false; + _tensorHandle.DangerousAddRef(ref success); + } + + protected override bool ReleaseHandle() + { + var _handle = c_api.TF_TensorData(_tensorHandle); +#if TRACK_TENSOR_LIFE + Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}"); +#endif + for (int i = 0; i < _shape.size; i++) + { + c_api.TF_StringDealloc(_handle); + _handle += TF_TSRING_SIZE; + } + + SetHandle(IntPtr.Zero); + _tensorHandle.DangerousRelease(); + + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/SafeTensorHandle.cs b/src/TensorFlowNET.Core/Tensors/SafeTensorHandle.cs new file mode 100644 index 000000000..43320e3d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/SafeTensorHandle.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class SafeTensorHandle : SafeTensorflowHandle + { + protected SafeTensorHandle() + { + } + + public SafeTensorHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { +#if TRACK_TENSOR_LIFE + print($"Delete TensorHandle 0x{handle.ToString("x16")}"); +#endif + c_api.TF_DeleteTensor(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TF_BindingArray.cs b/src/TensorFlowNET.Core/Tensors/TF_BindingArray.cs new file mode 100644 index 000000000..535541b82 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TF_BindingArray.cs @@ -0,0 +1,29 @@ +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + [StructLayout(LayoutKind.Sequential)] + public struct TF_BindingArray + { + public IntPtr array; + public int length; + + public static implicit operator TF_BindingArray(IntPtr handle) + => handle == IntPtr.Zero ? default : Marshal.PtrToStructure(handle); + + public unsafe IntPtr this[int index] + => array == IntPtr.Zero ? IntPtr.Zero : *((IntPtr*)array + index); + + public unsafe IntPtr[] Data + { + get + { + var results = new IntPtr[length]; + for (int i = 0; i < length; i++) + results[i] = array == IntPtr.Zero ? IntPtr.Zero : *((IntPtr*)array + i); + return results; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 543fb4e45..2a6f71147 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -1,9 +1,13 @@ -using System; -using System.Collections.Generic; -using System.Text; +using Newtonsoft.Json; +using Tensorflow.Keras.Saving.Common; namespace Tensorflow { + /// + /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. + /// The enum values here are identical to corresponding values in types.proto. + /// + [JsonConverter(typeof(CustomizedDTypeJsonConverter))] public enum TF_DataType { DtInvalid = 0, @@ -30,6 +34,30 @@ public enum TF_DataType TF_RESOURCE = 20, TF_VARIANT = 21, TF_UINT32 = 22, - TF_UINT64 = 23 + TF_UINT64 = 23, + + DtFloatRef = 101, // DT_FLOAT_REF + DtDoubleRef = 102, // DT_DOUBLE_REF + DtInt32Ref = 103, // DT_INT32_REF + DtUint8Ref = 104, + DtInt16Ref = 105, + DtInt8Ref = 106, + DtStringRef = 107, + DtComplex64Ref = 108, + DtInt64Ref = 109, // DT_INT64_REF + DtBoolRef = 110, + DtQint8Ref = 111, + DtQuint8Ref = 112, + DtQint32Ref = 113, + DtBfloat16Ref = 114, + DtQint16Ref = 115, + DtQuint16Ref = 116, + DtUint16Ref = 117, + DtComplex128Ref = 118, + DtHalfRef = 119, + DtResourceRef = 120, + DtVariantRef = 121, + DtUint32Ref = 122, + DtUint64Ref = 123, } } diff --git a/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs b/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs new file mode 100644 index 000000000..233b16e56 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public enum TF_TString_Type + { + TF_TSTR_SMALL = 0, + TF_TSTR_LARGE = 1, + TF_TSTR_OFFSET = 2, + TF_TSTR_VIEW = 3 + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs b/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs new file mode 100644 index 000000000..06c0be8dd --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TF_Tensor.cs @@ -0,0 +1,21 @@ +using System; + +namespace Tensorflow +{ + public struct TF_Tensor + { + IntPtr _handle; + + public TF_Tensor(IntPtr handle) + => _handle = handle; + + public static implicit operator TF_Tensor(IntPtr handle) + => new TF_Tensor(handle); + + public static implicit operator IntPtr(TF_Tensor tensor) + => tensor._handle; + + public override string ToString() + => $"TF_Tensor 0x{_handle.ToString("x16")}"; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs new file mode 100644 index 000000000..1e8bfc8dc --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs @@ -0,0 +1,29 @@ +using System; + +namespace Tensorflow +{ + public partial class Tensor + { + /// + /// Used to keep the original variable when slicing + /// + public ResourceVariable OriginalVar { get; set; } + public ParsedSliceArgs OriginalVarSlice { get; set; } + + public ResourceVariable assign(Tensor tensor) + { + if (tensor.dtype != dtype) + throw new ArrayTypeMismatchException(""); + + if (OriginalVar != null) + { + OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice); + return OriginalVar; + } + else + { + throw new RuntimeError($"Operation doesn't support. {this.name} is a constant tensor. Make sure to initiate {this.name} from tf.Variable() and declare {this.name} as ResourceVariable or var."); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs new file mode 100644 index 000000000..fdd62aeed --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs @@ -0,0 +1,23 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow; + +public partial class Tensor +{ + public TensorSpec ToTensorSpec() + => new TensorSpec(shape, dtype, name); +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs new file mode 100644 index 000000000..e7ff9f748 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -0,0 +1,206 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Numerics; +using System.Text; +using static Tensorflow.c_api; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] + public partial class Tensor + { + public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); + + protected Tensor() + { + } + + /// + /// Create a Tensor object from an existing TF handle + /// + /// Handle to a object. + public unsafe Tensor(SafeTensorHandle handle, bool clone = false) + { + _handle = handle; + if (clone && handle != null) + _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); + } + + /// + /// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller) + /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor + /// but not the memory itself! + /// + /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Tensor shape + /// TF data type + public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) + { + _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); + } + + public unsafe Tensor(NDArray nd) + { + _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); + } + + #region scala + public Tensor(bool value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(byte value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(sbyte value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(short value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(int value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(uint value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(long value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(ulong value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(float value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(double value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(string value) => InitTensor(new[] { value }, Shape.Scalar); + #endregion + + #region 1d array + public Tensor(bool[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(sbyte[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(byte[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(short[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(ushort[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(int[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(uint[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(long[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(ulong[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(float[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(double[] data, Shape? shape = null) => InitTensor(data, shape); + public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); + #endregion + + public Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); + public Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); + public Tensor(byte[] bytes, Shape shape, TF_DataType dtype) => InitTensor(shape, bytes, dtype); + + public Tensor(Operation op, int value_index, TF_DataType dtype) + { + _op = op; + _value_index = value_index; + _override_dtype = dtype; + _tf_output = null; + _id = ops.uid(); + } + + internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output) + { + Tensor ret = new Tensor(op, value_index, dtype); + ret._tf_output = tf_output; + return ret; + } + + protected unsafe void InitTensor(Shape shape, TF_DataType dtype) + { + _handle = TF_NewTensor(shape, dtype, null); + _id = ops.uid(); + } + + protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) + { + if (dtype == TF_DataType.TF_STRING) + _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); + else + _handle = TF_NewTensor(bytes, shape, dtype); + _id = ops.uid(); + } + + protected unsafe void InitTensor(Array array, Shape? shape = null) + { + shape = shape ?? array.GetShape(); + var dtype = array.GetDataType(); + + if (shape.size == 0 && dtype != TF_DataType.TF_STRING) + { + _handle = TF_NewTensor(shape, dtype, null); + return; + } + + _handle = array switch + { + bool[] val => InitTensor(val, shape, dtype), + bool[,] val => InitTensor(val, shape, dtype), + bool[,,] val => InitTensor(val, shape, dtype), + bool[,,,] val => InitTensor(val, shape, dtype), + byte[] val => InitTensor(val, shape, dtype), + byte[,] val => InitTensor(val, shape, dtype), + byte[,,] val => InitTensor(val, shape, dtype), + byte[,,,] val => InitTensor(val, shape, dtype), + short[] val => InitTensor(val, shape, dtype), + short[,] val => InitTensor(val, shape, dtype), + short[,,] val => InitTensor(val, shape, dtype), + short[,,,] val => InitTensor(val, shape, dtype), + int[] val => InitTensor(val, shape, dtype), + int[,] val => InitTensor(val, shape, dtype), + int[,,] val => InitTensor(val, shape, dtype), + int[,,,] val => InitTensor(val, shape, dtype), + long[] val => InitTensor(val, shape, dtype), + long[,] val => InitTensor(val, shape, dtype), + long[,,] val => InitTensor(val, shape, dtype), + long[,,,] val => InitTensor(val, shape, dtype), + ulong[] val => InitTensor(val, shape, dtype), + ulong[,] val => InitTensor(val, shape, dtype), + ulong[,,] val => InitTensor(val, shape, dtype), + ulong[,,,] val => InitTensor(val, shape, dtype), + float[] val => InitTensor(val, shape, dtype), + float[,] val => InitTensor(val, shape, dtype), + float[,,] val => InitTensor(val, shape, dtype), + float[,,,] val => InitTensor(val, shape, dtype), + double[] val => InitTensor(val, shape, dtype), + double[,] val => InitTensor(val, shape, dtype), + double[,,] val => InitTensor(val, shape, dtype), + double[,,,] val => InitTensor(val, shape, dtype), + string[] val => StringTensor(val, shape), + _ => throw new NotImplementedException("") + }; + + _id = ops.uid(); + } + + unsafe SafeTensorHandle InitTensor(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe SafeTensorHandle InitTensor(T[,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe SafeTensorHandle InitTensor(T[,,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0, 0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe SafeTensorHandle InitTensor(T[,,,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0, 0, 0]) + return TF_NewTensor(shape, dtype, addr); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs new file mode 100644 index 000000000..ee587b2a4 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs @@ -0,0 +1,13 @@ +using System; +using System.Runtime.CompilerServices; + +namespace Tensorflow +{ + public partial class Tensor + { + public static Tensor operator !=(Tensor x, int y) + => gen_math_ops.not_equal(x, constant_op.constant(y, dtype: x.dtype)); + public static Tensor operator ==(Tensor x, int y) + => gen_math_ops.equal(x, constant_op.constant(y, dtype: x.dtype)); + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs new file mode 100644 index 000000000..d20c48aba --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -0,0 +1,113 @@ +using System; +using System.Runtime.CompilerServices; + +namespace Tensorflow +{ + public partial class Tensor + { + public unsafe static explicit operator bool(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_BOOL); + return *(bool*)tensor.buffer; + } + + public unsafe static explicit operator sbyte(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT8); + return *(sbyte*)tensor.buffer; + } + + public unsafe static explicit operator byte(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT8); + return *(byte*)tensor.buffer; + } + + public unsafe static explicit operator ushort(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT16); + return *(ushort*)tensor.buffer; + } + + public unsafe static explicit operator short(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT16); + return *(short*)tensor.buffer; + } + + public unsafe static explicit operator int(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT32); + return *(int*)tensor.buffer; + } + + public unsafe static explicit operator uint(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT32); + return *(uint*)tensor.buffer; + } + + public unsafe static explicit operator long(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT64); + return *(long*)tensor.buffer; + } + + public unsafe static explicit operator ulong(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT64); + return *(ulong*)tensor.buffer; + } + + public unsafe static explicit operator float(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_FLOAT); + return *(float*)tensor.buffer; + } + + public unsafe static explicit operator double(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_DOUBLE); + return *(double*)tensor.buffer; + } + + public unsafe static explicit operator string(Tensor tensor) + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_STRING); + return new string((char*)tensor.buffer, 0, (int)tensor.size); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureDType(Tensor tensor, TF_DataType @is) + { + if (tensor.dtype != @is) + throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureScalar(Tensor tensor) + { + if (tensor == null) + throw new ArgumentNullException(nameof(tensor)); + + if (tensor.shape.ndim != 0) + throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); + + if (tensor.shape.size != 1) + throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); + } + + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs new file mode 100644 index 000000000..80d8b5f2d --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Flatten.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public partial class Tensor + { + public object[] Flatten() + { + return new Tensor[] { this }; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs new file mode 100644 index 000000000..f51b097a0 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -0,0 +1,18 @@ +using System; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + public static implicit operator SafeTensorHandle(Tensor tensor) + => tensor._handle; + + public static implicit operator Operation(Tensor tensor) + => tensor?.op; + + public static implicit operator Tensor(SafeTensorHandle handle) + => new Tensor(handle); + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs new file mode 100644 index 000000000..51062cf3b --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -0,0 +1,199 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + public Tensor this[int idx] => slice(idx); + + public Tensor this[params Slice[] slices] + { + get + { + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "strided_slice", args), scope => + { + string name = scope; + if (args.Begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(args.Begin), + array_ops.stack(args.End), + array_ops.stack(args.Strides)); + + return array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: args.BeginMask, + end_mask: args.EndMask, + shrink_axis_mask: args.ShrinkAxisMask, + new_axis_mask: args.NewAxisMask, + ellipsis_mask: args.EllipsisMask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + } + + public Tensor this[params string[] slices] + => this[slices.Select(x => new Slice(x)).ToArray()]; + + public Tensor slice(Slice slice) + { + var slice_spec = new int[] { slice.Start.Value }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + begin.Add(s); + if (slice.Stop.HasValue) + { + end.Add(slice.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(slice.Step); + + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + + public Tensor this[Tensor start, Tensor stop = null, Tensor step = null] + { + get + { + var args = tensor_util.ParseSlices(start, stop: stop, step: step); + + return tf_with(ops.name_scope(null, "strided_slice", args), scope => + { + string name = scope; + + var tensor = gen_array_ops.strided_slice( + this, + args.PackedBegin, + args.PackedEnd, + args.PackedStrides, + begin_mask: args.BeginMask, + end_mask: args.EndMask, + shrink_axis_mask: args.ShrinkAxisMask, + new_axis_mask: args.NewAxisMask, + ellipsis_mask: args.EllipsisMask, + name: name); + + tensor.OriginalVarSlice = args; + + return tensor; + }); + } + } + + public Tensor slice(int start) + { + var slice_spec = new int[] { start }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + begin.Add(s); + end.Add(s + 1); + strides.Add(1); + shrink_axis_mask |= (1 << index); + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return array_ops.strided_slice(this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs new file mode 100644 index 000000000..ca946ca48 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow; + +public partial class Tensor +{ + public bool IsFromKerasTensor { get; set; } + + /// + /// Keras History: (Layer, (node_index, tensor_index)) + /// + public KerasHistory KerasHistory { get; set; } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs new file mode 100644 index 000000000..c7a631d8b --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -0,0 +1,314 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Numerics; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + public static Tensor operator +(Tensor lhs, ResourceVariable rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, sbyte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(sbyte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, byte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(byte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, short rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(short lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ushort rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ushort lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, int rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(int lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, uint rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(uint lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ulong rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ulong lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, long rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(long lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, float rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(float lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, double rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(double lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, Complex rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Complex lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator -(Tensor lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, NDArray rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(NDArray lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, sbyte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(sbyte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, byte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(byte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, short rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(short lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ushort rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ushort lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, int rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(int lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, uint rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(uint lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ulong rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ulong lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, long rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(long lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, float rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(float lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, double rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(double lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, Complex rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Complex lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator *(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, byte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(byte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, short rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(short lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ushort rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ushort lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, int rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(int lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, uint rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(uint lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ulong rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ulong lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, long rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(long lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, float rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(float lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, double rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); + public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, byte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(byte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, short rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(short lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ushort rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ushort lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, int rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(int lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, uint rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(uint lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ulong rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ulong lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, long rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(long lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, float rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(float lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, double rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(double lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, Complex rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Complex lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + + public static Tensor operator >(Tensor lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, NDArray rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(NDArray lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, sbyte rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(sbyte lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, byte rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(byte lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, short rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(short lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, ushort rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(ushort lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, int rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(int lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, uint rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(uint lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, ulong rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(ulong lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), ops.convert_to_tensor(rhs)); + public static Tensor operator >(Tensor lhs, long rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(long lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, float rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(float lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, double rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(double lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >(Tensor lhs, Complex rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >(Complex lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, NDArray rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(NDArray lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, sbyte rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(sbyte lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, byte rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(byte lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, short rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(short lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, ushort rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(ushort lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, int rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(int lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, uint rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(uint lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, ulong rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(ulong lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, long rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(long lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, float rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(float lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, double rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(double lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <(Tensor lhs, Complex rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <(Complex lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, NDArray rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(NDArray lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, sbyte rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(sbyte lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, byte rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(byte lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, short rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(short lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, ushort rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(ushort lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, int rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(int lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, uint rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(uint lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, ulong rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(ulong lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, long rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(long lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, float rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(float lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, double rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(double lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator >=(Tensor lhs, Complex rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator >=(Complex lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, NDArray rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(NDArray lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, sbyte rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(sbyte lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, byte rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(byte lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, short rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(short lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, ushort rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(ushort lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, int rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(int lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, uint rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(uint lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, ulong rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(ulong lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, long rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(long lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, float rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(float lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, double rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(double lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); + public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); + public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); + + + private static readonly TF_DataType[] _intTfDataTypes = { + TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, + TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, + TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 + }; + + private static string div_or_truediv(string name, Tx x, Ty y) + { + bool is_floating = false; + var types = new List(); + + if (x is Tensor t1) + types.add(t1.dtype.is_floating()); + + if (y is Tensor t2) + types.add(t2.dtype.is_floating()); + + is_floating = types.Contains(true); + + return is_floating ? "truediv" : name; + } + + protected static Tensor BinaryOpWrapper(string name, Tx x, Ty y) + { + return tf_with(ops.name_scope(null, name, new { x, y }), scope => + { + var dtype = GetBestDType(x, y); + var x1 = ops.convert_to_tensor(x, name: "x", dtype: dtype); + var y1 = ops.convert_to_tensor(y, name: "y", dtype: dtype); + string newname = scope; + + return name.ToLowerInvariant() switch + { + "add" => math_ops.add_v2(x1, y1, name: newname), + "div" => math_ops.div(x1, y1, name: newname), + "floordiv" => gen_math_ops.floor_div(x1, y1, name: newname), + "truediv" => math_ops.truediv(x1, y1, name: newname), + "mul" => math_ops.multiply(x1, y1, name: newname), + "sub" => gen_math_ops.sub(x1, y1, name: newname), + "mod" => gen_math_ops.floor_mod(x1, y1, name: newname), + _ => throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}") + }; + }); + } + + static TF_DataType GetBestDType(Tx x, Ty y) + { + var dtype1 = x.GetDataType(); + var dtype2 = y.GetDataType(); + if (dtype1.is_integer() && dtype2.is_floating()) + return dtype2; + else if (dtype1.is_floating() && dtype2.is_integer()) + return dtype1; + else + return dtype1; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs new file mode 100644 index 000000000..15c2a8826 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Pack.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public partial class Tensor + { + public Tensor Pack(object[] sequences) + { + return sequences[0] as Tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs new file mode 100644 index 000000000..5048d5a58 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -0,0 +1,114 @@ +using System; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + const int TF_TSRING_SIZE = 24; + + public SafeStringTensorHandle StringTensor(string[] strings, Shape shape) + { + // convert string array to byte[][] + var buffer = new byte[strings.Length][]; + for (var i = 0; i < strings.Length; i++) + buffer[i] = Encoding.UTF8.GetBytes(strings[i]); + + return StringTensor(buffer, shape); + } + + public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape) + { + var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, + shape.dims, + shape.ndim, + (ulong)shape.size * TF_TSRING_SIZE); + + var tstr = c_api.TF_TensorData(handle); +#if TRACK_TENSOR_LIFE + print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}"); +#endif + for (int i = 0; i < buffer.Length; i++) + { + c_api.TF_StringInit(tstr); + c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length); + // var data = c_api.TF_StringGetDataPointer(tstr); + tstr += TF_TSRING_SIZE; + } + + return new SafeStringTensorHandle(handle, shape); + } + + public string[] StringData() + { + var buffer = StringBytes(); + + var _str = new string[buffer.Length]; + for (int i = 0; i < _str.Length; i++) + _str[i] = Encoding.UTF8.GetString(buffer[i]); + + return _str; + } + + public string StringData(int index) + { + var bytes = StringBytes(index); + return Encoding.UTF8.GetString(bytes); + } + + public byte[] StringBytes(int index) + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + byte[] buffer = new byte[0]; + var tstrings = TensorDataPointer; + for (int i = 0; i < shape.size; i++) + { + if(index == i) + { + var data = c_api.TF_StringGetDataPointer(tstrings); + var len = c_api.TF_StringGetSize(tstrings); + buffer = new byte[len]; + // var capacity = c_api.TF_StringGetCapacity(tstrings); + // var type = c_api.TF_StringGetType(tstrings); + Marshal.Copy(data, buffer, 0, Convert.ToInt32(len)); + break; + } + tstrings += TF_TSRING_SIZE; + } + return buffer; + } + + public byte[][] StringBytes() + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + // + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. + // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] + // + long size = 1; + foreach (var s in shape.dims) + size *= s; + + var buffer = new byte[size][]; + var tstrings = TensorDataPointer; + for (int i = 0; i < buffer.Length; i++) + { + var data = c_api.TF_StringGetDataPointer(tstrings); + var len = c_api.TF_StringGetSize(tstrings); + buffer[i] = new byte[len]; + // var capacity = c_api.TF_StringGetCapacity(tstrings); + // var type = c_api.TF_StringGetType(tstrings); + Marshal.Copy(data, buffer[i], 0, Convert.ToInt32(len)); + tstrings += TF_TSRING_SIZE; + } + return buffer; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs new file mode 100644 index 000000000..5a9771420 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -0,0 +1,75 @@ +using Tensorflow.NumPy; +using System; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + /// + /// + /// + /// + /// + public virtual unsafe T[] ToArray() where T : unmanaged + { + //Are the types matching? + if (typeof(T).as_tf_dtype() != dtype) + throw new ArrayTypeMismatchException($"Required dtype {dtype} mismatch with {typeof(T).as_tf_dtype()}."); + + if (ndim == 0 && size == 1) //is it a scalar? + { + unsafe + { + return new T[] { *(T*)buffer }; + } + } + + //types match, no need to perform cast + var ret = new T[size]; + var len = (long)(size * dtypesize); + var src = (T*)buffer; + + fixed (T* dst = ret) + System.Buffer.MemoryCopy(src, dst, len, len); + + return ret; + } + + /// + /// Copy of the contents of this Tensor into a NumPy array or scalar. + /// + /// + /// A NumPy array of the same shape and dtype or a NumPy scalar, if this + /// Tensor has rank 0. + /// + public NDArray numpy() + => GetNDArray(dtype); + + protected NDArray GetNDArray(TF_DataType dtype) + { + if (dtype == TF_DataType.TF_STRING) + { + var str= StringData(); + return new NDArray(str, shape); + } + + return new NDArray(this, clone: true); + } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + public unsafe byte[] BufferToArray() + { + // ReSharper disable once LocalVariableHidesMember + var data = new byte[bytesize]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); + + return data; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 74ed8e5f6..65e1c8576 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -1,9 +1,26 @@ -using NumSharp.Core; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; using System; -using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; -using System.Runtime.InteropServices; -using System.Text; +using Tensorflow.Eager; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; namespace Tensorflow { @@ -11,129 +28,225 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public class Tensor + [SuppressMessage("ReSharper", "ConvertToAutoProperty")] + public partial class Tensor : DisposableObject, + ITensorOrOperation, + ITensorOrTensorArray, + IPackable, + ICanBeFlattened { - public Operation op { get; } - public int value_index { get; } + protected long _id; + private readonly Operation _op; + private readonly int _value_index; + private TF_Output? _tf_output; + private readonly TF_DataType _override_dtype; + public long Id => _id; - public Graph graph => op.graph; + /// + /// The Graph that contains this tensor. + /// + public Graph graph => op?.graph; - public string name; + /// + /// The Operation that produces this tensor as an output. + /// + public Operation op => _op; + public Tensor[] outputs => op?.outputs; - public TF_DataType dtype { get; } - public IntPtr handle { get; } - public ulong bytesize { get; } - public ulong dataTypeSize { get;} - public ulong size => bytesize / dataTypeSize; - public IntPtr buffer { get; } - public long[] shape { get; } - /// - /// number of dimensions - /// 0 Scalar (magnitude only) - /// 1 Vector (magnitude and direction) - /// 2 Matrix (table of numbers) - /// 3 3-Tensor (cube of numbers) - /// n n-Tensor (you get the idea) + /// The string name of this tensor.
+ /// Tensor.name is meaningless when eager execution is enabled. ///
- public int rank; + public virtual string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; /// - /// if original buffer is free. + /// The index of this tensor in the outputs of its Operation. /// - private bool deallocator_called; + public int value_index => _value_index; - public Tensor(IntPtr handle) + /// + /// The DType of elements in this tensor. + /// + public virtual TF_DataType dtype => _handle == null ? _override_dtype : c_api.TF_TensorType(_handle); + public virtual ulong bytesize => _handle == null ? 0 : c_api.TF_TensorByteSize(_handle); + public ulong dtypesize => (ulong)dtype.get_datatype_size(); + public ulong size => _handle == null ? 0 : bytesize / dtypesize; + public virtual IntPtr buffer => _handle == null ? IntPtr.Zero : c_api.TF_TensorData(_handle); + public int num_consumers(TF_Output oper_out) => _handle == null ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); + public int ndim => rank; + + /// + /// The name of the device on which this tensor will be produced, or null. + /// + public virtual string Device => op?.Device; + public long[] dims => shape.dims; + + /// + /// Used for keep other pointer when do implicit operating + /// + public object Tag { get; set; } + protected new SafeTensorHandle _handle; + public virtual SafeTensorHandle Handle => _handle; + public Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { get; internal set; } + + protected SafeEagerTensorHandle _eagerTensorHandle; + /// + /// TFE_TensorHandle + /// + public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; + + /// + /// Returns the shape of a tensor. + /// + /// https://www.tensorflow.org/api_docs/python/tf/shape + public Shape shape { - this.handle = handle; - dtype = c_api.TF_TensorType(handle); - rank = c_api.TF_NumDims(handle); - bytesize = c_api.TF_TensorByteSize(handle); - buffer = c_api.TF_TensorData(handle); - dataTypeSize = c_api.TF_DataTypeSize(dtype); - - shape = new long[rank]; - for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(handle, i); + get + { + if (rank < 0) + return Shape.Null; + + return GetShapeInternal(); + } + + set + { + SetShapeInternal(value); + tf.Status.Check(true); + } } - public Tensor(NDArray nd) + protected virtual Shape GetShapeInternal() { - var data = Marshal.AllocHGlobal(sizeof(float) * nd.size); - Marshal.Copy(nd.Data(), 0, data, nd.size); - var dataType = ToTFDataType(nd.dtype); - - var handle = c_api.TF_NewTensor(dataType, - nd.shape.Select(x => (long)x).ToArray(), // shape - nd.ndim, - data, - (UIntPtr)(nd.size * sizeof(float)), - (IntPtr values, IntPtr len, ref bool closure) => - { - // Free the original buffer and set flag - Marshal.FreeHGlobal(data); - closure = true; - }, - ref deallocator_called); - - this.handle = handle; - dtype = c_api.TF_TensorType(handle); - rank = c_api.TF_NumDims(handle); - bytesize = c_api.TF_TensorByteSize(handle); - buffer = c_api.TF_TensorData(handle); - dataTypeSize = c_api.TF_DataTypeSize(dtype); - - shape = new long[rank]; - for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(handle, i); + var dims = new Shape(new long[rank]); + + if (_handle == null) + { + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status); + } + else + { + for (int i = 0; i < rank; i++) + dims[i] = c_api.TF_Dim(_handle, i); + } + + return dims; } - public Tensor(Operation op, int value_index, TF_DataType dtype) + protected virtual void SetShapeInternal(Shape value) { - this.op = op; - this.value_index = value_index; - this.dtype = dtype; + if (value is null || value.ndim == 0 || value.ndim == -1) + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status); + else + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status); } - public TF_Output _as_tf_output() + public int[] _shape_tuple() { - return c_api_util.tf_output(op._c_op, value_index); + return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); } - public T[] Data() + /// + /// Updates the shape of this tensor. + /// + public void set_shape(Tensor shape) { - // Column major order - // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg - // matrix:[[1, 2, 3], [4, 5, 6]] - // index: 0 2 4 1 3 5 - // result: 1 4 2 5 3 6 - var data = new T[size]; - - for (ulong i = 0; i < size; i++) + // ReSharper disable once MergeConditionalExpression + this.shape = shape is null ? null : shape.shape; + } + + /// + /// number of dimensions

+ /// -1 Unknown

+ /// 0 Scalar (magnitude only)

+ /// 1 Vector (magnitude and direction)

+ /// 2 Matrix (table of numbers)

+ /// 3 3-Tensor (cube of numbers)

+ /// n n-Tensor (you get the idea) + ///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank + public virtual int rank + { + get { - data[i] = Marshal.PtrToStructure(buffer + (int)(i * dataTypeSize)); + if (_handle == null) + { + var output = _as_tf_output(); + Status status = new(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(true); + return ndim; + } + + return c_api.TF_NumDims(_handle); } + } + + /// + /// Returns a list of Operations that consume this tensor. + /// + /// + public Operation[] consumers() + { + var output = _as_tf_output(); + var consumer_names = c_api.TF_OperationOutputConsumers_wrapper(output); + return consumer_names.Select(x => graph.OperationByName(x)).ToArray(); + } - return data; + public TF_Output _as_tf_output() + { + if (!_tf_output.HasValue) + _tf_output = new TF_Output(op, _value_index); + + return _tf_output.Value; + } + + public Tensor MaybeMove() + { + var tensor = c_api.TF_TensorMaybeMove(_handle); + return tensor; } - public byte[] Data() + /// + /// Evaluates this tensor in a `Session`. + /// + /// A dictionary that maps `Tensor` objects to feed values. + /// A array corresponding to the value of this tensor. + public NDArray eval(params FeedItem[] feed_dict) { - var data = new byte[bytesize]; - Marshal.Copy(buffer, data, 0, (int)bytesize); + return ops._eval_using_default_session(this, feed_dict, graph); + } - return data; + /// + /// Evaluates this tensor in a `Session`. + /// + /// A dictionary that maps `Tensor` objects to feed values. + /// The `Session` to be used to evaluate this tensor. + /// A array corresponding to the value of this tensor. + public NDArray eval(Session session, params FeedItem[] feed_dict) + { + return ops._eval_using_default_session(this, feed_dict, graph, session); } - public TF_DataType ToTFDataType(Type type) + public override string ToString() { - switch (type.Name) + // this can throw IndexOutOfRangeException + switch (rank) { - case "Single": - return TF_DataType.TF_FLOAT; + case -1: + return $"tf.Tensor '{name}' shape={shape} dtype={dtype.as_numpy_name()}"; + case 0: + return $"tf.Tensor '{name}' shape={shape} dtype={dtype.as_numpy_name()}"; + default: + return $"tf.Tensor '{name}' shape={shape} dtype={dtype.as_numpy_name()}"; } + } + + protected override void DisposeUnmanagedResources(IntPtr handle) + { - return TF_DataType.DtInvalid; } + + public bool IsDisposed => _disposed; } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs new file mode 100644 index 000000000..ff74956ac --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -0,0 +1,72 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Common.Types; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// TensorArray is designed to hide an underlying implementation object + /// and as such accesses many of that object's hidden fields. + /// + /// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. + /// This class is meant to be used with dynamic iteration primitives such as + /// `while_loop` and `map_fn`. It supports gradient back-propagation via special + /// "flow" control flow dependencies. + /// + public abstract class TensorArray : ITensorOrTensorArray + { + public virtual TF_DataType dtype { get; } + public virtual Tensor handle { get; } + public virtual Tensor flow { get; } + public virtual bool infer_shape { get; } + public virtual bool colocate_with_first_write_call { get; } + + public abstract TensorArray unstack(Tensor value, string name = null); + + public abstract Tensor read(T index, string name = null); + + public abstract TensorArray write(int index, T value, string name = null); + public abstract TensorArray write(Tensor index, Tensor value, string name = null); + + public abstract Tensor stack(string name = null); + public abstract Tensor gather(Tensor indices, string name = null); + + internal bool _dynamic_size; + internal Tensor _size; + internal List _colocate_with; + internal Shape _element_shape; + + public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false, + bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant)) + { + return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, + infer_shape, element_shape, colocate_with_first_write_call, name); + } + else + { + return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, + infer_shape, element_shape, colocate_with_first_write_call, name); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TensorBuffer.cs b/src/TensorFlowNET.Core/Tensors/TensorBuffer.cs index 6027e835c..52c128d9b 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorBuffer.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorBuffer.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow +namespace Tensorflow { public class TensorBuffer { diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs deleted file mode 100644 index 5c3cf87bf..000000000 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ /dev/null @@ -1,28 +0,0 @@ -using Google.Protobuf.Collections; -using NumSharp.Core; -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class TensorShape : Shape - { - public TensorShape(params int[] shape) : base(shape) - { - - } - - public TensorShape as_shape() - { - return this; - } - - public TensorShapeProto as_proto() - { - TensorShapeProto dim = new TensorShapeProto(); - - return new TensorShapeProto(dim); - } - } -} diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs new file mode 100644 index 000000000..2838b000d --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -0,0 +1,350 @@ +using Tensorflow.NumPy; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Operations; +using Tensorflow.Common.Extensions; + +namespace Tensorflow +{ + /// + /// Tensors is used to represent a Tensor or a array of Tensor. + /// It will simplify the API interface, it converts Tensor + /// and Tensor[] to Tensors implicitily. And parse back to Tensor + /// and Tensor[] from Tensors implicitily. + /// It works for tuple and scalar as well. + /// + public sealed class Tensors : Nest, IDisposable + { + public TF_DataType dtype => this.First().dtype; + public Shape shape => this.First().shape; + public int rank => this.First().rank; + public Graph graph => this.First().graph; + public bool IsList { get; set; } + public int Length => this.Count(); + /// + /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. + /// + public Tensor Single + { + get + { + if (Length != 1) + { + throw new ValueError("Tensors with more than one tensor cannot be " + + "implicitly converted to Tensor."); + } + return this.First(); + } + } + + /// + /// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty, + /// otherwise throw an exception. + /// + public Tensor? SingleOrNull + { + get + { + if (Length > 1) + { + throw new ValueError($"Tensors with {Length} tensor cannot be " + + "implicitly converted to Tensor."); + } + return this.FirstOrDefault(); + } + } + + public Tensor this[params string[] slices] + => this.First()[slices]; + + internal Tensors(Nest nested) : base(nested) + { + + } + + public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors)) + { + + } + + public Tensors(IList tensors) : base(tensors.Select(x => new Nest(x))) + { + + } + + public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) + { + + } + + /// + /// Get the element in shallow level. For example, for ts = [1, [2, 3], 4], + /// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3] + /// + /// + /// + public Tensors GetShallow(int index) + { + if(NestType == NestType.Node) + { + if(index > 0) + { + throw new IndexOutOfRangeException(); + } + return this; + } + else if(NestType == NestType.List) + { + return ListValue![index].AsNest().ToTensors(); + } + else + { + throw new NotImplementedException(); + } + } + + private static Nest DealWithConstructorArrayInput(Tensor[] tensors) + { + if (tensors.Length == 0) + { + return Nest.Empty; + } + else if(tensors.Length == 1) + { + return new Nest(tensors[0]); + } + else + { + return new Nest(tensors.Select(x => new Nest(x))); + } + } + + public bool IsSingle() + { + return Length == 1; + } + + public new Tensors MergeWith(Nest? other) + { + return FromNest(base.MergeWith(other)); + } + + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] + public void Add(Tensor tensor) + { + if(NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(NodeValue), new Nest(tensor) }; + NodeValue = null; + } + else if(NestType == NestType.List) + { + ListValue!.Add(new Nest(tensor)); + } + else //Empty + { + NestType = NestType.Node; + NodeValue = tensor; + } + } + + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] + public void AddRange(IEnumerable tensors) + { + if (NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if (NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(NodeValue) }; + ListValue.AddRange(tensors.Select(x => new Nest(x))); + NodeValue = null; + } + else if(NestType == NestType.List) + { + ListValue!.AddRange(tensors.Select(x => new Nest(x))); + } + else // empty + { + NestType = NestType.List; + ListValue = tensors.Select(x => new Nest(x) as INestStructure).ToList(); + } + } + + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] + public void Insert(int index, Tensor tensor) + { + if (NestType == NestType.List) + { + ListValue.Insert(index, new Nest(tensor)); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(NodeValue) }; + ListValue.Insert(index, new Nest(tensor)); + NodeValue = null; + } + else + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + } + + public string[] StringData() + { + return Single.StringData(); + } + + public string StringData(int index) + { + return Single.StringData(index); + } + + public NDArray numpy() + { + return Single.numpy(); + } + + [Obsolete] + public T[] ToArray() where T: unmanaged + { + return Single.ToArray(); + } + + #region Explicit Conversions + public static explicit operator bool(Tensors tensor) + { + return (bool)tensor.Single; + } + + public static explicit operator sbyte(Tensors tensor) + { + return (sbyte)tensor.Single; + } + + public static explicit operator byte(Tensors tensor) + { + return (byte)tensor.Single; + } + + public static explicit operator ushort(Tensors tensor) + { + return (ushort)tensor.Single; + } + + public static explicit operator short(Tensors tensor) + { + return (short)tensor.Single; + } + + public static explicit operator int(Tensors tensor) + { + return (int)tensor.Single; + } + + public static explicit operator uint(Tensors tensor) + { + return (uint)tensor.Single; + } + + public static explicit operator long(Tensors tensor) + { + return (long)tensor.Single; + } + + public static explicit operator ulong(Tensors tensor) + { + return (ulong)tensor.Single; + } + + public static explicit operator float(Tensors tensor) + { + return (byte)tensor.Single; + } + + public static explicit operator double(Tensors tensor) + { + return (double)tensor.Single; + } + + public static explicit operator string(Tensors tensor) + { + return (string)tensor.Single; + } + + public static explicit operator object[](Tensors tensors) + => tensors.Flatten().ToArray(); + #endregion + + #region Implicit Conversions + public static implicit operator Tensors(Tensor tensor) + => new Tensors(tensor); + + public static implicit operator Tensors((Tensor, Tensor) tuple) + => new Tensors(tuple.Item1, tuple.Item2); + + [AutoNumPy] + public static implicit operator Tensors(NDArray nd) + => new Tensors(nd); + + public static implicit operator Tensors(Tensor[] tensors) + => new Tensors(tensors); + + public static implicit operator Tensors(List tensors) + => new Tensors(tensors.ToArray()); + + public static implicit operator Tensor(Tensors? tensors) + => tensors?.SingleOrNull; + + public static implicit operator Tensor[](Tensors tensors) + => tensors.Flatten().ToArray(); + #endregion + + public static Tensors? FromNest(Nest nested) + { + if(nested == Nest.Empty) + { + return null; + } + return new Tensors(nested); + } + + public void Deconstruct(out Tensor a, out Tensors? b) + { + a = this.First(); + b = Length == 1? null : new Tensors(this.Skip(1).ToArray()); + } + + public override string ToString() + { + if(Length == 1) + { + return this.First().ToString(); + } + else + { + return $"Totally {Length} tensors: {base.ToString()}"; + } + } + + public void Dispose() + { + foreach (var tensor in this) + tensor.Dispose(); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Variable.cs b/src/TensorFlowNET.Core/Tensors/Variable.cs deleted file mode 100644 index 19253ce19..000000000 --- a/src/TensorFlowNET.Core/Tensors/Variable.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - /// - /// A variable maintains state in the graph across calls to `run()`. You add a - /// variable to the graph by constructing an instance of the class `Variable`. - /// - /// The `Variable()` constructor requires an initial value for the variable, - /// which can be a `Tensor` of any type and shape. The initial value defines the - /// type and shape of the variable. After construction, the type and shape of - /// the variable are fixed. The value can be changed using one of the assign methods. - /// https://tensorflow.org/guide/variables - /// - public class Variable - { - public Variable(object initial_value, TF_DataType trainable, bool validate_shape = true) - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index a99d14bb0..3779ddcfd 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -1,36 +1,63 @@ -using System; -using System.Collections.Generic; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Text; +using Tensorflow.NumPy; namespace Tensorflow { - public static partial class c_api + public partial class c_api { + /// + /// Allocate and return a new Tensor. + /// + /// TF_DataType + /// const int64_t* + /// int + /// size_t + /// + [DllImport(TensorFlowLibName)] + public static extern SafeTensorHandle TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); + /// /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. /// /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); + public static extern ulong TF_DataTypeSize(TF_DataType dt); /// /// Destroy a tensor. /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteTensor(IntPtr tensor); + public static extern void TF_DeleteTensor(IntPtr tensor); /// /// Return the length of the tensor in the "dim_index" dimension. - /// REQUIRES: 0 <= dim_index < TF_NumDims(tensor) + /// REQUIRES: 0 <= dim_index < TF_NumDims(tensor) /// /// /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index); + public static extern long TF_Dim(SafeTensorHandle tensor, int dim_index); /// /// Return a new tensor that holds the bytes data[0,len-1] @@ -44,7 +71,40 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); + public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg); + + public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype) + { + var length = data.Length; + var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, (ulong)length); + var tensor = TF_TensorData(handle); + if (tensor == IntPtr.Zero) + throw new TensorflowException("AllocateTensor failed."); + fixed (void* addr = &data[0]) + System.Buffer.MemoryCopy(addr, tensor.ToPointer(), length, length); + return handle; + } + + public static unsafe SafeTensorHandle TF_NewTensor(Shape shape, TF_DataType dtype, void* data) + { + var length = shape.size * dtype.get_datatype_size(); + var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, (ulong)length); + var tensor = TF_TensorData(handle); + if (tensor == IntPtr.Zero) + throw new TensorflowException("AllocateTensor failed."); + if (data != null) + System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); + return handle; + } + + public static unsafe SafeTensorHandle TF_NewTensor(T value) + where T : unmanaged + { + var dtype = value.GetType().as_tf_dtype(); + var handle = TF_AllocateTensor(dtype, new long[0], 0, (ulong)dtype.get_datatype_size()); + *(T*)TF_TensorData(handle) = value; + return handle; + } /// /// Return the number of dimensions that the tensor has. @@ -52,7 +112,7 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_NumDims(IntPtr tensor); + public static extern int TF_NumDims(SafeTensorHandle tensor); /// /// Return the size of the underlying data in bytes. @@ -60,7 +120,7 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); + public static extern ulong TF_TensorByteSize(SafeTensorHandle tensor); /// /// Return a pointer to the underlying data buffer. @@ -68,7 +128,16 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); + public static extern IntPtr TF_TensorData(SafeTensorHandle tensor); + + /// + /// Deletes `tensor` and returns a new TF_Tensor with the same content if + /// possible. Returns nullptr and leaves `tensor` untouched if not. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern SafeTensorHandle TF_TensorMaybeMove(SafeTensorHandle tensor); /// /// Return the type of a tensor element. @@ -76,6 +145,105 @@ public static partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); + public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor); + + /// + /// Set a new shape for the Tensor. Note that this API only works after tf2.11. + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims); + + /// + /// Return the size in bytes required to encode a string `len` bytes long into a + /// TF_STRING tensor. + /// + /// size_t + /// + [DllImport(TensorFlowLibName)] + public static extern ulong TF_StringEncodedSize(ulong len); + + /// + /// Encode the string `src` (`src_len` bytes long) into `dst` in the format + /// required by TF_STRING tensors. Does not write to memory more than `dst_len` + /// bytes beyond `*dst`. `dst_len` should be at least + /// TF_StringEncodedSize(src_len). + /// + /// const char* + /// size_t + /// char* + /// size_t + /// TF_Status* + /// On success returns the size in bytes of the encoded string. + [DllImport(TensorFlowLibName)] + public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, byte* dst, ulong dst_len, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringInit(IntPtr t); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringCopy(IntPtr dst, byte[] text, long size); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringCopy(IntPtr dst, string text, long size); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringAssignView(IntPtr dst, IntPtr text, long size); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_StringGetDataPointer(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern TF_TString_Type TF_StringGetType(SafeTensorHandle tst); + + [DllImport(TensorFlowLibName)] + public static extern ulong TF_StringGetSize(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern ulong TF_StringGetCapacity(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringDealloc(IntPtr tst); + + /// + /// Decode a string encoded using TF_StringEncode. + /// + /// const char* + /// size_t + /// const char** + /// size_t* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); + + + public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; + + [MonoPInvokeCallback(typeof(c_api.Deallocator))] + private static void FreeNothingDeallocator(IntPtr dataPtr, IntPtr len, ref c_api.DeallocatorArgs args) + { } + + /// + /// This attribute can be applied to callback functions that will be invoked + /// from unmanaged code to managed code. + /// + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + public sealed class MonoPInvokeCallbackAttribute : Attribute + { + /// + /// Use this constructor to annotate the type of the callback function that + /// will be invoked from unmanaged code. + /// + /// T. + public MonoPInvokeCallbackAttribute(Type t) { } + } } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs new file mode 100644 index 000000000..1a825e0cb --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -0,0 +1,258 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class constant_op + { + /// + /// Creates a constant tensor. + /// + /// The resulting tensor is populated with values of type `dtype`, as + /// specified by arguments `value` and (optionally) `shape` + /// + /// A constant value (or list) of output type `dtype`. + /// The type of the elements of the resulting tensor. + /// Optional dimensions of resulting tensor. + /// Optional name for the tensor. + /// + public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, + Shape shape = null, bool verify_shape = false, + bool allow_broadcast = true, string name = "Const") + { + if (value == null) + return null; + + if(tf.executing_eagerly()) + return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); + else + return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); + } + + private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) + { + var attr_t = tensor.dtype.as_datatype_enum(); + var dims_t = convert_to_eager_tensor(shape, ctx, dtypes.int32); + var inputs_flat = new[] { tensor, dims_t }; + var attrs = new object[] { "T", attr_t, "Tshape", TF_DataType.TF_INT32 }; + var result = tf.Runner.Execute(ctx, "Reshape", 1, inputs_flat, attrs); + return result[0]; + } + + private static Tensor _eager_fill(int[] dims, Tensor value, Context ctx) + { + var attr_t = value.dtype.as_datatype_enum(); + var dims_t = convert_to_eager_tensor(dims, ctx, dtypes.int32); + var inputs_flat = new[] { dims_t, value }; + var attrs = new object[] { "T", attr_t, "index_type", TF_DataType.TF_INT32 }; + var result = tf.Runner.Execute(ctx, "Fill", 1, inputs_flat, attrs); + return result[0]; + } + + private static Tensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) + { + ctx.ensure_initialized(); + // convert data type + if (dtype != TF_DataType.DtInvalid && + value.GetType().Name != "NDArray" && + value.GetType().BaseType.Name != "Array" && + dtype != value.GetDataType()) + { + switch (dtype) + { + case TF_DataType.TF_DOUBLE: + value = Convert.ToDouble(value); + break; + case TF_DataType.TF_FLOAT: + value = Convert.ToSingle(value); + break; + case TF_DataType.TF_INT64: + value = Convert.ToInt64(value); + break; + case TF_DataType.TF_INT32: + value = Convert.ToInt32(value); + break; + default: + break; + } + } + else if (dtype != TF_DataType.DtInvalid && + value is NDArray nd && + nd.dtype != dtype) + { + value = math_ops.cast(nd, dtype); + } + + // non ascii char + if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) + return new EagerTensor(bytes, Shape.Scalar, TF_DataType.TF_STRING); + + switch (value) + { + case EagerTensor val: + return val; + case NDArray val: + return val; + case Shape val: + return new EagerTensor(val.dims, new Shape(val.ndim)); + case Axis val: + return new EagerTensor(val.axis, val.IsScalar ? Shape.Scalar : new Shape(val.size)); + case string val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case string[] val: + return new EagerTensor(val, new Shape(val.Length)); + case bool val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case byte val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case int val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case long val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case ulong val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case float val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case double val: + return new EagerTensor(new[] { val }, Shape.Scalar); + case IEnumerable val: + return ops.convert_to_tensor(val); + case Array val: + return new EagerTensor(val, val.GetShape()); + default: + throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); + } + } + + static Tensor convert_to_eager_tensor(object value, + TF_DataType dtype, + Shape shape, + string name, + bool verify_shape, + bool allow_broadcast) + { + var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); + if (dtype != TF_DataType.DtInvalid && dtype != t.dtype) + { + t = math_ops.cast(t, dtype); + } + if (shape is null || shape.IsNull) + return t; + + if (t.shape.Equals(shape)) + return t; + + if (verify_shape) + throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); + + var num_t = t.shape.size; + if (num_t == shape.size) + return _eager_reshape(t, shape, tf.Context); + if (num_t == 1) + { + if (t.dtype == dtypes.@bool) + throw new NotImplementedException(""); + else + return _eager_fill(shape, t, tf.Context); + } + + throw new NotImplementedException(""); + } + + static Tensor convert_to_graph_tensor(object value, + TF_DataType dtype, + Shape shape, + string name, + bool verify_shape, + bool allow_broadcast) + { + Graph g = ops.get_default_graph(); + var tensor_value = new AttrValue(); + tensor_value.Tensor = tensor_util.make_tensor_proto(value, + dtype: dtype, + shape: shape, + verify_shape: verify_shape, + allow_broadcast: allow_broadcast); + + var dtype_value = new AttrValue + { + Type = tensor_value.Tensor.Dtype, + }; + + var attrs = new Dictionary(); + attrs["value"] = tensor_value; + attrs["dtype"] = dtype_value; + + var op = g.create_op("Const", + new Tensor[0], + new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, + attrs: attrs, + name: name); + + return op.outputs[0]; + } + + /// + /// Function to convert Shape to Tensor. + /// + /// + /// + /// + /// + /// + public static Tensor _tensor_shape_tensor_conversion_function(Shape s, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool as_ref = false) + { + var s_list = s.dims; + var int64_value = 0L; + foreach (var dim in s_list) + { + if (dim > Math.Pow(2, 31)) + { + int64_value = dim; + break; + } + } + + dtype = int64_value > 0 ? TF_DataType.TF_INT64 : TF_DataType.TF_INT32; + + if (string.IsNullOrEmpty(name)) + name = "shape_as_tensor"; + + return constant_op.constant(s_list, dtype: dtype, name: name); + } + + public static bool is_constant(ITensorOrOperation tensor_or_op) + { + if (tensor_or_op is Tensor tensor) + return tensor.op.type == "Const"; + else if (tensor_or_op is Operation op) + return op.type == "Const"; + else + throw new ValueError("is_constant"); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs new file mode 100644 index 000000000..5b4db53b9 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -0,0 +1,348 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Numerics; +using System.Diagnostics; + +namespace Tensorflow +{ + public static class dtypes + { + public static TF_DataType @bool = TF_DataType.TF_BOOL; + public static TF_DataType int8 = TF_DataType.TF_INT8; + public static TF_DataType int32 = TF_DataType.TF_INT32; + public static TF_DataType int64 = TF_DataType.TF_INT64; + public static TF_DataType uint8 = TF_DataType.TF_UINT8; + public static TF_DataType uint32 = TF_DataType.TF_UINT32; + public static TF_DataType uint64 = TF_DataType.TF_UINT64; + public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? + public static TF_DataType float16 = TF_DataType.TF_HALF; + public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType complex = TF_DataType.TF_COMPLEX; + public static TF_DataType complex64 = TF_DataType.TF_COMPLEX64; + public static TF_DataType complex128 = TF_DataType.TF_COMPLEX128; + public static TF_DataType variant = TF_DataType.TF_VARIANT; + public static TF_DataType resource = TF_DataType.TF_RESOURCE; + + /// + /// + /// + /// + /// equivalent to , if none exists, returns null. + public static Type as_system_dtype(this TF_DataType type) + { + switch (type.as_base_dtype()) + { + case TF_DataType.TF_BOOL: + return typeof(bool); + case TF_DataType.TF_UINT8: + return typeof(byte); + case TF_DataType.TF_INT8: + return typeof(sbyte); + case TF_DataType.TF_INT64: + return typeof(long); + case TF_DataType.TF_UINT64: + return typeof(ulong); + case TF_DataType.TF_INT32: + return typeof(int); + case TF_DataType.TF_UINT32: + return typeof(uint); + case TF_DataType.TF_INT16: + return typeof(short); + case TF_DataType.TF_UINT16: + return typeof(ushort); + case TF_DataType.TF_FLOAT: + return typeof(float); + case TF_DataType.TF_DOUBLE: + return typeof(double); + case TF_DataType.TF_STRING: + return typeof(string); + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return typeof(Complex); + default: + throw new NotSupportedException($"Unable to convert {type} to a system data type."); + } + } + + /// + /// + /// + /// + /// + /// When has no equivalent + public static TF_DataType as_tf_dtype(this Type type) + { + while (type.IsArray) + type = type.GetElementType(); + + TF_DataType dtype = TF_DataType.DtInvalid; + + switch (type.Name) + { + case "Char": + dtype = TF_DataType.TF_UINT8; + break; + case "SByte": + dtype = TF_DataType.TF_INT8; + break; + case "Byte": + dtype = TF_DataType.TF_UINT8; + break; + case "Int16": + dtype = TF_DataType.TF_INT16; + break; + case "UInt16": + dtype = TF_DataType.TF_UINT16; + break; + case "Int32": + dtype = TF_DataType.TF_INT32; + break; + case "UInt32": + dtype = TF_DataType.TF_UINT32; + break; + case "Int64": + dtype = TF_DataType.TF_INT64; + break; + case "UInt64": + dtype = TF_DataType.TF_UINT64; + break; + case "Single": + dtype = TF_DataType.TF_FLOAT; + break; + case "Double": + dtype = TF_DataType.TF_DOUBLE; + break; + case "Complex": + dtype = TF_DataType.TF_COMPLEX128; + break; + case "String": + dtype = TF_DataType.TF_STRING; + break; + case "Boolean": + dtype = TF_DataType.TF_BOOL; + break; + default: + dtype = TF_DataType.DtInvalid; + break; + } + + return dtype; + } + + public static TF_DataType tf_dtype_from_name(string name) + { + TF_DataType dtype = name.ToLower() switch + { + "char" => TF_DataType.TF_UINT8, + "boolean" => TF_DataType.TF_BOOL, + "sbyte" => TF_DataType.TF_INT8, + "byte" => TF_DataType.TF_UINT8, + "int16" => TF_DataType.TF_INT16, + "uint16" => TF_DataType.TF_UINT16, + "int32" => TF_DataType.TF_INT32, + "uint32" => TF_DataType.TF_UINT32, + "int64" => TF_DataType.TF_INT64, + "uint64" => TF_DataType.TF_UINT64, + "float16" => TF_DataType.TF_BFLOAT16, + "float32" => TF_DataType.TF_FLOAT, + "single" => TF_DataType.TF_FLOAT, + "float64" => TF_DataType.TF_DOUBLE, + "double" => TF_DataType.TF_DOUBLE, + "complex" => TF_DataType.TF_COMPLEX128, + "string" => TF_DataType.TF_STRING, + _ => TF_DataType.DtInvalid + }; + + return dtype; + } + + public static DataType as_datatype_enum(this TF_DataType type) + { + return (DataType)type; + } + + public static TF_DataType as_base_dtype(this TF_DataType type) + { + return (int)type > 100 ? (TF_DataType)((int)type - 100) : type; + } + + public static int name(this TF_DataType type) + { + return (int)type; + } + + public static string as_numpy_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "string", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + + public static string as_python_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "str", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + + public static int get_datatype_size(this TF_DataType type) + => type.as_base_dtype() switch + { + TF_DataType.TF_BOOL => sizeof(bool), + TF_DataType.TF_UINT8 => sizeof(byte), + TF_DataType.TF_INT8 => sizeof(sbyte), + TF_DataType.TF_UINT16 => sizeof(ushort), + TF_DataType.TF_INT16 => sizeof(short), + TF_DataType.TF_UINT32 => sizeof(uint), + TF_DataType.TF_INT32 => sizeof(int), + TF_DataType.TF_UINT64 => sizeof(ulong), + TF_DataType.TF_INT64 => sizeof(long), + TF_DataType.TF_FLOAT => sizeof(float), + TF_DataType.TF_DOUBLE => sizeof(double), + _ => throw new NotImplementedException("") + }; + + public static Type as_numpy_dtype(this DataType type) + { + return type.as_tf_dtype().as_system_dtype(); + } + + public static DataType as_base_dtype(this DataType type) + { + return (int)type > 100 ? (DataType)((int)type - 100) : type; + } + + [DebuggerStepThrough] + public static TF_DataType as_tf_dtype(this DataType type) + { + return (TF_DataType)type; + } + + public static TF_DataType as_ref(this TF_DataType type) + { + return (int)type < 100 ? (TF_DataType)((int)type + 100) : type; + } + + public static long min(this TF_DataType type) + { + throw new NotImplementedException($"min {type.name()}"); + } + + public static long max(this TF_DataType type) + { + switch (type) + { + case TF_DataType.TF_INT8: + return sbyte.MaxValue; + case TF_DataType.TF_INT16: + return short.MaxValue; + case TF_DataType.TF_INT32: + return int.MaxValue; + case TF_DataType.TF_INT64: + return long.MaxValue; + case TF_DataType.TF_UINT8: + return byte.MaxValue; + case TF_DataType.TF_UINT16: + return ushort.MaxValue; + case TF_DataType.TF_UINT32: + return uint.MaxValue; + default: + throw new NotImplementedException($"max {type.name()}"); + } + } + + public static bool is_complex(this TF_DataType type) + { + return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; + } + + public static bool is_integer(this TF_DataType type) + { + return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 || + type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 || + type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; + } + + public static bool is_unsigned(this TF_DataType type) + { + return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || + type == TF_DataType.TF_UINT64; + } + + public static bool is_bool(this TF_DataType type) + { + return type == TF_DataType.TF_BOOL; + } + + public static bool is_floating(this TF_DataType type) + { + return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; + } + + public static bool is_ref_dtype(this TF_DataType type) + { + return (int)type > 100; + } + + public static bool is_compatible_with(this TF_DataType self, TF_DataType other) + { + return self.as_datatype_enum() == other.as_datatype_enum(); + } + + public static TF_DataType real_dtype(this TF_DataType self) + { + TF_DataType base_ = self.as_base_dtype(); + if (base_ == complex64) + return float32; + else if (base_ == complex128) + return float64; + else + return self; + } + + public static bool is_value_dtype(this TF_DataType type) + { + return ((int)type >= 1 && (int)type <= 19) + || type == TF_DataType.TF_UINT32 + || type == TF_DataType.TF_UINT64; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/shape_utils.cs b/src/TensorFlowNET.Core/Tensors/shape_utils.cs new file mode 100644 index 000000000..a77dd34ce --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/shape_utils.cs @@ -0,0 +1,44 @@ +using System; +using System.Linq; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class shape_utils + { + public static Tensor static_or_dynamic_map_fn(Func fn, Tensor elems, TF_DataType[] dtypes = null, + int parallel_iterations = 32, bool back_prop = true) + { + var outputs = tf.unstack(elems).Select(arg => fn(arg)).ToArray(); + + throw new NotImplementedException(""); + } + + public static Shape from_object_array(object[] shape) + { + var dims = shape.Select(x => + { + if (x is KerasTensor kt && kt.inferred_value != null) + { + return kt.inferred_value.as_int_list()[0]; + } + else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32) + { + return et.ToArray()[0]; + } + else if (x is int i) + { + return i; + } + else if (x is long l) + { + return l; + } + throw new NotImplementedException(); + }).ToArray(); + + return new Shape(dims); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 1c6da7bc2..6e5024efd 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -1,51 +1,725 @@ -using NumSharp.Core; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; using System; using System.Collections.Generic; +using System.Linq; using System.Text; -using tensor_pb2 = Tensorflow; +using Tensorflow.Eager; +using Tensorflow.Graphs; +using static Tensorflow.Binding; +using System.Diagnostics; namespace Tensorflow { public static class tensor_util { - public static TensorProto make_tensor_proto(object values, Type dtype = null) - { - NDArray nparray; - TensorProto tensor_proto = null; - TensorShape tensor_shape = new TensorShape(0); - - switch (values) - { - case float val: - nparray = np.array(new float[] { val }, np.float32); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = DataType.DtFloat, - TensorShape = tensor_shape.as_shape().as_proto() - }; - tensor_proto.FloatVal.Add(val); - break; - case double val: - nparray = np.array(new double[] { val }, np.float64); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = DataType.DtDouble, - TensorShape = tensor_shape.as_shape().as_proto() - }; - tensor_proto.DoubleVal.Add(val); - break; - case string val: - nparray = np.array(new string[] { val }, np.chars); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = DataType.DtString, - TensorShape = tensor_shape.as_shape().as_proto() - }; - tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(val, Encoding.UTF8)); - break; + /// + /// Returns the constant value of the given tensor, if efficiently calculable. + /// + /// + /// + /// + public static NDArray constant_value(Tensor tensor, bool partial = false) + { + if (tensor is NDArray nd) + return nd; + else if (tensor is EagerTensor) + return tensor.numpy(); + + NDArray ret = _ConstantValue(tensor, partial); + if (!(ret is null)) + tensor.graph.prevent_feeding(tensor); + + return ret; + } + + private static NDArray _ConstantValue(Tensor tensor, bool partial) + { + switch (tensor.op.type) + { + case "Const": + return MakeNdarray(tensor.op.get_attr("value") as TensorProto); + default: + return null; + } + } + + public static NDArray MakeNdarray(TensorProto tensor) + { + var shape = new Shape(tensor.TensorShape.Dim.Select(x => x.Size).ToArray()); + var num_elements = shape.size; + var tensor_dtype = tensor.Dtype.as_tf_dtype(); + + T[] ExpandArrayToSize(IList src) + { + if (src.Count == 0) + { + return new T[0]; + } + var pad_count = num_elements - src.Count; + var pre = pad_count / 2; + var after = pad_count - pre; + var first_elem = src[0]; + var last_elem = src[src.Count - 1]; + T[] res = new T[num_elements]; + for (long i = 0; i < num_elements; i++) + { + if (i < pre) res[i] = first_elem; + else if (i >= num_elements - after) res[i] = last_elem; + else res[i] = src[(int)(i - pre)]; + } + return res; + } + + if (shape.ndim > 0 && tensor.TensorContent.Length > 0) + { + return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); + } + NDArray values; + if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) + { + values = np.array(ExpandArrayToSize(tensor.HalfVal)); + } + else if (tensor.Dtype == DataType.DtFloat) + { + values = np.array(ExpandArrayToSize(tensor.FloatVal)); + } + else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) + { + values = np.array(ExpandArrayToSize(tensor.IntVal)); + } + else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) + { + values = np.array(ExpandArrayToSize(tensor.Int64Val)); + } + else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) + { + values = np.array(ExpandArrayToSize(tensor.Uint64Val)); + } + else if (tensor.Dtype == DataType.DtBool) + { + values = np.array(ExpandArrayToSize(tensor.BoolVal)); + } + else + { + throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " + + $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); + } + + if (values.size == 0) + { + return np.zeros(shape, tensor_dtype); + } + + return values.reshape(shape); + } + + private static readonly TF_DataType[] quantized_types = new TF_DataType[] + { + TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, + TF_DataType.TF_QINT32 + }; + + private static Array ConvertArray(Array inputArray, Func converter) + { + if (inputArray == null) + throw new ArgumentNullException(nameof(inputArray)); + + var elementType = typeof(TOut); + var lengths = new int[inputArray.Rank]; + for (var i = 0; i < inputArray.Rank; i++) + { + lengths[i] = inputArray.GetLength(i); + } + + var outputArray = Array.CreateInstance(elementType, lengths); + + FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0); + + return outputArray; + } + + private static void FillArray(Array inputArray, Array outputArray, Func converter, int[] indices, int dimension) + { + if (dimension == inputArray.Rank - 1) + { + for (int i = 0; i < inputArray.GetLength(dimension); i++) + { + indices[dimension] = i; + var inputValue = (TIn)inputArray.GetValue(indices); + var convertedValue = converter(inputValue); + outputArray.SetValue(convertedValue, indices); + } + } + else + { + for (int i = 0; i < inputArray.GetLength(dimension); i++) + { + indices[dimension] = i; + FillArray(inputArray, outputArray, converter, indices, dimension + 1); + } + } + } + + /// + /// Create a TensorProto, invoked in graph mode + /// + /// + /// + /// + /// + /// + /// + public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape? shape = null, bool verify_shape = false, bool allow_broadcast = false) + { + if (allow_broadcast && verify_shape) + throw new ValueError("allow_broadcast and verify_shape are not both allowed."); + if (values is TensorProto tp) + return tp; + + var origin_dtype = values.GetDataType(); + if (dtype == TF_DataType.DtInvalid) + dtype = origin_dtype; + else if (origin_dtype != dtype) + { + var new_system_dtype = dtype.as_system_dtype(); + + if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE) + { + if (values is Array arrayValues) + { + values = dtype switch + { + TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32), + TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle), + TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble), + _ => values, + }; + } else + { + values = Convert.ChangeType(values, new_system_dtype); + } + + } else + { + + } + dtype = values.GetDataType(); + } + + shape = shape ?? values.GetShape(); + var tensor_proto = new TensorProto + { + Dtype = dtype.as_datatype_enum(), + TensorShape = shape.as_shape_proto() + }; + + if (values is NDArray nd) + { + // scalar + if (nd.shape.IsScalar) + { + switch (nd.dtype) + { + case TF_DataType.TF_BOOL: + tensor_proto.BoolVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_UINT8: + tensor_proto.IntVal.AddRange(nd.ToArray().Select(x => (int)x).ToArray()); + break; + case TF_DataType.TF_INT32: + tensor_proto.IntVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_INT64: + tensor_proto.Int64Val.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_FLOAT: + tensor_proto.FloatVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_DOUBLE: + tensor_proto.DoubleVal.AddRange(nd.ToArray()); + break; + default: + throw new Exception("make_tensor_proto Not Implemented"); + } + } + else + { + var len = nd.dtypesize * nd.size; + byte[] bytes = nd.ToByteArray(); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } + } + else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) + { + if (values is string str) + tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); + else if (values is string[] str_values) + tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); + else if (values is byte[] byte_values) + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); + } + else if (values is Array array) + { + // array + var len = dtype.get_datatype_size() * (int)shape.size; + byte[] bytes = new byte[len]; + System.Buffer.BlockCopy(array, 0, bytes, 0, len); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } + else + { + switch (values) + { + case Axis val: + tensor_proto.IntVal.AddRange(val.axis); + break; + case Shape val: + tensor_proto.Int64Val.AddRange(val.dims); + break; + case bool val: + tensor_proto.BoolVal.AddRange(new[] { val }); + break; + case sbyte val: + tensor_proto.IntVal.AddRange(new[] { (int)val }); + break; + case byte val: + tensor_proto.IntVal.AddRange(new[] { (int)val }); + break; + case int val: + tensor_proto.IntVal.AddRange(new[] { val }); + break; + case long val: + tensor_proto.Int64Val.AddRange(new[] { val }); + break; + case float val: + tensor_proto.FloatVal.AddRange(new[] { val }); + break; + case double val: + tensor_proto.DoubleVal.AddRange(new[] { val }); + break; + default: + throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}"); + } } return tensor_proto; } + + public static Shape constant_value_as_shape(Tensor tensor) + { + bool hasattr(Graph property, string attr) + { + var t = property.GetType().GetProperties(); + foreach (System.Reflection.PropertyInfo pi in t) + { + if (pi.Name == attr) + return true; + } + return false; + } + + if (tensor is EagerTensor eagerTensor) + { + if (tensor.dtype == tf.int64) + return new Shape(tensor.ToArray()); + else + return new Shape(tensor.ToArray()); + } + + if (tensor.shape.ndim == 0) + { + var value_ = constant_value(tensor); + if (value_ == null) + throw new ValueError( + @"Received a scalar with unknown value as shape; require a statically +known scalar with value '-1' to describe an unknown shape."); + if ((int)value_ != -1) + throw new ValueError( + String.Format(@"Received a scalar value {0} as shape; require a statically known +scalar with value '-1' to describe an unknown shape.", value_)); + return tensor.shape.unknown_shape(-1); + } + + var shape = tensor.shape.with_rank(1); + if (shape == new Shape(new int[] { 1 })) + { + return new Shape(new int[] { }); + } + else if (tensor.op.type == "Cast") + { + var pre_cast = constant_value_as_shape(tensor.op.inputs[0]); + if (pre_cast.dims == null) + return pre_cast; + var cast_dtype = dtypes.as_tf_dtype((Type)tensor.op.get_attr("DstT")); + if (!Array.Exists(new[] { dtypes.int32, dtypes.int64 }, cast_dtype_ => cast_dtype_ == cast_dtype)) + return tensor.shape.unknown_shape((int)shape.dims[0]); + + long[] x_ = { }; + foreach (var x in pre_cast.dims) + if (x != -1) + x_[x_.Length] = x; + else + x_[x_.Length] = -1; + var dest_dtype_shape_array = np.array(x_).astype(cast_dtype); + + long[] y_ = { }; + foreach (int y in dest_dtype_shape_array.ToArray()) + if (y >= 0) + y_[y_.Length] = y; + else + y_[y_.Length] = -1; + return new Shape(y_); + } + else if (tensor.op.type == "Shape") + { + return tensor.op.inputs[0].shape; + } + else if (tensor.op.type == "Pack") + { + var ret_ = new Shape(new int[] { }); + if ((int)tensor.op.get_attr("axis") != 0) + throw new ValueError(String.Format( + @"Since rank 1 inputs are expected, Pack's axis: {0} must be 0, otherwise it +would not be rank 1.", tensor.op.get_attr("axis"))); + foreach (Tensor pack_input in tensor.op.inputs) + { + var pack_input_val = (int)constant_value(pack_input); + Dimension new_dim; + if (pack_input_val < 0) + { + new_dim = new Dimension(-1); + } + else if (pack_input_val == null) + { + new_dim = new Dimension(-1); + } + else + { + new_dim = new Dimension(pack_input_val); + } + ret_ = ret_.concatenate(new long[] { new_dim }); + } + return ret_; + } + else if (tensor.op.type == "Concat") + { + var ret_ = new Shape(new int[] { }); + + var inputlist_ = new ArraySegment(tensor.op.inputs, 1, + tensor.op.inputs.Length - 1); + foreach (var concat_input in inputlist_) + { + ret_ = ret_.concatenate(constant_value_as_shape(concat_input)); + } + return ret_; + } + else if (tensor.op.type == "StridedSlice") + { + try + { + var begin = constant_value(tensor.op.inputs[1]); + var end = constant_value(tensor.op.inputs[2]); + var strides = constant_value(tensor.op.inputs[3]); + if (new[] { begin, end, strides }.All(x => x == null)) + { + begin = begin[0]; + end = end[0]; + strides = strides[0]; + var begin_mask = tensor.op.get_attr("begin_mask"); + if ((int)begin_mask == 1) + { + begin = null; + } + var end_mask = tensor.op.get_attr("end_mask"); + if ((int)end_mask == 1) + { + end = null; + } + + var ellipsis_mask = tensor.op.get_attr("ellipsis_mask"); + var new_axis_mask = tensor.op.get_attr("new_axis_mask"); + var shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask"); + + bool valid_attributes; + if (!(bool)ellipsis_mask && !(bool)new_axis_mask && + !(bool)shrink_axis_mask && !((bool)begin_mask || (int)begin_mask == 1) && + !((bool)end_mask || (int)end_mask == 1)) + { + valid_attributes = true; + } + else { valid_attributes = false; } + if (valid_attributes) + { + // sorry for the mess here, but this hacky solution was the best way + // i could come up with to implement the things done in python in c# + var prev_ = constant_value_as_shape(tensor.op.inputs[0]).dims; + var prev = prev_.Skip((int)begin).Take((int)end - (int)begin).ToArray(); + // 100 being the comparison doesn't really matter here; it's going to break anyway + for (int iter = 0; iter != 100; iter = iter + (int)strides) + { + prev[prev.Length] = prev_[iter]; + if ((iter + (int)strides) > prev_.Length) + break; + } + var ret_ = new Shape(prev); + return ret_; + } + } + } + catch (Exception ex) + { + if (ex is ValueError || ex is TypeError) { } + } + } + else if (tensor.op.type == "Placeholder" && + tensor.op.graph.building_function && + tensor.op.graph is FuncGraph func_graph) + { + int i = 0; + foreach (Tensor capture in func_graph.internal_captures) + { + if (capture.GetType() == typeof(Tensor)) + { + var external_capture = func_graph.external_captures[i]; + return constant_value_as_shape(external_capture); + } + + i++; + } + } + + var ret = tensor.shape.unknown_shape((int)shape.dims[0]); + var value = constant_value(tensor); + if (value is not null) + { + var d_ = new int[value.size]; + foreach (var (index, d) in enumerate(value.ToArray())) + d_[index] = d >= 0 ? d : -1; + + ret = ret.merge_with(new Shape(d_)); + } + return ret; + } + + public static TensorShapeProto as_shape(T[] dims) + { + TensorShapeProto shape = new TensorShapeProto(); + + for (int i = 0; i < dims.Length; i++) + { + var dim = new TensorShapeProto.Types.Dim(); + switch (dims[i]) + { + case int n: + dim.Size = n; + break; + case long l: + dim.Size = l; + break; + default: + throw new NotImplementedException("as_shape Not Implemented"); + } + // dim.Name = $"dim_{i}"; + + shape.Dim.Add(dim); + } + + return shape; + } + + public static Shape to_shape(long[] dims) + { + return new Shape(dims.Select(x => (int)x).ToArray()); + } + + public static Shape to_shape(int[] dims) + { + return new Shape(dims); + } + + public static TensorShapeProto as_shape_proto(this Shape tshape) + { + TensorShapeProto shape = new TensorShapeProto(); + + for (int i = 0; i < tshape.ndim; i++) + { + var dim = new TensorShapeProto.Types.Dim(); + dim.Size = tshape.dims[i]; + //dim.Name = $"dim_{i}"; + + shape.Dim.Add(dim); + } + + return shape; + } + + public static Shape reshape(this Shape shape, int[] dims) + { + return new Shape(dims); + } + + public static TensorShapeProto as_proto(this Shape tshape) + { + TensorShapeProto shape = new TensorShapeProto(); + + for (int i = 0; i < tshape.ndim; i++) + { + var dim = new TensorShapeProto.Types.Dim(); + dim.Size = tshape.dims[i]; + //dim.Name = $"dim_{i}"; + + shape.Dim.Add(dim); + } + + return shape; + } + + public static Tensor shape_tensor(int[] shape) + { + return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); + } + + public static ParsedSliceArgs ParseSlices(Slice[] slices) + { + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slices) + { + if (s.IsNewAxis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + new_axis_mask |= (1 << index); + } + else if (s.IsEllipsis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + ellipsis_mask |= (1 << index); + } + else + { + if (s.Start.HasValue) + { + begin.Add(s.Start.Value); + } + else + { + begin.Add(0); + begin_mask |= (1 << index); + } + + if (s.Stop.HasValue) + { + end.Add(s.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(s.Step); + if (s.IsIndex) + shrink_axis_mask |= (1 << index); + } + + index += 1; + } + + return new ParsedSliceArgs + { + Begin = begin.ToArray(), + End = end.ToArray(), + Strides = strides.ToArray(), + BeginMask = begin_mask, + EndMask = end_mask, + EllipsisMask = ellipsis_mask, + ShrinkAxisMask = shrink_axis_mask, + NewAxisMask = new_axis_mask + }; + } + + public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tensor step = null) + { + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + begin.Add(start); + + if (stop == null) + end.Add(start + 1); + else + end.Add(stop); + + shrink_axis_mask |= (1 << index); + + if (step == null) + strides.Add(tf.constant(1, dtype: start.dtype)); + else + strides.Add(step); + + return new ParsedSliceArgs + { + PackedBegin = array_ops.stack(begin), + PackedEnd = array_ops.stack(end), + PackedStrides = array_ops.stack(strides), + BeginMask = begin_mask, + EndMask = end_mask, + EllipsisMask = ellipsis_mask, + ShrinkAxisMask = shrink_axis_mask, + NewAxisMask = new_axis_mask + }; + } + + /// + /// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor + /// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse + /// the existing memory. Any other usage should be prevented. If you are sure you want to use it when + /// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first. + /// + /// + /// + internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype) + { + long tf_tensor_address = handle.DangerousGetHandle().ToInt64(); + long interface_address = *(long*)(tf_tensor_address); + long tensor_shape_address = interface_address + 8; + long tensor_dtype_address = tensor_shape_address + 13; + byte* dtype_pointer = (byte*)tensor_dtype_address; + *dtype_pointer = (byte)dtype; + Debug.Assert(c_api.TF_TensorType(handle) == dtype); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs new file mode 100644 index 000000000..ac26b3da3 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -0,0 +1,59 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor constant(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + Shape shape = null, + string name = "Const") + => constant_op.constant(value, + dtype: dtype, + shape: shape, + name: name, + verify_shape: false, + allow_broadcast: true); + + public Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.zeros(shape, dtype, name); + + public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.zeros(shape, dtype, name); + + public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.ones(shape, dtype, name); + + public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.ones(shape, dtype, name); + + public Tensor size(Tensor input, + string name = null, + TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, + name, + optimize: true, + out_type: out_type); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/AssetResource.cs b/src/TensorFlowNET.Core/Trackables/AssetResource.cs new file mode 100644 index 000000000..6e8d05a8c --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/AssetResource.cs @@ -0,0 +1,18 @@ +using Google.Protobuf.Collections; +using System.IO; +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class AssetResource : Trackable +{ + public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + string export_dir, + RepeatedField asset_file_def, + Dictionary> operation_attributes) + { + var proto = object_proto.Asset; + var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename); + return (new AssetResource(), null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/CapturableResource.cs b/src/TensorFlowNET.Core/Trackables/CapturableResource.cs new file mode 100644 index 000000000..d93f786dc --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/CapturableResource.cs @@ -0,0 +1,7 @@ +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class CapturableResource : Trackable +{ +} diff --git a/src/TensorFlowNET.Core/Trackables/RestoredResource.cs b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs new file mode 100644 index 000000000..cb9f6aa0b --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs @@ -0,0 +1,13 @@ +using Google.Protobuf.Collections; +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class RestoredResource : TrackableResource +{ + public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + Dictionary> operation_attributes) + { + return (new RestoredResource(), null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs new file mode 100644 index 000000000..d65446f3d --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs @@ -0,0 +1,34 @@ +using Google.Protobuf.Collections; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow.Trackables; + +public class TrackableConstant : Trackable +{ + Tensor _constant; + public TrackableConstant(Tensor constant) + { + _constant = constant; + } + + public static (Tensor, Action) deserialize_from_proto(SavedObject object_proto, + Dictionary> operation_attributes) + { + var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; + var ndarray = tensor_util.MakeNdarray(tensor_proto); + Tensor imported_constant; + if (tensor_proto.Dtype == DataType.DtString) + { + imported_constant = tf_with(ops.device("CPU"), _ => + { + return constant_op.constant(ndarray); + }); + } + else + { + imported_constant = constant_op.constant(ndarray); + } + return (imported_constant, null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/TrackableResource.cs b/src/TensorFlowNET.Core/Trackables/TrackableResource.cs new file mode 100644 index 000000000..43cbc5a20 --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/TrackableResource.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Trackables; + +public class TrackableResource : CapturableResource +{ +} diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs new file mode 100644 index 000000000..c64154e56 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs @@ -0,0 +1,174 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + /// + /// Optimizer that implements the Adam algorithm. + /// http://arxiv.org/abs/1412.6980 + /// + public class AdamOptimizer : Optimizer + { + float _beta1; + float _beta2; + float _epsilon; + Tensor _beta1_t, _beta2_t, _epsilon_t; + TF_DataType _dtype; + + public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") + : base(learning_rate, use_locking, name) + { + _beta1 = beta1; + _beta2 = beta2; + _epsilon = epsilon; + _dtype = dtype; + } + + public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") + : base(learning_rate, use_locking, name) + { + _beta1 = beta1; + _beta2 = beta2; + _epsilon = epsilon; + _dtype = dtype; + } + + public override Operation _apply_sparse(IndexedSlices grad, ResourceVariable var) + { + return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => + { + return state_ops.scatter_add(x, i, v, use_locking: _use_locking); + }); + } + + public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => + { + return state_ops.scatter_add(x, i, v, use_locking: _use_locking); + }); + } + + public override Operation _apply_dense(Tensor grad, ResourceVariable var) + { + var m = get_slot(var, "m"); + var v = get_slot(var, "v"); + var (beta1_power, beta2_power) = _get_beta_accumulators(); + return gen_training_ops.apply_adam( + var.Handle, + m.Handle, + v.Handle, + math_ops.cast(beta1_power.Handle, var.dtype.as_base_dtype()), + math_ops.cast(beta2_power.Handle, var.dtype.as_base_dtype()), + math_ops.cast(_lr_t, var.dtype.as_base_dtype()), + math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), + math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), + math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()), + grad, + use_locking: _use_locking).op; + } + + private Operation _apply_sparse_shared(Tensor grad, IVariableV1 var, Tensor indices, Func scatter_add) + { + var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); + Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype()); + Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype()); + var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype()); + var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype()); + var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype()); + var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()); + var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); + var m = get_slot(var, "m"); + var m_scaled_g_values = grad * (1 - beta1_t); + var m_t = state_ops.assign(m, m.AsTensor() * beta1_t, use_locking: _use_locking); + tf_with(ops.control_dependencies(new[] { m_t }), delegate + { + m_t = scatter_add(m, indices, m_scaled_g_values); + }); + + var v = get_slot(var, "v"); + var v_scaled_g_values = (grad * grad) * (1 - beta2_t); + var v_t = state_ops.assign(v, v.AsTensor() * beta2_t, use_locking: _use_locking); + tf_with(ops.control_dependencies(new[] { v_t }), delegate + { + v_t = scatter_add(v, indices, v_scaled_g_values); + }); + var v_sqrt = math_ops.sqrt(v_t); + var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking); + return control_flow_ops.group(new[] { var_update, m_t, v_t }); + } + + protected override void _create_slots(IVariableV1[] var_list) + { + var first_var = var_list.OrderBy(x => x.Name).First(); + _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); + _create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var); + + // Create slots for the first and second moments. + foreach (var v in var_list) + { + _zeros_slot(v, "m", Name); + _zeros_slot(v, "v", Name); + } + } + + public override Operation _finish(Operation[] update_ops, string name_scope) + { + var operations = new List(); + operations.AddRange(update_ops); + + tf_with(ops.control_dependencies(update_ops), delegate + { + var (beta1_power, beta2_power) = _get_beta_accumulators(); + ops.colocate_with(beta1_power); + var update_beta1 = beta1_power.assign(beta1_power.AsTensor() * _beta1_t, use_locking: _use_locking); + var update_beta2 = beta2_power.assign(beta2_power.AsTensor() * _beta2_t, use_locking: _use_locking); + + operations.Add(update_beta1); + operations.Add(update_beta2); + }); + + return control_flow_ops.group(operations.ToArray(), name: name_scope); + } + + private (IVariableV1, IVariableV1) _get_beta_accumulators() + { + ops.init_scope(); + var graph = ops.get_default_graph(); + return (_get_non_slot_variable("beta1_power", graph: graph), + _get_non_slot_variable("beta2_power", graph: graph)); + } + + public override void _prepare() + { + var lr = _call_if_callable(_lr); + var beta1 = _call_if_callable(_beta1); + var beta2 = _call_if_callable(_beta2); + var epsilon = _call_if_callable(_epsilon); + + _lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate", dtype: _dtype); + _beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1", dtype: _dtype); + _beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2", dtype: _dtype); + _epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon", dtype: _dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs new file mode 100644 index 000000000..20631ce82 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -0,0 +1,90 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Operations.Activation; +using Tensorflow.Training; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class AutoTrackable : Trackable + { + public void _delete_tracking(string name) + { + _maybe_initialize_trackable(); + if (_unconditional_dependency_names.ContainsKey(name)) + { + _unconditional_dependency_names.Remove(name); + for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) + { + if (_unconditional_checkpoint_dependencies[i].Name == name) + { + _unconditional_checkpoint_dependencies.RemoveAt(i); + } + } + } + } + + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with `self_setattr_tracking`. + value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + base.SetAttr(name, value); + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + if(save_type != SaveType.SAVEDMODEL) + { + return base._trackable_children(save_type, cache); + } + + Dictionary functions = new(); + // TODO: process of logs. + // TODO(Rinne): deal with members. + var properties = this.GetType().GetProperties(); + foreach ( var property in properties ) + { + if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) + { + string name = property.Name; + object value = property.GetValue(this, null); + functions[name] = (Trackable)value; + } + } + + foreach(var item in CustomizedFields) + { + var name = item.Key; + var value = item.Value; + if (value is Function or ConcreteFunction) + { + functions[name] = (Trackable)value; + } + } + + // TODO: process the type `core_types.GenericFunction`. + + Dictionary children = new(); + foreach(var pair in CheckpointDependencies) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) // or Generic function + { + continue; + } + if(functions.ContainsKey(name) && functions[name] != child) + { + throw new ValueError($"Can't save object because it has multiple children with the same " + + $"name. Object: {this}, attribute name: {name}, child 1: " + + $"{child}, child 2: {functions[name]}"); + } + children[name] = child; + } + + return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Checkpointable/CheckpointableBase.cs b/src/TensorFlowNET.Core/Training/Checkpointable/CheckpointableBase.cs new file mode 100644 index 000000000..c1738f7f3 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Checkpointable/CheckpointableBase.cs @@ -0,0 +1,9 @@ +using Tensorflow.Train; + +namespace Tensorflow +{ + public abstract class CheckpointableBase : AutoTrackable + { + + } +} diff --git a/src/TensorFlowNET.Core/Training/Coordinator.cs b/src/TensorFlowNET.Core/Training/Coordinator.cs new file mode 100644 index 000000000..b00ef3deb --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Coordinator.cs @@ -0,0 +1,9 @@ +namespace Tensorflow.Training +{ + /// + /// A coordinator for threads + /// + public class Coordinator + { + } +} diff --git a/src/TensorFlowNET.Core/Training/Distribute.cs b/src/TensorFlowNET.Core/Training/Distribute.cs new file mode 100644 index 000000000..3edc4761e --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Distribute.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public static class Distribute + { + public static VariableAggregationType get_loss_reduction() + { + return VariableAggregationType.MEAN; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs new file mode 100644 index 000000000..e3f454bc3 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs @@ -0,0 +1,77 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class ExponentialMovingAverage + { + float _decay; + int? _num_updates; + bool _zero_debias; + string _name; + public string name => _name; + Dictionary _averages; + + public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, + string name = "ExponentialMovingAverage") + { + _decay = decay; + _num_updates = num_updates; + _zero_debias = zero_debias; + _name = name; + _averages = new Dictionary(); + } + + /// + /// Maintains moving averages of variables. + /// + /// + /// + public Operation apply(RefVariable[] var_list = null) + { + if (var_list == null) + var_list = variables.trainable_variables() as RefVariable[]; + + foreach (var var in var_list) + { + if (!_averages.ContainsKey(var)) + { + ops.init_scope(); + var slot_creator = new SlotCreator(); + var value = var.initialized_value(); + var avg = slot_creator.create_slot(var, + value, + name, + colocate_with_primary: true); + ops.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var); + _averages[var] = avg; + } + else + { + // avg = slot_creator.create_zeros_slot( + throw new NotImplementedException(""); + } + } + + return tf_with(ops.name_scope(name), scope => + { + var decay = ops.convert_to_tensor(_decay, name: "decay"); + if (_num_updates.HasValue) + { + throw new NotImplementedException("ExponentialMovingAverage.apply"); + } + + var updates = new List(); + foreach (var var in var_list) + { + var zero_debias = false;// _averages[var] in zero_debias_true + var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias); + updates.Add(ama); + } + + return control_flow_ops.group(updates.ToArray(), name: scope); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/GateGradientType.cs b/src/TensorFlowNET.Core/Training/GateGradientType.cs new file mode 100644 index 000000000..cdb1d3964 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/GateGradientType.cs @@ -0,0 +1,9 @@ +namespace Tensorflow +{ + public enum GateGradientType + { + GATE_NONE = 0, + GATE_OP = 1, + GATE_GRAPH = 2 + } +} diff --git a/src/TensorFlowNET.Core/Training/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Training/GradientDescentOptimizer.cs new file mode 100644 index 000000000..9173d6baa --- /dev/null +++ b/src/TensorFlowNET.Core/Training/GradientDescentOptimizer.cs @@ -0,0 +1,64 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Train +{ + /// + /// Optimizer that implements the gradient descent algorithm. + /// + public class GradientDescentOptimizer : Optimizer + { + private bool _useTensor; + + /// + /// Construct a new gradient descent optimizer. + /// + /// A Tensor or a floating point value. The learning + /// rate to use. + /// If true use locks for update operations. + /// Optional name prefix for the operations created when applying + /// gradients.Defaults to "GradientDescent". + /// + /// When eager execution is enabled, `learning_rate` can be a callable that + /// takes no arguments and returns the actual value to use.This can be useful + /// for changing these values across different invocations of optimizer + /// functions. + /// + public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") + : base(learning_rate, use_locking, name) + { + _lr = learning_rate; + _useTensor = false; + } + + public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") + : base(learning_rate, use_locking, name) + { + _lr_t = learning_rate; + _useTensor = true; + } + + public override void _prepare() + { + if (!_useTensor) + { + var lr = _call_if_callable(_lr); + _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); + } + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/IWithTrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs new file mode 100644 index 000000000..87eda8795 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + public interface IWithTrackable + { + Trackable GetTrackable(); + } +} diff --git a/src/TensorFlowNET.Core/Training/LayerUtils.cs b/src/TensorFlowNET.Core/Training/LayerUtils.cs new file mode 100644 index 000000000..211419651 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/LayerUtils.cs @@ -0,0 +1,9 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + +} diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs new file mode 100644 index 000000000..e656fe96d --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -0,0 +1,479 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Base class for optimizers. + /// This class defines the API to add Ops to train a model. You never use this + /// class directly, but instead instantiate one of its subclasses such as + /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. + /// + public abstract class Optimizer : Trackable + { + // Values for gate_gradients. + public static int GATE_NONE = 0; + public static int GATE_OP = 1; + public static int GATE_GRAPH = 2; + + string _name; + public string Name => _name; + protected float _lr; + public float LearningRate => _lr; + protected Tensor _lr_t; + public Tensor LearningRateTensor => _lr_t; + public bool _use_locking; + public Dictionary> _slots; + public Dictionary _non_slot_dict; + public Dictionary _deferred_slot_restorations; + SlotCreator slot_creator = new SlotCreator(); + + public Optimizer(float learning_rate, bool use_locking, string name = null) + { + if (String.IsNullOrEmpty(name)) + throw new NotImplementedException("Must specify the optimizer name"); + + _name = name; + _use_locking = use_locking; + _lr = learning_rate; + // Dictionary of slots. + _slots = new Dictionary>(); + _non_slot_dict = new Dictionary(); + _deferred_slot_restorations = new Dictionary(); + } + + public Optimizer(Tensor learning_rate, bool use_locking, string name = null) + { + if (String.IsNullOrEmpty(name)) + throw new NotImplementedException("Must specify the optimizer name"); + + _name = name; + _use_locking = use_locking; + _lr_t = learning_rate; + // Dictionary of slots. + _slots = new Dictionary>(); + _non_slot_dict = new Dictionary(); + _deferred_slot_restorations = new Dictionary(); + } + + /// + /// Add operations to minimize `loss` by updating `var_list` + /// + /// This method simply combines calls `compute_gradients()` and + /// `apply_gradients()`. If you want to process the gradient before applying + /// them call `compute_gradients()` and `apply_gradients()` explicitly instead + /// of using this function. + /// + /// A `Tensor` containing the value to minimize. + /// Optional `Variable` to increment by one after the + /// variables have been updated. + /// Optional list or tuple of `Variable` objects to update to + /// minimize `loss`. Defaults to the list of variables collected in + /// the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + /// + /// How to gate the computation of gradients. Can be + /// `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + /// + /// + /// Specifies the method used to combine gradient terms. + /// Valid values are defined in the class `AggregationMethod`. + /// + /// + /// Optional name for the returned operation. + /// Optional. A `Tensor` holding the gradient computed for `loss`. + /// + /// An Operation that updates the variables in `var_list`. If `global_step` + /// was not `None`, that operation also increments `global_step`. + /// + public Operation minimize(Tensor loss, + IVariableV1 global_step = null, + List var_list = null, + GateGradientType gate_gradients = GateGradientType.GATE_OP, + int? aggregation_method = null, + bool colocate_gradients_with_ops = false, string name = null, Tensor grad_loss = null) + { + // TODO: strongly type aggregation_method + var grads_and_vars = compute_gradients(loss, var_list: var_list, + gate_gradients: gate_gradients, + aggregation_method: aggregation_method, + colocate_gradients_with_ops: colocate_gradients_with_ops, + grad_loss: grad_loss); + + var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray(); + if (vars_with_grad.Length == 0) + throw new ValueError($"No gradients provided for any variable, check your graph for ops" + + $" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.Name))} and loss {loss}."); + + return apply_gradients(grads_and_vars, global_step: global_step, name: name); + } + + /// + /// Apply gradients to variables. + /// + /// This is the second part of `minimize()`. It returns an `Operation` that + /// applies gradients. + /// + /// List of (gradient, variable) pairs as returned by + /// `compute_gradients()`. + /// Optional `Variable` to increment by one after the + /// variables have been updated. + /// Optional name for the returned operation. Default to the + /// name passed to the `Optimizer` constructor. + /// + /// An `Operation` that applies the specified gradients. If `global_step` + /// was not None, that operation also increments `global_step`. + public Operation apply_gradients(Tuple[] grads_and_vars, IVariableV1 global_step = null, string name = null) + { + // No DistributionStrategy case. + var converted_grads_and_vars = new List<(Tensor, IVariableV1, _OptimizableVariable)>(); + foreach (var (g, v) in grads_and_vars) + { + if (g != null) + { + // Convert the grad to Tensor or IndexedSlices if necessary. + var gR = ops.convert_to_tensor_or_indexed_slices(g); + var p = optimizer._get_processor(v as ResourceVariable); + converted_grads_and_vars.Add((gR, v, p)); + } + } + + var var_list = converted_grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray(); + if (var_list.Length == 0) + throw new ValueError($"No gradients provided for any variable"); + + ops.init_scope(); + _create_slots(var_list); + + var update_ops = new List(); + return tf_with(ops.name_scope(name, Name), scope => + { + name = scope; + _prepare(); + + foreach (var (grad, var, processor) in converted_grads_and_vars) + { + if (grad == null) + continue; + + var scope_name = var.Op.name; + tf_with(ops.name_scope("update_" + scope_name), scope2 => + { + var op = processor.update_op(this, grad); + update_ops.Add(op); + }); + } + + Operation apply_updates = null; + if (global_step == null) + { + apply_updates = _finish(update_ops.ToArray(), name); + } + else + { + tf_with(ops.control_dependencies(new object[] { _finish(update_ops.ToArray(), "update") }), dep => + { + // ops.colocate_with(global_step); + // TODO: port this if branch once ResourceVariable has been ported! + //if (global_step is ResourceVariable) + //{ + // # TODO(apassos): the implicit read in assign_add is slow; consider + // # making it less so. + // apply_updates = resource_variable_ops.assign_add_variable_op( + // global_step.handle, + // ops.convert_to_tensor(1, dtype = global_step.dtype), + // name = name) + //} + //else + { + apply_updates = state_ops.assign_add(global_step, + ops.convert_to_tensor(1, dtype: global_step.dtype), + name: name); + } + }); + } + + if (!tf.Context.executing_eagerly()) + { + var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP); + if (train_op != null && train_op.Contains(apply_updates)) + train_op.Add(apply_updates); + } + + return apply_updates; + }); + } + + /// + /// Create the beta1 and beta2 accumulators on the same device as the first + /// variable. Sort the var_list to make sure this device is consistent across + /// workers (these need to go on the same PS, otherwise some updates are + /// silently ignored). + /// + /// + protected virtual void _create_slots(IVariableV1[] var_list) + { + + } + + /// + /// Add an extra variable, not associated with a slot. + /// + /// + /// + /// + protected IVariableV1 _create_non_slot_variable(float initial_value, string name, IVariableV1 colocate_with) + { + // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. + var graph = colocate_with.Graph; + var key = $"{name}.{graph.graph_key}"; + var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + if (v == null) + { + _maybe_initialize_trackable(); + v = variable_scope.default_variable_creator( + initial_value, + name: name, + dtype: colocate_with.dtype.as_base_dtype(), + trainable: false, + use_resource: resource_variable_ops.is_resource_variable( + colocate_with)); + + // Restore this variable by name if necessary, but don't add a + // Trackable dependency. Optimizers return the current graph's + // non-slot variables from _checkpoint_dependencies explicitly rather + // than unconditionally adding dependencies (since there may be multiple + // non-slot variables with the same name in different graphs, trying to + // save all of them would result in errors). + _handle_deferred_dependencies(name, v); + _non_slot_dict[key] = v; + } + + return v; + } + + public virtual Operation _finish(Operation[] update_ops, string name_scope) + { + return control_flow_ops.group(update_ops, name_scope); + } + + public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) + { + if (tf.executing_eagerly()) + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.resource_apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; + } + else + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; + } + } + + public virtual Operation _apply_dense(Tensor grad, RefVariable var) + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; + } + + /// + /// Add ops to apply sparse gradients to `var`, with repeated sparse indices. + /// + /// + /// + /// + public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var) + { + var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); + var gradient_no_duplicate_indices = new IndexedSlices( + indices: unique_indices, + values: summed_values, + dense_shape: grad.dense_shape); + return _apply_sparse(gradient_no_duplicate_indices, var); + } + + public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var) + { + var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); + var gradient_no_duplicate_indices = new IndexedSlices( + indices: unique_indices, + values: summed_values, + dense_shape: grad.dense_shape); + return _apply_sparse(gradient_no_duplicate_indices, var); + } + + public virtual Operation _apply_sparse(IndexedSlices grad, ResourceVariable var) + { + throw new NotImplementedException("_apply_sparse"); + } + + public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + throw new NotImplementedException("_apply_sparse"); + } + + public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) + { + var (unique_indices, new_index_positions) = array_ops.unique(indices); + var shape = array_ops.shape(unique_indices).slice(0); + var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape); + return (summed_values, unique_indices); + } + + public virtual void _prepare() + { + + } + + /// + /// Return a slot named `name` created for `var` by the Optimizer. + /// + /// + /// + /// + internal IVariableV1 get_slot(IVariableV1 var, string name) + { + var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; + if (named_slots == null) + return null; + + return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; + } + + internal IEnumerable get_slot_names() + { + return _slots.Keys; + } + + private string _var_key(IVariableV1 var) + { + return $"{var.Op.graph.graph_key}.{var.Op.name}"; + } + + protected IVariableV1 _get_non_slot_variable(string name, Graph graph = null) + { + var key = $"{name}.{graph.graph_key}"; + var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + + return non_slot; + } + + /// + /// Compute gradients of `loss` for the variables in `var_list`. + /// + /// + /// + /// + /// A list of (gradient, variable) pairs. Variable is always present, but + /// gradient can be `None`. + /// + public Tuple[] compute_gradients(Tensor loss, + List var_list = null, + int? aggregation_method = null, + GateGradientType gate_gradients = GateGradientType.GATE_OP, + bool colocate_gradients_with_ops = false, + Tensor grad_loss = null) + { + // Scale loss if using a "mean" loss reduction and multiple replicas. + loss = _scale_loss(loss); + + if (var_list == null) + { + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var tmp = variables.trainable_variables(); + var_list = (tmp as List).Concat(vars).ToList(); + } + + var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); + var processors = var_list.Select(v => optimizer._get_processor(v as ResourceVariable)).ToList(); + var var_refs = processors.Select(x => x.target()).ToArray(); + + var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss }, + gate_gradients: gate_gradients == GateGradientType.GATE_OP, + aggregation_method: aggregation_method, + colocate_gradients_with_ops: colocate_gradients_with_ops); + + if ((int)gate_gradients == Optimizer.GATE_GRAPH) + grads = control_flow_ops.tuple(grads); + + var grads_and_vars = zip(grads, var_list) + .Select(x => new Tuple(x.Item1, x.Item2)) + .ToArray(); + + return grads_and_vars; + } + + private Tensor _scale_loss(Tensor loss_value) + { + ops.get_default_graph()._is_loss_scaled_by_optimizer = false; + // TODO + // if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: + return loss_value; + } + + protected T _call_if_callable(T param) + { + return param; + } + + /// + /// Find or create a slot initialized with 0.0. + /// + /// + /// + /// + /// + protected IVariableV1 _zeros_slot(IVariableV1 var, string slot_name, string op_name) + { + var named_slots = _slot_dict(slot_name); + if (!named_slots.ContainsKey(_var_key(var))) + { + var new_slot_variable = slot_creator.create_zeros_slot(var, op_name); + _restore_slot_variable(slot_name: slot_name, variable: var, slot_variable: new_slot_variable); + named_slots[_var_key(var)] = new_slot_variable; + } + return named_slots[_var_key(var)]; + } + + /// + /// Restore a newly created slot variable's value. + /// + protected void _restore_slot_variable(string slot_name, IVariableV1 variable, IVariableV1 slot_variable) + { + var variable_key = _var_key(variable); + // TODO + } + + protected Dictionary _slot_dict(string slot_name) + { + var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null; + if (named_slots == null) + { + named_slots = new Dictionary(); + _slots[slot_name] = named_slots; + } + + return named_slots; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/QueueRunner.cs b/src/TensorFlowNET.Core/Training/QueueRunner.cs new file mode 100644 index 000000000..30d3af5fd --- /dev/null +++ b/src/TensorFlowNET.Core/Training/QueueRunner.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Queues; + +namespace Tensorflow.Train +{ + /// + /// Holds a list of enqueue operations for a queue, each to be run in a thread. + /// + public class QueueRunner + { + public QueueRunner(QueueBase queue, Operation[] enqueue_ops) + { + + } + + + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs new file mode 100644 index 000000000..e16f82c05 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs @@ -0,0 +1,232 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class BaseSaverBuilder + { + protected SaverDef.Types.CheckpointFormatVersion _write_version; + + public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) + { + _write_version = write_version; + } + + /// + /// Create an Op to save 'saveables'. + /// + /// + /// + /// + public virtual Operation save_op(Tensor filename_tensor, MySaveableObject[] saveables) + { + var tensor_names = new List(); + var tensors = new List(); + var tensor_slices = new List(); + + foreach (var saveable in saveables) + { + foreach (var spec in saveable.specs) + { + tensor_names.Add(spec.name); + tensors.Add(spec.tensor); + tensor_slices.Add(spec.slice_spec); + } + } + + if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) + { + return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); + } + else + { + throw new NotImplementedException("_write_version v1"); + } + } + + public virtual Tensor[] bulk_restore(Tensor filename_tensor, MySaveableObject[] saveables, int preferred_shard, bool restore_sequentially) + { + var names = new List(); + var slices = new List(); + var dtypes = new List(); + foreach (var saveable in saveables) + foreach (var spec in saveable.specs) + { + names.Add(spec.name); + slices.Add(spec.slice_spec); + dtypes.Add(spec.dtype); + } + + return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); + } + + public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + float keep_checkpoint_every_n_hours = 10000, + string name = null, + bool restore_sequentially = false, + string filename = "model", + bool build_save = true, + bool build_restore = true) + { + if (!build_save || !build_restore) + throw new ValueError("save and restore operations need to be built together " + + " when eager execution is not enabled."); + + var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables); + + if (max_to_keep < 0) + max_to_keep = 0; + + Tensor save_tensor = null; + Operation restore_op = null; + + return tf_with(ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => + { + name = scope; + + // Add a placeholder string tensor for the filename. + var filename_tensor = array_ops.placeholder_with_default(tf.convert_to_tensor(string.IsNullOrEmpty(filename) ? "model" : filename), shape: new int[0], name: "filename"); + // Keep the name "Const" for backwards compatibility. + filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); + + // Add the save ops. + if (sharded) + { + + } + else + { + if (build_save) + save_tensor = _AddSaveOps(filename_tensor, saveables); + + if (build_restore) + restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); + } + + var graph = ops.get_default_graph(); + // Do some sanity checking on collections containing + // PartitionedVariables. If a saved collection has a PartitionedVariable, + // the GraphDef needs to include concat ops to get the value (or there'll + // be a lookup error on load). + var check_collection_list = graph.get_all_collection_keys(); + foreach (var collection_type in check_collection_list) + { + /*var cols = graph.get_collection(collection_type); + switch (cols) + { + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + default: + throw new NotImplementedException("_build_internal.check_collection_list"); + }*/ + + } + + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + MaxToKeep = max_to_keep, + Sharded = sharded, + KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours, + Version = _write_version + }; + }); + } + + public Tensor _AddSaveOps(Tensor filename_tensor, MySaveableObject[] saveables) + { + var save = save_op(filename_tensor, saveables); + return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); + } + + /// + /// Add operations to restore saveables. + /// + /// + /// + /// + /// + /// + /// + /// An Operation that restores the variables. + public Operation _AddRestoreOps(Tensor filename_tensor, + MySaveableObject[] saveables, + bool restore_sequentially, + bool reshape, + int preferred_shard = -1, + string name = "restore_all") + { + var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially); + var assign_ops = new List(); + int idx = 0; + + // Load and optionally reshape on the CPU, as string tensors are not + // available on the GPU. + // TODO(touts): Re-enable restore on GPU when we can support annotating + // string tensors as "HostMemory" inputs. + foreach (var saveable in saveables) + { + List shapes = null; + if (reshape) + { + throw new NotImplementedException("_AddRestoreOps"); + } + + var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); + idx += saveable.specs.Length; + var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()); + assign_ops.Add(restored); + } + + return control_flow_ops.group(assign_ops.ToArray(), name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/BulkSaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/BulkSaverBuilder.cs new file mode 100644 index 000000000..27f89f382 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/BulkSaverBuilder.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder + { + public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/ISaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/ISaverBuilder.cs new file mode 100644 index 000000000..c3275dd25 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/ISaverBuilder.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public interface ISaverBuilder + { + Operation save_op(Tensor filename_tensor, MySaveableObject[] saveables); + + Tensor[] bulk_restore(Tensor filename_tensor, MySaveableObject[] saveables, int preferred_shard, bool restore_sequentially); + + SaverDef _build_internal(IVariableV1[] names_to_saveables, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + float keep_checkpoint_every_n_hours = 10000, + string name = null, + bool restore_sequentially = false, + string filename = "model", + bool build_save = true, + bool build_restore = true); + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/ReferenceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ReferenceVariableSaveable.cs new file mode 100644 index 000000000..963227f07 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/ReferenceVariableSaveable.cs @@ -0,0 +1,31 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public class ReferenceVariableSaveable : MySaveableObject + { + private SaveSpec _spec; + + public ReferenceVariableSaveable(Tensor var, string slice_spec, string name) + { + _spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype); + op = var; + specs = new SaveSpec[] { _spec }; + this.name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs new file mode 100644 index 000000000..587dede40 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -0,0 +1,84 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class ResourceVariableSaveable : MySaveableObject + { + string _var_device; + int[] _var_shape; + Tensor handle_op; + + public ResourceVariableSaveable(Tensor var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + handle_op = var.op.inputs[0]; + var tensor = var; + var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + + op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + + public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + + Func _read_variable_closure(BaseResourceVariable v) + { + return () => + { + return tf_with(ops.device(v.Device), _ => + { + if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + return tf_with(ops.device("/device:CPU:0"), _ => + { + return array_ops.identity(x); + }); + }); + }; + } + + this.handle_op = var.Handle; + var tensor_creator = _read_variable_closure(var); + + var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device); + _op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + var restored_tensor = restored_tensors[0]; + return tf_with(ops.device(_var_device), _ => + { + restored_tensor = array_ops.identity(restored_tensor); + return resource_variable_ops.shape_safe_assign_variable_handle( + handle_op, _var_shape, restored_tensor); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs new file mode 100644 index 000000000..2b300c2a9 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -0,0 +1,85 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Exceptions; + +namespace Tensorflow +{ + /// + /// Class used to describe tensor slices that need to be saved. + /// + public class SaveSpec + { + private Tensor _tensor = null; + private Func _tensor_creator = null; + public Tensor tensor + { + get + { + if(_tensor is not null || _tensor_creator is null) + { + return _tensor; + } + else + { + return _tensor_creator(); + } + } + } + + internal Func TensorCreator => _tensor_creator; + + private string _slice_spec; + public string slice_spec => _slice_spec; + + private string _name; + public string name { get => _name; set => _name = value; } + + private TF_DataType _dtype; + public TF_DataType dtype => _dtype; + private string _device; + public string device => _device; + + public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) + { + _tensor = tensor; + _slice_spec = slice_spec; + _name = name; + _dtype = dtype; + if(device is not null) + { + _device = device; + } + else + { + _device = tensor.Device; + } + } + + public SaveSpec(Func tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) + { + _tensor_creator = tensor_creator; + _slice_spec = slice_spec; + _name = name; + if(dtype == TF_DataType.DtInvalid || device is null) + { + throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided."); + } + _dtype = dtype; + _device = device; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs new file mode 100644 index 000000000..f8c979757 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -0,0 +1,104 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using OneOf; +using Tensorflow.Checkpoint; + +namespace Tensorflow +{ + public class MySaveableObject + { + protected OneOf _op; + public Tensor op + { + get + { + if(_op.TryPickT0(out var tensor, out var _)) + { + return tensor; + } + else + { + throw new TypeError("The _op is not a tensor."); + } + } + set + { + _op = value; + } + } + public BaseResourceVariable variable + { + get + { + if (_op.TryPickT1(out var v, out var _)) + { + return v; + } + else + { + throw new TypeError("The _op is not a variable."); + } + } + set + { + _op = value; + } + } + public SaveSpec[] specs; + public string name; + public string device; + + public MySaveableObject() + { + + } + + public MySaveableObject(Tensor var, string slice_spec, string name) + { + + } + + public MySaveableObject(Tensor op, SaveSpec[] specs, string name) + { + this._op = op; + this.specs = specs; + this.name = name; + } + + public virtual Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + var restored_tensor = restored_tensors[0]; + return gen_state_ops.assign(op, + restored_tensor, + validate_shape: restored_shapes == null && op.shape.IsFullyDefined); + } + } + + public class NoRestoreSaveable: MySaveableObject + { + public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, + new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) + { + + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + return control_flow_ops.no_op(); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs new file mode 100644 index 000000000..d10257822 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Tensorflow; + +public record class AssetInfo +( + List asset_defs, + Dictionary asset_initializers_by_resource, + Dictionary asset_filename_map, + Dictionary asset_index +); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs new file mode 100644 index 000000000..9d0b3f001 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -0,0 +1,129 @@ +using System; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow; + +public class AugmentedGraphView: ObjectGraphView +{ + private Dictionary> _children_cache; + private Dictionary> _serialization_cache; + private List _untraces_functions; + private Dictionary _wrapped_functions; + public AugmentedGraphView(Trackable root): base(root) + { + _children_cache= new Dictionary>(); + _serialization_cache = new Dictionary>(); + _untraces_functions = new List(); + _wrapped_functions = new Dictionary(); + } + + public void set_signature(SignatureMap signature_map, IDictionary wrapped_functions) + { + list_children(Root); + var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; + if (!_children_cache.ContainsKey(Root)) + { + _children_cache[Root] = new Dictionary(); + } + _children_cache[Root][name] = signature_map; + _wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); + } + + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary>? serialization_cache = null) + { + if(serialization_cache is not null) + { + throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); + } + + if (!_children_cache.ContainsKey(obj)) + { + Dictionary children = new Dictionary(); + _children_cache[obj] = children; + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) + { + child = maybe_uncache_variable_captures((ConcreteFunction)child); + } + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + } + + List res = new(); + foreach(var pair in _children_cache[obj]) + { + res.Add(new TrackableReference(pair.Key, pair.Value)); + } + + return res; + } + + private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) + { + if (_wrapped_functions.ContainsKey(concrete_function)) + { + return _wrapped_functions[concrete_function]; + } + // skip the process here because of lack of feature. + // In the future, we may add an attribute which could specify if the variable is supposed to be cached. + //foreach(var capture in concrete_function.CapturedInputs) + //{ + + //} + return concrete_function; + } + + public override (IList, IDictionary>) breadth_first_traversal() + { + void merged_trackable(Trackable x) + { + // TODO: complete it with new definitions `Asset` and `TrackableConstant`. + } + + var trackable_objects = base.breadth_first_traversal(); + + foreach(var obj in _children_cache.Keys) + { + // skip the deletion of cache (maybe do it later). + foreach(var pair in _children_cache[obj]) + { + merged_trackable(pair.Value); + } + } + + return base.breadth_first_traversal(); + } + + public List<(string, Trackable)> list_dependencies(Trackable obj) + { + if (!_children_cache.TryGetValue(obj, out var children)) + { + children= new Dictionary(); + } + + List<(string, Trackable)> res = new(); + foreach(var pair in obj.deserialization_dependencies(children)) + { + res.Add((pair.Key, pair.Value)); + } + return res; + } + + public Trackable get_child(Trackable obj, string name) + { + return _children_cache[obj][name]; + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..726f6cfd4 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -0,0 +1,33 @@ +namespace Tensorflow; + +public static class Constants +{ + public static readonly string ASSETS_DIRECTORY = "assets"; + public static readonly string ASSETS_KEY = "saved_model_assets"; + + public static readonly string DEBUG_DIRECTORY = "debug"; + + public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; + + public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; + + public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; + + public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; + + public static readonly string MAIN_OP_KEY = "saved_model_main_op"; + + public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; + public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; + + public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; + + public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; + + public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; + + public static readonly string VARIABLES_DIRECTORY = "variables"; + public static readonly string VARIABLES_FILENAME = "variables"; +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs new file mode 100644 index 000000000..df9bdc1b5 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public record class LoadOptions + { + public bool allow_partial_checkpoint; + public string experimental_io_device; + public bool experimental_skip_checkpoint; + public VariablePolicy experimental_variable_policy; + + public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, + bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) + { + this.allow_partial_checkpoint = allow_partial_checkpoint; + this.experimental_io_device = experimental_io_device; + this.experimental_skip_checkpoint = experimental_skip_checkpoint; + this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs new file mode 100644 index 000000000..ab6adc30f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -0,0 +1,54 @@ +using System; +using System.Diagnostics; +using Tensorflow.Train; +using Tensorflow.Training; + +namespace Tensorflow; + +public class RevivedTypes +{ + private static Dictionary _registered_revived_creator = new(); + static RevivedTypes() + { + var list_wrapper = new ListWrapper(new Trackable[] { }); + _registered_revived_creator[list_wrapper.Identifier] = list_wrapper; + var dict_wrapper = new DictWrapper(new Dictionary()); + _registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper; + } + /// + /// Create a SavedUserObject from a trackable object. + /// + /// + /// + public static SavedUserObject? serialize(Trackable obj) + { + // TODO(Rinne): complete the implementation. + return null; + } + + public static (Trackable, Action) deserialize(SavedUserObject proto) + { + if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper)) + { + return (wrapper.FromProto(proto), (x, y, z) => + { + if (x is not ITrackableWrapper trackable) + { + throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}."); + } + Debug.Assert(y is string); + trackable.SetValue(y, z); + } + ); + } + else + { + return (null, null); + } + } + + public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj) + { + _registered_revived_creator[identifier] = obj; + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs new file mode 100644 index 000000000..d42f52535 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Options for saving to SavedModel. + /// + public class SaveOptions + { + public bool save_debug_info = false; + public IList? namespace_white_list { get; set; } = null; + public IDictionary? function_aliases { get; set; } = null; + public string? experimental_io_device { get; set; } = null; + // TODO: experimental + public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; + public bool experimental_custom_gradients { get; set; } = true; + public SaveOptions(bool save_debug_info = false) + { + this.save_debug_info = save_debug_info; + } + } + + public class VariablePolicy + { + public string Policy { get; } + private VariablePolicy(string policy) + { + Policy = policy; + } + public static VariablePolicy None = new(null); + public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); + public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); + + public bool save_variable_devices() + { + return this != None; + } + + /// + /// Tries to convert `obj` to a VariablePolicy instance. + /// + /// + /// + public static VariablePolicy from_obj(object obj) + { + if (obj is null) return None; + if (obj is VariablePolicy) return (VariablePolicy)obj; + var key = obj.ToString().ToLower(); + return key switch + { + null => None, + "save_variable_devices" => SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, + _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs new file mode 100644 index 000000000..8dd4f008f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tensorflow; + +public enum SaveType +{ + SAVEDMODEL, + CHECKPOINT +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs new file mode 100644 index 000000000..44a627b67 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow; + +public class SaveableView +{ + private AugmentedGraphView _augmented_graph_view; + private SaveOptions _options; + private IList _trackable_objects; + private List _nodes; + private IDictionary> _node_paths; + private IDictionary _node_ids; + private IDictionary> + _slot_variables; + private IDictionary _object_names; + private List _gradient_functions; // to be completed + private List _gradient_defs; // to be completed + private List _concrete_functions; + private Dictionary _captured_tensor_node_ids; + private Dictionary> _saveable_objects_map; + private Dictionary _obj_to_registered_saver; + + public AugmentedGraphView AugmentedGraphView + { + get => _augmented_graph_view; + } + + public Trackable Root + { + get => _nodes[0]; + } + public List Nodes + { + get => _nodes; + } + public IDictionary NodeIds + { + get => _node_ids; + } + public List GradientDefs + { + get => _gradient_defs; + } + public IDictionary> NodePaths + { + get => _node_paths; + } + public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) + { + _augmented_graph_view = augmented_graph_view; + _options = options; + + (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = + CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); + + // TODO: deal with untraced functions. + + initialize_save_and_restore_functions(); + initialize_nodes_and_concrete_functions(); + + _captured_tensor_node_ids = new(); + } + + private void initialize_save_and_restore_functions() + { + // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. + var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. + _obj_to_registered_saver = new(); + _saveable_objects_map = new(); + } + + private void initialize_nodes_and_concrete_functions() + { + _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy + _gradient_functions = new(); + _gradient_defs = new(); + + // TODO: deal with the condition that obj in `_saveable_objects_map`. + // foreach (var obj in _nodes) + // { + // + // } + + //_concrete_functions = new(); + //foreach (var obj in _nodes) + //{ + // if (obj is ConcreteFunction) + // { + // _concrete_functions.Add((ConcreteFunction)obj); + // } + //} + } + + public List get_concrete_resource_initializers() + { + // TODO: complete the implementation. + return new List(); + } + + public (Dictionary, Dictionary, AssetInfo) map_resources() + { + Debug.Assert(!tf.Context.executing_eagerly()); + + Dictionary object_map = new(); + Dictionary tensor_map = new(); + + AssetInfo assetInfo = new(new List(), new Dictionary(), + new Dictionary(), new Dictionary()); + + foreach (var node_id in dependency_sorted_node_ids()) + { + var obj = _nodes[node_id]; + var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); + // TODO: deal with Asset (if obj is Asset) + foreach (var tensor in tensors) + { + _captured_tensor_node_ids[tensor] = node_id; + } + } + + return (object_map, tensor_map, assetInfo); + } + + /// + /// Returns topologically sorted nodes, sorted by dependencies. + /// + public List dependency_sorted_node_ids() + { + Dictionary> dependency_map = new(); + foreach (var node in _nodes) + { + var node_id = _node_ids[node]; + List deps = new List(); + dependency_map.Add(node_id, deps); + + // TODO: deal with captured tensor. + + foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) + { + if (!_node_ids.ContainsKey(dep)) + { + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + throw new ValueError( + $"Found an untracked dependency. Object {node_path} depends on {dep}, " + + $"but this dependency isn't listed as a child. Please track this child by " + + $"overriding `_trackable_children` or use `._track_trackable`."); + } + deps.Add(_node_ids[dep]); + } + } + + try + { + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError err) + { + List pretty_printed_nodes = new(); + List pretty_printed_dependencies = new(); + + foreach (var pair in err.LeftOverDependencyMap) + { + var x = pair.Key; + var deps = pair.Value; + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); + pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); + pretty_printed_dependencies.Add( + $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); + } + + throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + + $"Saving cannot continue until this cycle is resolved." + + $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + + $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); + } + } + + /// + /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph + /// + /// + /// + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index) + { + SavedObjectGraph proto = new(); + fill_object_graph_proto(proto); + + // TODO: complete the process of concrete functions. + + int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + var obj = _nodes[i]; + var obj_proto = proto.Nodes[i]; + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); + } + + return proto; + } + + private static void write_object_proto(Trackable obj, SavedObject proto, + IDictionary asset_file_def_index, Func> list_children_fn) + { + // skip the process of type Asset + if (resource_variable_ops.is_resource_variable(obj)) + { + var options = SaveContext.get_save_options(); + (obj as BaseResourceVariable).write_object_proto(proto, options); + } + else if (obj is Function) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is ConcreteFunction) + { + // TODO(Rinne): complete it. + // throw new NotImplementedException(); + } + // skip the process of type `_CapturedTensor` and `CapturableResource`. + else + { + var registered_type_proto = RevivedTypes.serialize(obj); + if (registered_type_proto is null) + { + registered_type_proto = new SavedUserObject() + { + Identifier = obj.ObjectIdentifier, + Version = new VersionDef() + { + Producer = 1, + MinConsumer = 1, + BadConsumers = { } + } + }; + } + + proto.UserObject = new SavedUserObject(registered_type_proto); + } + + // TODO: try get the registered_name from `registration`. + } + + public void fill_object_graph_proto(SavedObjectGraph proto) + { + for (int node_id = 0; node_id < _nodes.Count; node_id++) + { + var node = _nodes[node_id]; + Debug.Assert(_node_ids[node] == node_id); + SavedObject object_proto = new(); + if (_slot_variables.TryGetValue(node, out var value)) + { + object_proto.SlotVariables.AddRange(value); + } + // skip the check of type `_CapturedTensor` + foreach (var child in _augmented_graph_view.list_children(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[child.Refer]; + child_proto.LocalName = child.Name; + object_proto.Children.Add(child_proto); + } + + foreach (var pair in _augmented_graph_view.list_dependencies(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[pair.Item2]; + child_proto.LocalName = pair.Item1; + object_proto.Dependencies.Add(child_proto); + } + + if (_saveable_objects_map.ContainsKey(node)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if(_obj_to_registered_saver.ContainsKey(node)) + { + // TODO: complete it. + // We now skip it for the lack of `SavedObject.registered_saver` API. + throw new NotImplementedException(); + } + + proto.Nodes.Add(object_proto); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs new file mode 100644 index 000000000..6aa1fbde1 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -0,0 +1,10 @@ +namespace Tensorflow; + +public static class TagConstants +{ + public static readonly string SERVING = "serve"; + public static readonly string TRAINING = "train"; + public static readonly string EVAL = "eval"; + public static readonly string GPU = "gpu"; + public static readonly string TPU = "tpu"; +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs new file mode 100644 index 000000000..695eadfd3 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Functions; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A class wraps a concrete function to handle different distributed contexts. + /// + internal class WrapperFunction: ConcreteFunction + { + public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) + { + throw new NotImplementedException(); + //this.forward_backward = concrete_function.forward_backward; + //this.Outputs = concrete_function.Outputs; + //this.ReturnType = concrete_function.ReturnType; + //this.OutputStructure = concrete_function.OutputStructure; + //this.ArgKeywords = concrete_function.ArgKeywords; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs new file mode 100644 index 000000000..dbbab91d8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BuilderUtils +{ + public static void copy_assets_to_destination_dir(IDictionary asset_filename_map, + string destination_dir, HashSet? saved_files = null) + { + if (saved_files is null) saved_files = new HashSet(); + + var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); + + // TODO: complete the implementation of this function. + if (asset_filename_map is not null && asset_filename_map.Count > 0) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs new file mode 100644 index 000000000..77b115a46 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -0,0 +1,494 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.RegularExpressions; +using Tensorflow.Framework; +using Tensorflow.Functions; +using Tensorflow.Gradients; +using Tensorflow.Graphs; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Training.Saving.SavedModel +{ + public static class function_deserialization + { + private static string _INFERENCE_PREFIX = "__inference_"; + private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$"; + /// + /// Creates a `Function` from a `SavedFunction`. + /// + /// + /// + /// + public static Function recreate_function(SavedFunction saved_function, + IDictionary concrete_functions) + { + var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); + + Tensor[] restored_function_body(Tensor[] inputs) + { + if(saved_function.ConcreteFunctions is null || saved_function.ConcreteFunctions.Count == 0) + { + throw new ValueError("Found zero restored functions for caller function."); + } + foreach(var function_name in saved_function.ConcreteFunctions) + { + var function = concrete_functions[function_name]; + if(function.CapturedInputs.Any(x => x is null)) + { + throw new ValueError("Looks like you are trying to run a loaded " + + "non-Keras model that was trained using tf.distribute.experimental.ParameterServerStrategy " + + "with variable partitioning, which is not currently supported. Try using Keras to define your model " + + "if possible."); + } + if(_concrete_function_callable_with(function, inputs, false)) + { + return _call_concrete_function(function, inputs); + } + } + throw new ValueError("Unexpected runtime behavior, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + + List concrete_function_objects = new(); + foreach(var concrete_function_name in saved_function.ConcreteFunctions) + { + concrete_function_objects.Add(concrete_functions[concrete_function_name]); + } + foreach(var cf in concrete_function_objects) + { + cf._set_function_spec(function_spec); + } + + var restored_function = new RestoredFunction(restored_function_body, nameof(restored_function_body), + function_spec, concrete_function_objects); + + return restored_function; + } + + public static Dictionary load_function_def_library(FunctionDefLibrary library, + SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null) + { + var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct(); + Dictionary functions = new(); + Dictionary renamed_functions = new(); + + Graph graph; + if (ops.executing_eagerly_outside_functions()) + { + graph = new Graph(); + } + else + { + graph = ops.get_default_graph(); + } + + if(load_shared_name_suffix is null) + { + load_shared_name_suffix = $"_load_{ops.uid()}"; + } + + Dictionary library_gradient_names = new(); + Dictionary new_gradient_op_types = new(); + Dictionary gradients_to_register = new(); + foreach (var gdef in library.RegisteredGradients) + { + if(gdef.RegisteredOpType is not null) + { + var new_op_type = custom_gradient.generate_name(); + var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType); + + library_gradient_names[old_op_type] = gdef.GradientFunc; + new_gradient_op_types[old_op_type] = new_op_type; + gradients_to_register[gdef.GradientFunc] = new_op_type; + } + } + + Dictionary> function_deps = new(); + foreach(var fdef in library.Function) + { + function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names); + } + + Dictionary loaded_gradients = new(); + foreach (var fdef in _sort_function_defs(library, function_deps)) + { + var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); + + object structured_input_signature = null; + object structured_outputs = null; + if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) + { + // TODO(Rinne): deal with structured_input_signature and structured_outputs. + + //var proto = saved_object_graph.ConcreteFunctions[orig_name]; + //structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); + //structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); + } + + graph.as_default(); + var func_graph = function_def_lib.function_def_to_graph(fdef, structured_input_signature, structured_outputs); + graph.Exit(); + + _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); + + foreach(var dep in function_deps[orig_name]) + { + functions[dep].AddTograph(func_graph); + } + + if (fdef.Attr.ContainsKey("_input_shapes")) + { + fdef.Attr.Remove("_input_shapes"); + } + var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value)); + if(wrapper_function is not null) + { + throw new NotImplementedException(); + } + func.AddTograph(graph); + + functions[orig_name] = func; + renamed_functions[func.Name] = func; + if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp")) + { + func.AddTograph(ops.get_default_graph()); + } + + if (gradients_to_register.ContainsKey(orig_name)) + { + var gradient_op_type = gradients_to_register[orig_name]; + loaded_gradients[gradient_op_type] = func; + ops.RegisterGradientFunction(gradient_op_type, _gen_gradient_func(func)); + } + } + return functions; + } + + public static void fix_node_def(NodeDef node_def, IDictionary functions, string shared_name_suffix) + { + if (functions.ContainsKey(node_def.Op)) + { + node_def.Op = functions[node_def.Op].Name; + } + foreach(var attr_value in node_def.Attr.Values) + { + if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) + { + attr_value.Func.Name = functions[attr_value.Func.Name].Name; + } + else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) + { + foreach(var fn in attr_value.List.Func) + { + fn.Name = functions[fn.Name].Name; + } + } + } + + if(node_def.Op == "HashTableV2") + { + if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B) + { + node_def.Attr["use_node_name_sharing"].B = true; + shared_name_suffix += $"_{ops.uid()}"; + } + } + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is not null) + { + var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault(); + if(attr is not null) + { + ByteString shared_name = null; + if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null) + { + shared_name = node_def.Attr["shared_name"].S; + } + else if(attr.DefaultValue.S is not null) + { + shared_name = tf.compat.as_bytes(attr.DefaultValue.S); + } + if(shared_name is null) + { + shared_name = tf.compat.as_bytes(node_def.Name); + } + node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray()); + } + } + } + + private static Func _gen_gradient_func(ConcreteFunction func) + { + return (unused_op, result_grads) => + { + result_grads = zip(result_grads, func.func_graph.Inputs) + .Select((item) => item.Item1 is null ? default_gradient.zeros_like(item.Item2) : item.Item1).ToArray(); + return func.CallFlat(result_grads, func.CapturedInputs); + }; + } + + private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary renamed_functions, Dictionary loaded_gradients) + { + if(loaded_gradients is null || loaded_gradients.Count == 0) + { + foreach (var op in func_graph.get_operations()) + { + if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") + { + var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; + op.op._gradient_function = function._get_gradient_function(); + } + } + } + else + { + foreach (var op in func_graph.get_operations()) + { + if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") + { + var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; + op.op._gradient_function = function._get_gradient_function(); + } + string gradient_op_type = null; + try + { + gradient_op_type = op.op.get_attr("_gradient_op_type") as string; + } + catch (InvalidArgumentError) + { + continue; + } + if (loaded_gradients.ContainsKey(gradient_op_type)) + { + var grad_fn = loaded_gradients[gradient_op_type]; + grad_fn.NumPositionArgs = op.op.inputs.Length; + grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); + } + } + } + } + + private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary functions, string shared_name_suffix, + IDictionary new_gradient_op_types) + { + var orig_name = fdef.Signature.Name; + bool contains_unsaved_custom_gradients = false; + + foreach(var node_def in fdef.NodeDef) + { + fix_node_def(node_def, functions, shared_name_suffix); + var op_type = _get_gradient_op_type(node_def); + if(op_type is not null) + { + if (new_gradient_op_types.ContainsKey(op_type)) + { + node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]); + } + else + { + contains_unsaved_custom_gradients = true; + } + } + } + if (contains_unsaved_custom_gradients) + { + // TODO(Rinne): log warnings. + } + + fdef.Signature.Name = _clean_function_name(fdef.Signature.Name); + return orig_name; + } + + private static string _clean_function_name(string name) + { + var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX); + if(match.Success) + { + return match.Groups[1].Value; + } + else + { + return name; + } + } + + /// + /// Return a topologic sort of FunctionDefs in a library. + /// + /// + /// + private static IEnumerable _sort_function_defs(FunctionDefLibrary library, Dictionary> function_deps) + { + Dictionary> edges = new(); + Dictionary in_count = new(); + foreach(var item in function_deps) + { + var fname = item.Key; + var deps = item.Value; + if(deps is null || deps.Count() == 0) + { + in_count[fname] = 0; + continue; + } + foreach(var dep in deps) + { + edges.SetDefault(dep, new List()).Add(fname); + if (in_count.ContainsKey(fname)) + { + in_count[fname]++; + } + else + { + in_count[fname] = 1; + } + } + } + var ready = new Stack(library.Function. + Where(x => in_count[x.Signature.Name] == 0) + .Select(x => x.Signature.Name).ToList()); + List output = new(); + while(ready.Count > 0) + { + var node = ready.Pop(); + output.Add(node); + if (!edges.ContainsKey(node)) + { + continue; + } + foreach(var dest in edges[node]) + { + in_count[dest] -= 1; + if (in_count[dest] == 0) + { + ready.Push(dest); + } + } + } + + if(output.Count != library.Function.Count) + { + var failed_to_resolve = in_count.Keys.Except(output); + throw new ValueError($"There is a cyclic dependency between functions. " + + $"Could not resolve ({string.Join(", ", failed_to_resolve)})."); + } + + var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x); + return output.Select(x => reverse[x]); + } + + private static IEnumerable _list_function_deps(FunctionDef fdef, IEnumerable library_function_names, IDictionary library_gradient_names) + { + HashSet deps = new HashSet(); + foreach(var node_def in fdef.NodeDef) + { + var grad_op_type = _get_gradient_op_type(node_def); + if (library_function_names.Contains(node_def.Op)) + { + deps.Add(node_def.Op); + } + else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name)) + { + deps.Add(gradient_name); + } + else + { + foreach(var attr_value in node_def.Attr.Values) + { + if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) + { + deps.Add(attr_value.Func.Name); + } + else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) + { + foreach(var fn in attr_value.List.Func) + { + deps.Add(fn.Name); + } + } + } + } + } + return deps.AsEnumerable(); + } + + private static ByteString _get_gradient_op_type(NodeDef node_def) + { + if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall") + { + return node_def.Attr["_gradient_op_type"].S; + } + return null; + } + + public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, + IDictionary concrete_functions) + { + var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; + concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); + concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; + + //var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + // TODO(Rinne): set the functiona spec. + concrete_function.AddTograph(); + return concrete_function; + } + + private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) + { + // TODO(Rinne); revise the implementation. + return new FunctionSpec() + { + Fullargspec = function_spec_proto.Fullargspec, + IsMethod = function_spec_proto.IsMethod, + InputSignature = function_spec_proto.InputSignature, + JitCompile = function_spec_proto.JitCompile + }; + } + + private static Tensors _call_concrete_function(ConcreteFunction function, Tensors inputs) + { + // TODO(Rinne): var expected_structure = function.func_graph.structured_input_signature + return function.CallFlat(inputs, function.CapturedInputs); + } + + private static bool _concrete_function_callable_with(ConcreteFunction function, Tensor[] inputs, bool allow_conversion) + { + // TODO(Rinne): revise it. + return function.CapturedInputs.Length + inputs.Length == function.Inputs.Length; + //var expected_inputs = function.func_graph.Inputs; + //foreach(var (arg, expected) in zip(inputs, expected_inputs)) + //{ + // if(arg.Id != expected.Id) + // { + // return false; + // } + //} + //return true; + } + } + + public class RestoredFunction : Function + { + IEnumerable _concrete_functions; + FunctionSpec _function_spec; + public IEnumerable ConcreteFunctions => _concrete_functions; + public RestoredFunction(Func function, string name, FunctionSpec function_spec, + IEnumerable concrete_functions): base(function, name, auto_graph: false) + { + _concrete_functions = concrete_functions; + _function_spec = function_spec; + } + + protected override bool _run_functions_eagerly() + { + return false; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs new file mode 100644 index 000000000..727d18a81 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -0,0 +1,700 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using System.Runtime.CompilerServices; +using Tensorflow.Variables; +using Tensorflow.Functions; +using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Trackables; +using OneOf; +using Tensorflow.Keras.Engine; + +namespace Tensorflow +{ + /// + /// Helper class to load an object-based SavedModel. + /// + public partial class Loader + { + private pbc::RepeatedField _asset_file_def; + private Dictionary> _operation_attributes; + private SavedObjectGraph _proto; + private string _export_dir; + private CheckpointOptions _checkpoint_options; + private LoadOptions _save_options; + private IDictionary)> _node_filters; + private Dictionary? _node_path_to_id; + private List? _filtered_nodes; + private List _ordered_node_ids; + private Dictionary)> _loaded_nodes; + private List _nodes; + private Dictionary> _node_setters; + private Dictionary _concrete_functions; + private HashSet _restored_concrete_functions; + public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, + CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary)> filters) + { + var meta_graph = saved_model_proto.MetaGraphs[0]; + _asset_file_def = meta_graph.AssetFileDef; + _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); + _proto = object_graph_proto; + _export_dir = export_dir; + // TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted. + _concrete_functions = function_deserialization.load_function_def_library( + meta_graph.GraphDef.Library, _proto); + _restored_concrete_functions = new HashSet(); + _checkpoint_options = ckpt_options; + _save_options = save_options; + + // TODO: `this._pretty_printer` + + _node_filters = filters; + _node_path_to_id = _convert_node_paths_to_ints(); + _loaded_nodes = new Dictionary)>(); + + if (filters != null) + { + foreach (var filter in filters) + { + _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; + } + } + + _filtered_nodes = _retrieve_all_filtered_nodes(); + + _ordered_node_ids = _generate_ordered_node_ids(); + + _load_all(); + + + if (!save_options.experimental_skip_checkpoint) + { + _restore_checkpoint(); + } + foreach(var node in _nodes) + { + // skip the process of `CapturableResource`. + } + } + + /// + /// Maps all string node paths in node_filters to the int node ids. + /// + /// + private Dictionary? _convert_node_paths_to_ints() + { + if( _node_filters is null) + { + return null; + } + Dictionary path_to_int = new(); + foreach(var node_id in _node_filters.Keys) + { + int int_node_id; + var node_path = node_id.Split('.'); + if (node_path[0] != "root") + { + throw new ValueError($"When passing string identifiers to node_filters, the first name" + + $" must be root. Received {node_path[0]}."); + } + int_node_id = 0; + for(int i = 0; i < node_path.Length - 1; i++) + { + var name = node_path[i + 1]; + int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); + } + path_to_int[node_id] = int_node_id; + } + return path_to_int; + } + + private int _find_node_child(int node_id, string child_name, string path) + { + foreach(var refer in _proto.Nodes[node_id].Children) + { + if(refer.LocalName == child_name) + { + return refer.NodeId; + } + } + throw new ValueError($"Unable to find node {path}."); + } + + private List? _retrieve_all_filtered_nodes() + { + if(_node_filters is null) + { + return null; + } + + HashSet all_filtered_nodes = new(); + Queue nodes_to_visit = new Queue(_node_filters.Keys); + + while(nodes_to_visit.Count > 0) + { + var node_path = nodes_to_visit.Dequeue(); + var node_id = _node_path_to_id[node_path]; + if (all_filtered_nodes.Contains(node_id)) + { + continue; + } + all_filtered_nodes.Add(node_id); + Trackable node = null; + Action setter = null; + if(_loaded_nodes.TryGetValue(node_id, out var res)) + { + (node, setter) = res; + } + if(node is not null) + { + node._maybe_initialize_trackable(); + } + + foreach(var refer in _proto.Nodes[node_id].Children) + { + Trackable children_object = null; + if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) + { + children_object = result.Item1; + } + // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. + if(children_object is null && node is not null) + { + children_object = node._lookup_dependency(refer.LocalName); + if(children_object is TrackableDataStructure) + { + // TODO: set setter as lambda. + + _loaded_nodes[refer.NodeId] = (children_object, setter); + } + } + string child_path = $"{node_path}.{refer.LocalName}"; + _node_path_to_id[child_path] = refer.NodeId; + nodes_to_visit.Enqueue(child_path); + } + } + + if (all_filtered_nodes.Contains(0)) + { + return null; + } + return all_filtered_nodes.ToList(); + } + + /// + /// Orders the node ids so that dependencies appear first. + /// + /// + private List _generate_ordered_node_ids() + { + List unordered_ids; + if(_filtered_nodes is null) + { + unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); + } + else + { + unordered_ids = new List(_filtered_nodes); + } + + Dictionary> dependency_map = new(); + foreach(var node_id in unordered_ids) + { + var deps = dependency_map.SetDefault(node_id, new List()); + if (_loaded_nodes.ContainsKey(node_id)) + { + continue; + } + var proto = _proto.Nodes[node_id]; + foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) + { + deps.Add(dep); + if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) + { + // TODO: add info with `_pretty_printer`. + throw new ValueError($"Unable to partially load SavedModel since the specified filter " + + $"does not include all required objects for loading (e.g. " + + $"variables used in functions or deserialization dependencies). " + + $"Please include this path in the filter: {dep}"); + } + } + int? prev_slot = null; + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + // The optimizer and original variable must be created before the slot + // variable, since the slot variable is generated using the Optimizer's + // add_slot API. + var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List()); + slot_deps.Add(node_id); + slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); + + if(prev_slot is not null) + { + slot_deps.Add(prev_slot.Value); + } + prev_slot = slot_variable_node_id; + } + } + try + { + int total = 0; + foreach(var v in dependency_map.Values) + { + total += v.Count; + } + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError ex) + { + throw new ValueError("Encountered a cycle in the deserialization dependencies" + + "in the SavedModel. This is extremely unexpected, please" + + "file a bug and make sure you are not manually modifying the SavedModel."); + } + } + + /// + /// Returns a dictionary of all dependencies of an object. + /// + /// + /// + private Dictionary, int> _get_node_dependencies(SavedObject proto) + { + Dictionary, int> dependencies = new(); + foreach(var refer in proto.Dependencies) + { + dependencies[refer.LocalName] = refer.NodeId; + } + if(proto.KindCase == SavedObject.KindOneofCase.Function) + { + var concreete_functions = proto.Function.ConcreteFunctions; + foreach(var fn_name in concreete_functions) + { + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) + { + var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.Resource) + { + foreach(var child in proto.Children) + { + if(child.LocalName == "_create_resource") + { + dependencies["_create_resource"] = child.NodeId; + } + } + } + return dependencies; + } + + /// + /// Loads all nodes and functions from the SavedModel and their edges. + /// + private void _load_all() + { + _load_nodes(); + _load_edges(); + + _setup_remaining_functions(); + _load_checkpoint_save_and_restore_functions(); + } + + /// + /// Restores the checkpoint-related save/restore functions to all nodes. + /// + private void _load_checkpoint_save_and_restore_functions() + { + foreach(var (node_id, proto) in _iter_all_nodes()) + { + var node = get(node_id); + if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + // Restore Trackable serialize- and restore-from-tensor functions. + Debug.Assert(proto.SaveableObjects.Count == 1); + var saveable_object_proto = proto.SaveableObjects.Values.First(); + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + else + { + // Restore legacy SaveableObject functions. + Dictionary saveable_fn_by_name = new(); + foreach(var item in proto.SaveableObjects) + { + var name = item.Key; + var saveable_object_proto = item.Value; + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id)); + } + var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); + if (saveable_objects is not null && saveable_objects.Count > 0) + { + if(node is Trackable trackable) + { + trackable.SelfSaveableObjectFactories = saveable_objects; + } + else + { + throw new TypeError(); + } + } + } + } + } + + /// + /// Load all saved objects. + /// + private void _load_nodes() + { + // `nodes` maps from node ids to recreated objects + // `node_setters` maps from node ids to setter functions + // (same signature as setattr) for setting children. + var (nodes, node_setters) = _initialize_loaded_nodes(); + + Dictionary + slot_variable_node_ids = new(); + + foreach(var (node_id, proto) in _iter_all_nodes()) + { + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); + } + } + + // Re-create everything. + foreach (var (node_id, proto) in _iter_all_nodes()) + { + if (nodes.ContainsKey(node_id)) + { + continue; + } + else if (slot_variable_node_ids.ContainsKey(node_id)) + { + // Use the public Optimizer interface when creating slot variables. + var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; + var optimizer_object = nodes[optimizer_node_id] as IOptimizer; + var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; + + var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName); + nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable; + node_setters[slot_variable_proto.SlotVariableNodeId] = setattr; + } + else + { + var (node, setter) = _recreate(proto, node_id, nodes); + nodes[node_id] = node; + node_setters[node_id] = setter; + } + } + + if (!nodes.ContainsKey(0)) + { + nodes[0] = _recreate_base_user_object().Item1; + } + _nodes = new List(); + for(int i = 0; i < _proto.Nodes.Count; i++) + { + _nodes.Add(nodes[i]); + } + _node_setters = node_setters; + } + + /// + /// Load state from checkpoint into the deserialized objects. + /// + private void _restore_checkpoint() + { + var variables_path = SavedModelUtils.get_variables_path(_export_dir); + var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0))); + tf_with(ops.device("CPU"), _ => + { + saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + }); + LoadStatus load_status; + if (_save_options.allow_partial_checkpoint) + { + load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); + load_status.assert_nontrivial_match(); + } + else + { + load_status = saver.restore(variables_path, _checkpoint_options); + load_status.assert_existing_objects_matched(); + } + var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; + + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + + /// + /// Adds edges from objects to other objects and functions. + /// + private void _load_edges() + { + foreach(var (node_id, object_proto) in _iter_all_nodes()) + { + _add_object_graph_edges(object_proto, node_id); + } + + if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) + { + var root = get(0); + foreach(var node_path in _node_filters.Keys) + { + var loaded_node = _nodes[_node_path_to_id[node_path]]; + + var path = node_path.Split('.'); + var current_node = root; + foreach(var name in path.Skip(1).Take(path.Length - 2)) + { + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + } + } + + private void _setup_function_captures(string concrete_function_name, IDictionary, object> nodes) + { + if (_restored_concrete_functions.Contains(concrete_function_name)) + { + return; + } + _restored_concrete_functions.Add(concrete_function_name); + var concrete_function = _concrete_functions[concrete_function_name]; + var proto = _proto.ConcreteFunctions[concrete_function_name]; + var inputs = proto.BoundInputs.Select(x => nodes[x]); + function_saved_model_utils.restore_captures(concrete_function, inputs); + } + + private void _setup_remaining_functions() + { + // TODO: implement it with concrete functions. + } + + public object get(int node_id) + { + return _nodes[node_id]; + } + + public object get(string node_id) + { + return get(_node_path_to_id[node_id]); + } + + /// + /// Adds edges from an object to its children. + /// + /// + /// + private void _add_object_graph_edges(SavedObject proto, int node_id) + { + var obj = _nodes[node_id]; + var setter = _node_setters[node_id]; + + foreach(var refer in proto.Children) + { + setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); + // TODO(Rinne): deal with "__call__" + } + } + + private (Dictionary, Dictionary>) _initialize_loaded_nodes() + { + Dictionary nodes = new(); + Dictionary> node_setters = new(); + foreach(var item in _loaded_nodes) + { + var node_id = item.Key; + var (node, setter) = item.Value; + nodes[node_id] = node; + node_setters[node_id] = setter; + } + return (nodes, node_setters); + } + + private IEnumerable<(int, SavedObject)> _iter_all_nodes() + { + foreach(var node_id in _ordered_node_ids) + { + yield return (node_id, _proto.Nodes[node_id]); + } + } + + private (object, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) + { + // skip the registered classes. + Dictionary, object> dependencies = new(); + foreach(var item in _get_node_dependencies(proto)) + { + dependencies[item.Key] = nodes[item.Value]; + } + + return proto.KindCase switch + { + SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes), + SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes), + SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes), + _ => _recreate_default(proto, node_id, dependencies) + }; + } + + /// + /// Creates a Python object from a SavedObject protocol buffer. + /// + /// + /// + /// + private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, object> dependencies) + { + return proto.KindCase switch + { + SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), + SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, dependencies), + SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), + SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), + SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), + _ => throw new NotImplementedException() + }; + } + + private (Trackable, Action) _recreate_user_object(SavedUserObject? proto, int node_id) + { + // skip the check of proto identifier because of lack of property. + var (trackable, setter) = RevivedTypes.deserialize(proto); + if(trackable is null) + { + return _recreate_base_user_object(proto, node_id); + } + return (trackable, setter); + } + + private (Trackable, Action) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) + { + return (new _UserObject(), setattr); + } + + private (BaseResourceVariable, Action) _recreate_variable(SavedVariable proto) + { + string name = proto.Name; + string dbg_name = !string.IsNullOrEmpty(name) ? name : ""; + + // TODO(Rinne): `validate_synchronization_aggregation_trainable` + + var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( + proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); + + var saved_device = proto.Device; + var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); + + if (load_with_device) + { + return tf_with(ops.device(saved_device), _ => + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + }); + } + else + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + } + } + + private (Function, Action) _recreate_function(SavedFunction proto, + IDictionary, object> dependencies) + { + var fn = function_deserialization.recreate_function(proto, _concrete_functions); + foreach (var name in proto.ConcreteFunctions) + { + _setup_function_captures(name, dependencies); + } + return (fn, setattr); + } + + private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, + IDictionary, object> dependencies) + { + var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); + _setup_function_captures(proto.ConcreteFunctionName, dependencies); + return (fn, setattr); + } + + private (Tensor, Action) _get_tensor_from_fn(CapturedTensor proto) + { + var outer_graph = _concrete_functions[proto.ConcreteFunction].func_graph; + var captured_tensor = outer_graph.get_tensor_by_name(proto.Name); + return (captured_tensor, setattr); + } + + // TODO: remove this to a common class. + public static Action setattr = (x, y, z) => + { + Debug.Assert(y is string); + if(x is Trackable trackable) + { + trackable.SetAttr(y as string, z); + } + else + { + var properties = x.GetType().GetProperties(); + foreach (var p in properties) + { + if ((string)y == p.Name) + { + p.SetValue(x, z); + return; + } + } + } + // TODO(Rinne): check if the property has been set successfully. + //throw new ValueError($"Cannot find the property {y} of {x}."); + }; + + public class _UserObject: AutoTrackable + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs new file mode 100644 index 000000000..d1c0170c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs @@ -0,0 +1,122 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Operations; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Loader + { + public static SavedModel parse_saved_model(string export_dir) + { + var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); + var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); + + SavedModel saved_model = new SavedModel(); + if (File.Exists(path_to_pb)) + { + byte[] file_content; + using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) + { + file_content = new byte[f.Length]; + Debug.Assert(f.Length <= int.MaxValue); + f.Read(file_content, 0, (int)f.Length); + } + // TODO: change to stream mode. + saved_model.MergeFrom(file_content); + return saved_model; + } + else if (File.Exists(path_to_pbtxt)) + { + throw new NotImplementedException(); + } + else + { + throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + + $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); + } + } + + // TODO: revise the type of `tags` + public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) + { + return load_partial(export_dir, null, tags, options)["root"]; + } + + public static IDictionary load_partial(string export_dir, IDictionary)>? filters, object? tags = null, LoadOptions? options = null) + { + if (options is null) + { + options = new LoadOptions(); + } + if (tags is not null) + { + throw new NotImplementedException(); + } + var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); + + Trackable root = null; + Loader loader = null; + if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) + { + // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` + var meta_graph_def = saved_model_proto.MetaGraphs[0]; + if (!BitConverter.IsLittleEndian) + { + SavedModelUtils.swap_function_tensor_content(meta_graph_def); + } + + var object_graph_proto = meta_graph_def.ObjectGraphDef; + var ckpt_options = new CheckpointOptions(options.experimental_io_device); + tf_with(ops.init_scope(), x => + { + loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); + root = (Trackable)loader.get(0); + // skip the assignment of `graph_debug_info`. + }); + // skip the assignment of `tensorflow_version` + // skip the assignment of `tensorflow_git_version` + // skip the process of `metrics`. + } + else + { + if(filters is not null && filters.Count > 0) + { + throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" + + " version) cannot be loaded with node filters."); + } + tf_with(ops.init_scope(), x => + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + }); + } + if(filters != null && filters.Count > 0) + { + return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x)); + } + else + { + var res = new Dictionary(); + res["root"] = root; + return res; + } + } + + public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) + { + var saved_model = parse_saved_model(export_dir); + + // TODO: implement debug info. + + return (saved_model, null); + } + + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs new file mode 100644 index 000000000..c81dc29eb --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs @@ -0,0 +1,268 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow.Training.Saving.SavedModel +{ + internal interface ICodec + { + //bool CanEncode(StructuredValue value); + bool CanDecode(StructuredValue value); + //StructuredValue DoEecode(object value, Func encode_fn); + object DoDecode(StructuredValue value, Func decode_fn); + } + public class nested_structure_coder + { + private static Dictionary _codecs = null; + public static object decode_proto(StructuredValue proto) + { + if(_codecs is null) + { + _codecs = new Dictionary(); + _codecs[StructuredValue.KindOneofCase.ListValue] = new ListCodec(); + _codecs[StructuredValue.KindOneofCase.TupleValue] = new TupleCodec(); + _codecs[StructuredValue.KindOneofCase.DictValue] = new DictCodec(); + _codecs[StructuredValue.KindOneofCase.NamedTupleValue] = new NamedTupleCodec(); + _codecs[StructuredValue.KindOneofCase.Float64Value] = new Float64Codec(); + _codecs[StructuredValue.KindOneofCase.Int64Value] = new Int64Codec(); + _codecs[StructuredValue.KindOneofCase.StringValue] = new StringCodec(); + _codecs[StructuredValue.KindOneofCase.NoneValue] = new NoneCodec(); + _codecs[StructuredValue.KindOneofCase.BoolValue] = new BoolCodec(); + _codecs[StructuredValue.KindOneofCase.TensorShapeValue] = new TensorShapeCodec(); + _codecs[StructuredValue.KindOneofCase.TensorDtypeValue] = new TensorTypeCodec(); + _codecs[StructuredValue.KindOneofCase.TensorSpecValue] = new TensorSpecCodec(); + _codecs[StructuredValue.KindOneofCase.BoundedTensorSpecValue] = new BoundedTensorSpecCodec(); + _codecs[StructuredValue.KindOneofCase.TypeSpecValue] = new TypeSpecCodec(); + } + + return decode_proto_internal(proto, x => decode_proto(x)); + } + + public static object decode_proto_internal(StructuredValue proto, Func encode_fn) + { + Debug.Assert(_codecs[proto.KindCase].CanDecode(proto)); + return _codecs[proto.KindCase].DoDecode(proto, encode_fn); + } + } + + internal class ListCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.ListValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.ListValue.Values.Select(x => decode_fn(x)).ToList(); + } + } + + internal class TupleCodec: ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.TupleValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.TupleValue.Values.Select(x => decode_fn(x)).ToArray(); + } + } + + internal class DictCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.DictValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.DictValue.Fields.ToDictionary(x => x.Key, x => decode_fn(x.Value)); + } + } + + internal class NamedTupleCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.NamedTupleValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + var key_value_pairs = value.NamedTupleValue.Values; + var items = key_value_pairs.ToDictionary(x => x.Key, x => decode_fn(x.Value)); + return new Common.Types.NamedTuple() + { + Name = value.NamedTupleValue.Name, + ValueDict = items + }; + } + } + + internal class Float64Codec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.KindCase == StructuredValue.KindOneofCase.Float64Value; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.Float64Value; + } + } + + internal class Int64Codec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.KindCase == StructuredValue.KindOneofCase.Int64Value; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return (int)value.Int64Value; + } + } + + internal class StringCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.StringValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return tf.compat.as_str(value.StringValue); + } + } + + internal class NoneCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.NoneValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return null; + } + } + + internal class BoolCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.KindCase == StructuredValue.KindOneofCase.BoolValue; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.BoolValue; + } + } + + internal class TensorShapeCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.TensorShapeValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return new Shape(value.TensorShapeValue); + } + } + + internal class TensorTypeCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.KindCase == StructuredValue.KindOneofCase.TensorDtypeValue; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + return value.TensorDtypeValue.as_tf_dtype(); + } + } + + internal class TensorSpecCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.TensorSpecValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + var name = value.TensorSpecValue.Name; + var shape = decode_fn(new StructuredValue() + { + TensorShapeValue = value.TensorSpecValue.Shape + }); + Debug.Assert(shape is Shape); + var dtype = decode_fn(new StructuredValue() + { + TensorDtypeValue = value.TensorSpecValue.Dtype + }); + Debug.Assert(dtype is TF_DataType); + return new Framework.Models.TensorSpec(shape as Shape, (TF_DataType)dtype, + string.IsNullOrEmpty(name) ? null : name); + } + } + + internal class BoundedTensorSpecCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.BoundedTensorSpecValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + var btsv = value.BoundedTensorSpecValue; + var name = btsv.Name; + var shape = decode_fn(new StructuredValue() + { + TensorShapeValue = btsv.Shape + }); + Debug.Assert(shape is Shape); + var dtype = decode_fn(new StructuredValue() + { + TensorDtypeValue = btsv.Dtype + }); + Debug.Assert(dtype is TF_DataType); + throw new NotImplementedException("The `BoundedTensorSpec` has not been supported, " + + "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + + internal class TypeSpecCodec : ICodec + { + public bool CanDecode(StructuredValue value) + { + return value.TypeSpecValue is not null; + } + + public object DoDecode(StructuredValue value, Func decode_fn) + { + var type_spec_proto = value.TypeSpecValue; + var type_spec_class_enum = type_spec_proto.TypeSpecClass; + var class_name = type_spec_proto.TypeSpecClassName; + + throw new NotImplementedException("The `TypeSpec` analysis has not been supported, " + + "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs new file mode 100644 index 000000000..23e0a9295 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -0,0 +1,268 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.Protobuf; +using Tensorflow.Checkpoint; +using Tensorflow.Functions; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + private static readonly IEnumerable byte_swappable = new List() + { + dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, + dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, + dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 + }.Select(x => (int)x); + + public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, + string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + { + if (options is null) + { + options = new SaveOptions(); + } + + var saved_model = new Tensorflow.SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + + var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = + _build_meta_graph(obj, signatures, options, meta_graph_def); + saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; + + if (!experimental_skip_checkpoint) + { + SavedModelUtils.get_or_create_variables_dir(export_dir); + CheckpointOptions ckpt_options = new(options.experimental_io_device); + object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); + } + BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); + + if (tf.Context.executing_eagerly()) + { + // tensorflow python has a check of `context.async_wait()` here. + } + + // TODO: deal with `pywrap_saved_model.Save(export_dir)`. + + var saved_model_serialized = saved_model.ToString(); + + // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. + if (true) + { + var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), + tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); + // TODO: add c api and complete the fingerprint def. + var fingerprint_proto = ""; + File.WriteAllText(fingerprint_path, fingerprint_proto); + } + + var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); + File.WriteAllBytes(path, saved_model.ToByteArray()); + //File.WriteAllText(path, saved_model.ToString()); + + if (options.save_debug_info) + { + throw new NotImplementedException(); + } + + ops.dismantle_graph(exported_graph); + + return (saved_nodes, node_paths); + } + + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList, + IDictionary>) _build_meta_graph(Trackable obj, + ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + { + using (SaveContext.save_context(options)) + { + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced [AutoGraph]. " + + "Move the call to the outer eagerly-executed context."); + } + + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } + + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is null) + { + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); + } + + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) + { + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } + } + + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } + } + + private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, + ConcreteFunction signatures, IEnumerable namespace_whitelist, + bool save_custom_gradients) + { + var resource_initializers = saveable_view.get_concrete_resource_initializers(); + var exported_graph = new Graph(); + + Dictionary object_map; + Dictionary tensor_map; + AssetInfo asset_info; + var g = exported_graph.as_default(); + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) + { + // TODO: trace gradient functions. + } + + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers + } + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + + g.Exit(); + + foreach (var obj in object_map.Values) + { + obj._maybe_initialize_trackable(); + } + + // TODO: add the implementation of `call_with_mapped_functions`. + var (named_saveable_objects, registered_savers) = + SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); + var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); + + var eg = exported_graph.as_default(); + var saver_def = saver.to_proto(); + meta_graph_def.SaverDef = saver_def; + eg.Exit(); + + + saveable_view.dependency_sorted_node_ids(); + + var graph_def = exported_graph.as_graph_def(true); + graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); + verify_ops(graph_def, namespace_whitelist); + + meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef = new(); + meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); + meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; + // TODO: add git version. + meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList = new(); + meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); + meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); + + // TODO: deal with signatures here. + + meta_graph.strip_graph_default_valued_attrs(meta_graph_def); + + if (!BitConverter.IsLittleEndian) + { + swap_function_tensor_content(meta_graph_def); + } + + return (asset_info, exported_graph); + } + + private static void verify_ops(GraphDef graph_def, IEnumerable? namespace_whitelist) + { + return; + // if (namespace_whitelist is null || !namespace_whitelist.Any()) + // { + // return; + // } + + // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. + } + + public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) + { + var functions = meta_graph_def.GraphDef.Library.Function; + foreach (var function in functions) + { + var node_def = function.NodeDef; + foreach (var node in node_def) + { + if (node.Op == "Const") + { + var tensor = node.Attr["value"].Tensor; + byte_swap_tensor_content(tensor); + } + } + } + } + + public static void byte_swap_tensor_content(TensorProto tensor) + { + if (byte_swappable.Contains((int)tensor.Dtype)) + { + var tshape = tensor.TensorShape.Dim; + var tensor_bytes = tensor.TensorContent; + if (tensor_bytes is not null && !tensor_bytes.IsEmpty) + { + long tensor_size = 1; + foreach (var sz in tshape) + { + tensor_size *= sz.Size; + } + + var chunksize = tensor_bytes.Length / tensor_size; + List reversed_bytes = new(); + for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) + { + var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); + reversed_bytes.AddRange(current); + } + tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs new file mode 100644 index 000000000..47d8cbab9 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A context for building a graph of SavedModel. + /// + public static class SaveContext + { + // TODO: make it thead safe. + private static bool _in_save_context = false; + private static SaveOptions _save_options = null; + + public static bool in_save_context() => _in_save_context; + public static SaveOptions get_save_options() + { + if (!in_save_context()) + { + throw new ValueError("Not in a SaveContext."); + } + return _save_options; + } + public static SaveContextHandler save_context(SaveOptions options) + { + return new SaveContextHandler(options); + } + + public class SaveContextHandler: IDisposable + { + private bool _old_in_save_context; + private SaveOptions _old_save_options; + public SaveContextHandler(SaveOptions options) + { + if (SaveContext.in_save_context()) + { + throw new ValueError("Already in a SaveContext."); + } + _old_in_save_context = SaveContext._in_save_context; + SaveContext._in_save_context = true; + _old_save_options = SaveContext._save_options; + SaveContext._save_options = options; + } + public void Dispose() + { + SaveContext._in_save_context = _old_in_save_context; + SaveContext._save_options = _old_save_options; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs new file mode 100644 index 000000000..d3ffebc9f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow; + +public static class SignatureSerializationUtils +{ + internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; + internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; + internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + Debug.Assert(func is ConcreteFunction); + // TODO: assert the `func.structured_outputs` and arg_keywords. + signature_map._add_signature(name, (ConcreteFunction)func); + } + + return signature_map; + } + + public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) + { + var children = graph_view.list_children(graph_view.Root); + List possible_signatures = new(); + foreach (var item in children) + { + var name = item.Name; + var child = item.Refer; + if(child is not (Function or ConcreteFunction)) + { + continue; + } + if(name == DEFAULT_SIGNATURE_ATTR) + { + Debug.Assert(child is ConcreteFunction); + return (ConcreteFunction)child; + } + ConcreteFunction concrete = get_signature(child); + if(concrete is not null && valid_signature(concrete)) + { + possible_signatures.Add(concrete); + } + } + + if(possible_signatures.Count == 1) + { + var signature = get_signature(possible_signatures[0]); + if(signature is not null && valid_signature(signature)) + { + return signature; + } + } + return null; + } + + private static ConcreteFunction get_signature(Trackable function) + { + // TODO: implement it. + return null; + } + + private static bool valid_signature(ConcreteFunction concreate_function) + { + // TODO: implement it. + return false; + } +} + +public class SignatureMap: Trackable +{ + private Dictionary _signatures; + + public SignatureMap() + { + _signatures = new(); + } + + public void _add_signature(string name, ConcreteFunction concrete_function) + { + _signatures[name] = concrete_function; + } + + public void _add_signature(string name, Function concrete_function) + { + _signatures[name] = concrete_function; + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + if (save_type != SaveType.SAVEDMODEL) + { + return new Dictionary(); + } + + return _signatures.Where(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..b0e6411c9 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -0,0 +1,57 @@ +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + /// + /// Return variables sub-directory, or create one if it doesn't exist. + /// + /// + public static string get_or_create_variables_dir(string export_dir) + { + var variables_dir = get_variables_dir(export_dir); + Directory.CreateDirectory(variables_dir); + return variables_dir; + } + + /// + /// Return variables sub-directory in the SavedModel. + /// + /// + /// + public static string get_variables_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); + } + + public static string get_variables_path(string export_dir) + { + return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); + } + + /// + /// Return assets sub-directory, or create one if it doesn't exist. + /// + /// + /// + public static string get_or_create_assets_dir(string export_dir) + { + var assets_destination_dir = get_assets_dir(export_dir); + Directory.CreateDirectory(assets_destination_dir); + return assets_destination_dir; + } + + /// + /// Return path to asset directory in the SavedModel. + /// + /// + /// + public static string get_assets_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/Saver.cs b/src/TensorFlowNET.Core/Training/Saving/Saver.cs new file mode 100644 index 000000000..85a3ee7d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/Saver.cs @@ -0,0 +1,351 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Saves and restores variables. + /// + public class Saver + { + private IVariableV1[] _var_list; + private bool _reshape; + private bool _sharded; + private int _max_to_keep; + private float _keep_checkpoint_every_n_hours; + private string _name; + private bool _restore_sequentially; + private SaverDef _saver_def; + private ISaverBuilder _builder; + private bool _allow_empty; + private bool _is_built; + private SaverDef.Types.CheckpointFormatVersion _write_version; + private bool _pad_step_number; +#pragma warning disable CS0649 // Field 'Saver._filename' is never assigned to, and will always have its default value null + private string _filename; +#pragma warning restore CS0649 // Field 'Saver._filename' is never assigned to, and will always have its default value null + private bool _is_empty; + private float _next_checkpoint_time; + private bool _save_relative_paths; +#pragma warning disable CS0414 // The field 'Saver._object_restore_saver' is assigned but its value is never used + private bool? _object_restore_saver; +#pragma warning restore CS0414 // The field 'Saver._object_restore_saver' is assigned but its value is never used + private Dictionary _last_checkpoints; + private Dictionary _checkpoints_to_be_deleted; + + public Saver(IVariableV1[] var_list = null, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + float keep_checkpoint_every_n_hours = 10000, + string name = null, + bool restore_sequentially = false, + SaverDef saver_def = null, + ISaverBuilder builder = null, + bool defer_build = false, + bool allow_empty = false, + SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2, + bool pad_step_number = false, + bool save_relative_paths = false, + string filename = "") + { + _var_list = var_list; + _reshape = reshape; + _sharded = sharded; + _max_to_keep = max_to_keep; + _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; + _name = name; + _restore_sequentially = restore_sequentially; + _saver_def = saver_def; + _builder = builder; + _is_built = false; + _allow_empty = allow_empty; + _write_version = write_version; + _pad_step_number = pad_step_number; + + if (!defer_build) + build(); + if (_saver_def != null) + { + _check_saver_def(); + _write_version = _saver_def.Version; + } + + _save_relative_paths = save_relative_paths; + _object_restore_saver = null; + + _last_checkpoints = new Dictionary(); + _checkpoints_to_be_deleted = new Dictionary(); + } + + public void build() + { + _build(_filename, build_save: true, build_restore: true); + } + + private void _build(string checkpoint_path, bool build_save, bool build_restore) + { + if (_is_built) return; + + _is_built = true; + + if (_saver_def == null) + { + if (_builder == null) + _builder = new BulkSaverBuilder(_write_version); + + if (_var_list == null) + _var_list = variables._all_saveable_objects(); + + if (_var_list == null || _var_list.Length == 0) + { + if (_allow_empty) + { + _is_empty = true; + return; + } + else + { + throw new ValueError("No variables to save"); + } + } + _is_empty = false; + + _saver_def = _builder._build_internal(_var_list, + reshape: _reshape, + sharded: _sharded, + max_to_keep: _max_to_keep, + keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours, + name: _name, + restore_sequentially: _restore_sequentially, + filename: checkpoint_path, + build_save: build_save, + build_restore: build_restore); + } + else if (_saver_def != null && !string.IsNullOrEmpty(_name)) + { + throw new NotImplementedException("Saver._build"); + } + + _check_saver_def(); + + _next_checkpoint_time = time() + _saver_def.KeepCheckpointEveryNHours * 3600; + } + + private void _check_saver_def() + { + if (!tf.Context.executing_eagerly()) + { + if (string.IsNullOrEmpty(_saver_def.SaveTensorName)) + throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}"); + if (string.IsNullOrEmpty(_saver_def.RestoreOpName)) + throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}"); + } + } + + public string save(Session sess, + string save_path, + int global_step = -1, + string latest_filename = "", + string meta_graph_suffix = "meta", + bool write_meta_graph = true, + bool write_state = true, + bool strip_default_attrs = false, + bool save_debug_info = false) + { + if (string.IsNullOrEmpty(latest_filename)) + latest_filename = "checkpoint"; + NDArray[] model_checkpoint_path = null; + string checkpoint_file = ""; + + if (global_step > 0) + checkpoint_file = $"{save_path}-{global_step}"; + else + checkpoint_file = save_path; + + var save_path_parent = Path.GetDirectoryName(save_path); + + if (!_is_empty) + { + model_checkpoint_path = sess.run(_saver_def.SaveTensorName, + (_saver_def.FilenameTensorName, checkpoint_file)); + + if (write_state) + { + var path = model_checkpoint_path[0].StringData()[0]; + _RecordLastCheckpoint(path); + checkpoint_management.update_checkpoint_state_internal( + save_dir: save_path_parent, + model_checkpoint_path: path, + all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), + latest_filename: latest_filename, + save_relative_paths: _save_relative_paths); + _MaybeDeleteOldCheckpoints(meta_graph_suffix: meta_graph_suffix); + } + } + + if (write_meta_graph) + { + string meta_graph_filename = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix: meta_graph_suffix); + export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); + } + + return checkpoint_file; + //var x = model_checkpoint_path[0]; + //var str = x.StringData(); + //return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0]; + } + + public (Saver, object) import_meta_graph(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "") + { + return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope); + } + + /// + /// Restores previously saved variables. + /// + /// This method runs the ops added by the constructor for restoring variables. + /// It requires a session in which the graph was launched. The variables to + /// restore do not have to have been initialized, as restoring is itself a way + /// to initialize variables. + /// + /// A `Session` to use to restore the parameters. None in eager mode. + /// Path where parameters were previously saved. + public void restore(Session sess, string save_path) + { + if (_is_empty) + return; + + if (string.IsNullOrEmpty(save_path)) + throw new ValueError("Can't load save_path when it is None."); + + if (!checkpoint_management.checkpoint_exists(save_path)) + throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}"); + + Binding.tf_output_redirect.WriteLine($"Restoring parameters from {save_path}"); + + if (tf.Context.executing_eagerly()) +#pragma warning disable CS0642 // Possible mistaken empty statement + ; +#pragma warning restore CS0642 // Possible mistaken empty statement + else + sess.run(_saver_def.RestoreOpName, + new FeedItem(_saver_def.FilenameTensorName, save_path)); + } + + /// + /// Writes `MetaGraphDef` to save_path/filename. + /// + /// + /// + /// + /// + /// + /// + /// + public MetaGraphDef export_meta_graph(string filename = "", + string[] collection_list = null, + string export_scope = "", + bool as_text = false, + bool clear_devices = false, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false, + bool save_debug_info = false) + { + return export_meta_graph( + filename: filename, + graph_def: ops.get_default_graph().as_graph_def(add_shapes: true), + saver_def: _saver_def, + collection_list: collection_list, + as_text: as_text, + export_scope: export_scope, + clear_devices: clear_devices, + clear_extraneous_savers: clear_extraneous_savers, + strip_default_attrs: strip_default_attrs); + } + + public MetaGraphDef export_meta_graph(string filename = "", + byte[] meta_info_def = null, + GraphDef graph_def = null, + SaverDef saver_def = null, + string[] collection_list = null, + bool as_text = false, + bool clear_devices = false, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false, + string export_scope = "") + { + var meta_graph_def = meta_graph.export_scoped_meta_graph( + filename: filename, + meta_info_def: meta_info_def, + graph_def: graph_def, + saver_def: saver_def, + // collection_list: collection_list, + as_text: as_text, + clear_devices: clear_devices, + clear_extraneous_savers: clear_extraneous_savers, + strip_default_attrs: strip_default_attrs); + return meta_graph_def.Item1; + } + + /// + /// Manages the list of the latest checkpoints. + /// + /// + private void _RecordLastCheckpoint(string latest_save_path) + { + if (_saver_def.MaxToKeep <= 0) return; + + // Remove first from list if the same name was used before. + var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value))); + if (_existed_checkpoints.Key != null) + _last_checkpoints.Remove(_existed_checkpoints.Key); + _last_checkpoints.Add(latest_save_path, time()); + + // If more than max_to_keep, remove oldest. + if (_last_checkpoints.Count > _saver_def.MaxToKeep) + { + var first = _last_checkpoints.First(); + _last_checkpoints.Remove(first.Key); + _checkpoints_to_be_deleted[first.Key] = first.Value; + } + } + + private string _CheckpointFilename((string, float) p) + { + return p.Item1; + } + + /// + /// Deletes old checkpoints if necessary. + /// + /// + private void _MaybeDeleteOldCheckpoints(string meta_graph_suffix = "meta") + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs new file mode 100644 index 000000000..474336f4f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs @@ -0,0 +1,213 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using static Tensorflow.Binding; +using static Tensorflow.SaverDef.Types; + +namespace Tensorflow +{ + public class checkpoint_management + { + /// + /// Updates the content of the 'checkpoint' file. + /// + /// Directory where the model was saved. + /// The checkpoint file. + /// List of strings. + /// + /// + /// + /// + public static void update_checkpoint_state_internal(string save_dir, + string model_checkpoint_path, + List all_model_checkpoint_paths = null, + string latest_filename = "", + bool save_relative_paths = false, + List all_model_checkpoint_timestamps = null, + float? last_preserved_timestamp = null + ) + { + CheckpointState ckpt = null; + // Writes the "checkpoint" file for the coordinator for later restoration. + string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); + if (save_relative_paths) + { + throw new NotImplementedException("update_checkpoint_state_internal save_relative_paths"); + } + else + { + ckpt = generate_checkpoint_state_proto(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths, + all_model_checkpoint_timestamps, + last_preserved_timestamp); + } + + if (coord_checkpoint_filename == ckpt.ModelCheckpointPath) + throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + + "checkpoint state. Please use a different save path."); + + // File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); + var checkpoints = new List + { + $"model_checkpoint_path: \"{ckpt.ModelCheckpointPath}\"" + }; + checkpoints.AddRange(all_model_checkpoint_paths.Select(x => $"all_model_checkpoint_paths: \"{x}\"")); + + File.WriteAllLines(coord_checkpoint_filename, checkpoints); + } + + /// + /// Returns a filename for storing the CheckpointState. + /// + /// The directory for saving and restoring checkpoints. + /// + /// Name of the file in 'save_dir' that is used + /// to store the CheckpointState. + /// + /// he path of the file that contains the CheckpointState proto. + private static string _GetCheckpointFilename(string save_dir, string latest_filename) + { + if (string.IsNullOrEmpty(latest_filename)) + latest_filename = "checkpoint"; + + return Path.Combine(save_dir, latest_filename); + } + + private static CheckpointState generate_checkpoint_state_proto(string save_dir, + string model_checkpoint_path, + List all_model_checkpoint_paths = null, + List all_model_checkpoint_timestamps = null, + double? last_preserved_timestamp = null) + { + if (all_model_checkpoint_paths == null) + all_model_checkpoint_paths = new List(); + + if (!all_model_checkpoint_paths.Contains(model_checkpoint_path)) + all_model_checkpoint_paths.Add(model_checkpoint_path); + + // Relative paths need to be rewritten to be relative to the "save_dir" + if (model_checkpoint_path.StartsWith(save_dir)) + { + model_checkpoint_path = model_checkpoint_path.Substring(save_dir.Length + 1); + all_model_checkpoint_paths = all_model_checkpoint_paths + .Select(x => x.Substring(save_dir.Length + 1)) + .ToList(); + } + + + var coord_checkpoint_proto = new CheckpointState() + { + ModelCheckpointPath = model_checkpoint_path + }; + + if (last_preserved_timestamp.HasValue) + coord_checkpoint_proto.LastPreservedTimestamp = last_preserved_timestamp.Value; + + coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths); + if (all_model_checkpoint_timestamps != null) + coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x)); + + return coord_checkpoint_proto; + } + + /// + /// Returns the meta graph filename. + /// + /// + /// + /// + public static string meta_graph_filename(string checkpoint_filename, string meta_graph_suffix = "meta") + { + string basename = checkpoint_filename; + string suffixed_filename = basename + "." + meta_graph_suffix; + return suffixed_filename; + } + + public static bool checkpoint_exists(string checkpoint_prefix) + { + string pathname = _prefix_to_checkpoint_path(checkpoint_prefix, CheckpointFormatVersion.V2); + if (File.Exists(pathname)) + return true; + else if (File.Exists(checkpoint_prefix)) + return true; + else + return false; + } + + private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormatVersion format_version) + { + if (format_version == CheckpointFormatVersion.V2) + return prefix + ".index"; + return prefix; + } + + /// + /// Finds the filename of latest saved checkpoint file. + /// + /// + /// + /// + public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + { + // Pick the latest checkpoint based on checkpoint state. + var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename); + if (ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + { + // Look for either a V2 path or a V1 path, with priority for V2. + var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2); + var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1); + if (File.Exists(v2_path) || File.Exists(v1_path)) + return ckpt.ModelCheckpointPath; + else + throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}"); + } + return null; + } + + public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + { + var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); + if (File.Exists(coord_checkpoint_filename)) + { + var file_content = File.ReadAllText(coord_checkpoint_filename); + // https://github.com/protocolbuffers/protobuf/issues/6654 + var ckpt = CheckpointState.Parser.ParseText(file_content); + if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); + // For relative model_checkpoint_path and all_model_checkpoint_paths, + // prepend checkpoint_dir. + if (!Path.IsPathRooted(ckpt.ModelCheckpointPath)) + ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath); + foreach (var i in range(len(ckpt.AllModelCheckpointPaths))) + { + var p = ckpt.AllModelCheckpointPaths[i]; + if (!Path.IsPathRooted(p)) + ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p); + } + + return ckpt; + } + + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs new file mode 100644 index 000000000..5f198a4f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -0,0 +1,466 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using OneOf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Device; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using Tensorflow.Training; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A SaveableObject that defines `Trackable` checkpointing steps. + /// + public class TrackableSaveable : MySaveableObject + { + private string _prefix; + private IEnumerable _local_names; + private Trackable _trackable; + private bool _call_with_mapped_captures; + // TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. + public TrackableSaveable(Trackable obj, IEnumerable specs, string name, IEnumerable local_names, + string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) + { + _prefix = prefix; + _trackable = obj; + _local_names = local_names; + _call_with_mapped_captures = call_with_mapped_captures; + } + + // TODO: complete this class. + } + public static class saveable_object_util + { + public static string NO_SLICE_SPEC_KEY = ""; + private static HashSet _VARIABLE_OPS = new HashSet(new string[] { + "Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp" + }); + /// + /// Returns the variables and names that will be used for a Saver. + /// + /// + /// + public static MySaveableObject[] validate_and_slice_inputs(IVariableV1[] names_to_saveables) + { + var names_to_saveables_dict = op_list_to_dict(names_to_saveables); + var saveables = new List(); + var seen_ops = new List(); + + foreach (var (name, op) in enumerate(names_to_saveables_dict)) + { + foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + return saveables.ToArray(); + } + + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach (var (name, op) in enumerate(names_to_saveables)) + { + foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + return saveables.ToArray(); + } + + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach(var item in names_to_saveables.OrderBy(x => x.Key)) + { + foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) + { + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + } + return saveables.ToArray(); + } + + private static void _add_saveable(List saveables, List seen_ops, T saveable) where T : MySaveableObject + { + if (seen_ops.Contains(saveable.op)) + throw new ValueError($"The same saveable will be restored with two names: {saveable.name}"); + + saveables.Add(saveable); + seen_ops.Add(saveable.op); + } + + private static void _add_saveable(List saveables, List seen_ops, MySaveableObject saveable) + { + if (seen_ops.Contains(saveable.variable)) + throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); + + saveables.Add(saveable); + seen_ops.Add(saveable.variable); + } + + /// + /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Tensor op, string name) + { + ops.init_scope(); + var variable = ops.convert_to_tensor(op, as_ref: true); + if (variable.dtype.is_ref_dtype()) + yield return new ReferenceVariableSaveable(variable, "", name); + else + yield return new ResourceVariableSaveable(variable, "", name); + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Trackable obj, string name) + { + // The `op` maybe `Variable` or `Trackable`. + if (obj is BaseResourceVariable) + { + var variable = obj as BaseResourceVariable; + if (variable.InGraphMode) + { + yield return new ResourceVariableSaveable(variable.GraphElement, "", name); + } + else + { + yield return new ResourceVariableSaveable(variable, "", name); + } + } + else if(obj is not IVariableV1) + { + foreach(var pair in saveable_objects_from_trackable(obj)) + { + var attr = pair.Key; + var factory = pair.Value; + string full_name; + if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) + { + full_name = name; + } + else + { + full_name = name + "_" + attr; + } + var op = factory(full_name); + if(op.TryPickT0(out var variable, out var saveable)) + { + foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) + { + yield return v; + } + } + else + { + foreach (var v in saveable_objects_for_op(saveable, saveable.name)) + { + yield return v; + } + } + } + } + else + { + // Variable + if (tf.Context.executing_eagerly()) + { + throw new ValueError($"Can only save/restore ResourceVariables when " + + $"executing eagerly, got type: {obj.GetType()}."); + } + var variable = ops.convert_to_tensor(obj, as_ref: true); + if (!_tensor_comes_from_variable(variable)) + { + throw new TypeError($"names_to_saveables must be a dict mapping string " + + $"names to Tensors/Variables. Not a variable: {variable}"); + } + if(variable.op.type == "Variable" || variable.op.type == "VariableV2" || + variable.op.type == "AutoReloadVariable") + { + yield return new ReferenceVariableSaveable(variable, "", name); + } + else + { + yield return new ResourceVariableSaveable(variable, "", name); + } + } + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(MySaveableObject obj, string name) + { + yield return obj; + } + + public static Dictionary op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) + { + op_list = op_list.OrderBy(x => x.Name).ToArray(); + var names_to_saveables = new Dictionary(); + + foreach (var var in op_list) + { + bool resource_or_ref_variable = var is RefVariable || var is ResourceVariable; + if (false) + { + throw new NotImplementedException("op_list_to_dict"); + } + else + { + // Variables (reference and resource) have an _in_graph_mode property + if (false) // eager + { + + } + else + { + string name = null; + Tensor tensor = null; + + if (convert_variable_to_tensor) + { + if (!var.dtype.is_ref_dtype()) + tensor = var.GraphElement; + else + tensor = ops.convert_to_tensor(var, as_ref: true); + } + + if (tensor.op.type == "ReadVariableOp") + name = tensor.op.inputs[0].op.name; + else + name = var.Op.name; + + if (names_to_saveables.ContainsKey(name)) + throw new ValueError($"At least two variables have the same name: {name}"); + + names_to_saveables[name] = tensor; + } + } + } + + return names_to_saveables; + } + + public static IDictionary>> saveable_objects_from_trackable(Trackable obj) + { + // skip the process of type `PythonState` + + OneOf create_saveable(string name = "") + { + // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. + var tensor_dict = obj.serialize_to_tensors(); + + List specs = new(); + List local_names = new(); + string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; + foreach (var pair in tensor_dict) + { + var tensor_name = pair.Key; + var internal_dict = pair.Value; + local_names.Add(tensor_name); + string spec_name = name + TrackableUtils.escape_local_name(tensor_name); + + foreach (var item in internal_dict) + { + Debug.Assert(item.Value.IsT0); + specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name)); + } + } + return new TrackableSaveable(obj, specs, name, local_names, prefix); + } + + if (trackable_has_serialize_to_tensor(obj)) + { + Dictionary>> res = new(); + res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; + return res; + } + else + { + return obj.gather_saveables_for_checkpoint(); + } + } + + public static bool trackable_has_serialize_to_tensor(Trackable obj) + { + return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); + } + + internal static string convert_to_string(string x) + { + return tf.compat.as_str(x); + } + + /// + /// Converts a list of SaveableObjects to a tensor dictionary. + /// + /// + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) + { + Dictionary>> tensor_dict = new(); + foreach (var saveable in saveables) + { + foreach (var spec in saveable.specs) + { + // skip the check that if `spec` is callable. + var name = convert_to_string(spec.name); + var slice_spec = convert_to_string(spec.slice_spec); + if (string.IsNullOrEmpty(slice_spec)) + { + slice_spec = NO_SLICE_SPEC_KEY; + } + tensor_dict.SetDefault(name, new Dictionary>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec; + } + } + return tensor_dict; + } + + /// + /// Generates `Trackable._restore_from_tensors` from SaveableObjects. + /// + /// + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + { + return (restored_tensors) => + { + Dictionary restored_ops = new(); + + foreach(var saveable in saveables) + { + List saveable_restored_tensors = new(); + foreach(var spec in saveable.specs) + { + var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + + var maybe_tensor = restored_tensors[name]; + IDictionary dict; + if(maybe_tensor.TryPickT0(out var tensor, out var dic)) + { + dict = new Dictionary(); + dict[""] = tensor; + } + else + { + dict = dic; + } + saveable_restored_tensors.Add(dict[slice_spec]); + } + restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); + } + return restored_ops; + }; + } + + /// + /// Returns a dict of SaveableObject factories generated from loaded fns. + /// + /// + /// + public static IDictionary>> recreate_saveable_objects( + IDictionary saveable_fn_by_name, IEnumerable? temp_session) + { + if (saveable_fn_by_name.Count > 0) + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + var res = new Dictionary>>(); + return res; + } + + public static OneOf create_saveable_object(string name, string key, Func> factory, + bool call_with_mapped_captures = false) + { + return factory(key); + } + + public static string set_cpu0(string device_string) + { + if (tf.Context.is_custom_device(device_string)) + { + return device_string; + } + var parsed_device = DeviceSpec.from_string(device_string); + parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0); + return parsed_device.ToString(); + } + + private static bool _tensor_comes_from_variable(object v) + { + return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private object _obj; + private IList _saveables; + public SaveableCompatibilityConverter(object obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public object Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary>> serialize_to_tensors() + { + return saveable_object_util.saveable_object_to_tensor_dict(_saveables); + } + + /// + /// Returns the restore ops defined in the Saveables. + /// + /// + /// + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + List expected_keys = new(); + foreach(var saveable in _saveables) + { + expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); + } + if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) + { + throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + + $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); + } + return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs new file mode 100644 index 000000000..f94f98940 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs @@ -0,0 +1,118 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class saver + { + public static (Saver, object) _import_meta_graph_with_return_elements(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "", + string[] return_elements = null) + { + var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); + + var (imported_vars, imported_return_elements) = meta_graph.import_scoped_meta_graph_with_return_elements( + meta_graph_def, + clear_devices: clear_devices, + import_scope: import_scope, + return_elements: return_elements); + + var saver = _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars); + + return (saver, imported_return_elements); + } + + /// + /// Return a saver for restoring variable values to an imported MetaGraph. + /// + /// + /// + /// + /// + public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, + string import_scope, + Dictionary imported_vars) + { + if (meta_graph_def.SaverDef != null) + { + // Infer the scope that is prepended by `import_scoped_meta_graph`. + string scope = import_scope; + var var_names = imported_vars.Keys.ToArray(); + if (var_names.Length > 0) + { + var sample_key = var_names[0]; + var sample_var = imported_vars[sample_key]; + scope = string.Join("", sample_var.Name.Skip(sample_key.Length)); + } + return new Saver(saver_def: meta_graph_def.SaverDef, name: scope); + } + else + { + if (variables._all_saveable_objects(scope: import_scope).Length > 0) + { + // Return the default saver instance for all graph variables. + return new Saver(); + } + else + { + // If no graph variables exist, then a Saver cannot be constructed. + Binding.tf_output_redirect.WriteLine("Saver not created because there are no variables in the" + + " graph to restore"); + return null; + } + } + } + + public static string freeze_graph(string checkpoint_dir, + string output_pb_name, + string[] output_node_names) + { + var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir); + if (!File.Exists($"{checkpoint}.meta")) return null; + + string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); + + var graph = tf.Graph(); + var sess = tf.Session(graph); + var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); + saver.restore(sess, checkpoint); + var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, + graph.as_graph_def(), + output_node_names); + Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); + File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); + return output_pb; + } + + public static Graph load_graph(string freeze_graph_pb, string name = "") + { + var bytes = File.ReadAllBytes(freeze_graph_pb); + var graph = tf.Graph().as_default(); + importer.import_graph_def(GraphDef.Parser.ParseFrom(bytes), + name: name); + return graph; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/SecondOrStepTimer.cs b/src/TensorFlowNET.Core/Training/SecondOrStepTimer.cs new file mode 100644 index 000000000..cc5b7488a --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SecondOrStepTimer.cs @@ -0,0 +1,41 @@ +using System; + +namespace Tensorflow.Training +{ + public class SecondOrStepTimer : _HookTimer + { + int _every_secs = 60; + int _every_steps = 0; + int _last_triggered_step = 0; +#pragma warning disable CS0414 // The field 'SecondOrStepTimer._last_triggered_time' is assigned but its value is never used + int _last_triggered_time = 0; +#pragma warning restore CS0414 // The field 'SecondOrStepTimer._last_triggered_time' is assigned but its value is never used + + public SecondOrStepTimer(int every_secs, int every_steps) + { + _every_secs = every_secs; + _every_steps = every_steps; + } + + public override void reset() + { + _last_triggered_step = 0; + _last_triggered_time = 0; + } + + public override int last_triggered_step() + { + return _last_triggered_step; + } + + public override bool should_trigger_for_step(int step) + { + throw new NotImplementedException(); + } + + public override void update_last_triggered_step(int step) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/SessionRunArgs.cs b/src/TensorFlowNET.Core/Training/SessionRunArgs.cs new file mode 100644 index 000000000..f65b45242 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SessionRunArgs.cs @@ -0,0 +1,6 @@ +namespace Tensorflow.Training +{ + public class SessionRunArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Training/SessionRunContext.cs b/src/TensorFlowNET.Core/Training/SessionRunContext.cs new file mode 100644 index 000000000..c30ee7dc8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SessionRunContext.cs @@ -0,0 +1,26 @@ +namespace Tensorflow.Training +{ + public class SessionRunContext + { + SessionRunArgs _original_args; + public SessionRunArgs original_args => _original_args; + + Session _session; + public Session session => _session; + + bool _stop_requested; + public bool stop_requested => _stop_requested; + + public SessionRunContext(SessionRunArgs original_args, Session session) + { + _original_args = original_args; + _session = session; + _stop_requested = false; + } + + public void request_stop() + { + _stop_requested = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/SessionRunHook.cs b/src/TensorFlowNET.Core/Training/SessionRunHook.cs new file mode 100644 index 000000000..28552fa52 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SessionRunHook.cs @@ -0,0 +1,46 @@ +namespace Tensorflow.Training +{ + /// + /// Hook to extend calls to MonitoredSession.run(). + /// + public abstract class SessionRunHook + { + /// + /// Called once before using the session. + /// + public virtual void begin() + { + } + + /// + /// Called when new TensorFlow session is created. + /// + /// + /// + public virtual void after_create_session(Session session, Coordinator coord) + { + } + + /// + /// Called before each call to run(). + /// + /// + public virtual void before_run(SessionRunContext run_context) + { + } + + /// + /// Called after each call to run(). + /// + public virtual void after_run(SessionRunContext run_context, SessionRunValues run_values) + { + } + + /// + /// Called at the end of session. + /// + public virtual void end(Session session) + { + } + } +} diff --git a/src/TensorFlowNET.Core/Training/SessionRunValues.cs b/src/TensorFlowNET.Core/Training/SessionRunValues.cs new file mode 100644 index 000000000..c0135d2cd --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SessionRunValues.cs @@ -0,0 +1,6 @@ +namespace Tensorflow.Training +{ + public class SessionRunValues + { + } +} diff --git a/src/TensorFlowNET.Core/Training/SlotCreator.cs b/src/TensorFlowNET.Core/Training/SlotCreator.cs new file mode 100644 index 000000000..df9983ab3 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/SlotCreator.cs @@ -0,0 +1,113 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Operations.Initializers; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class SlotCreator + { + /// + /// Create a slot initialized to the given value. + /// + /// + /// + /// + /// + /// + public IVariableV1 create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) + { + var validate_shape = val.shape.IsFullyDefined; + var prefix = primary.Op.name; + return tf_with(tf.variable_scope(name: null, prefix + "/" + name), delegate + { + return _create_slot_var(primary, val, "", validate_shape, null, TF_DataType.DtInvalid); + }); + } + + /// + /// Create a slot initialized to 0 with same shape as the primary object. + /// + /// + /// + /// + /// + /// + public IVariableV1 create_zeros_slot(IVariableV1 primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) + { + if (dtype == TF_DataType.DtInvalid) + dtype = primary.dtype; + var slot_shape = primary.shape; + if (slot_shape.IsFullyDefined) + { + var initializer = new Zeros(); + return create_slot_with_initializer( + primary, initializer, slot_shape, dtype, name, + colocate_with_primary: colocate_with_primary); + } + else + { + throw new NotImplementedException("create_zeros_slot is not fully defined."); + } + } + + /// + /// Creates a slot initialized using an `Initializer`. + /// + /// + public IVariableV1 create_slot_with_initializer(IVariableV1 primary, IInitializer initializer, Shape shape, + TF_DataType dtype, string name, bool colocate_with_primary = true) + { + var validate_shape = shape.IsFullyDefined; + var prefix = primary.Op.name; + return tf_with(new variable_scope(string.Empty, prefix + "/" + name), delegate + { + return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype); + }); + } + + /// + /// Helper function for creating a slot variable. + /// + /// + /// + /// + /// + /// + /// + /// + private IVariableV1 _create_slot_var(IVariableV1 primary, object val, string scope, bool validate_shape, + Shape shape, TF_DataType dtype) + { + bool use_resource = primary is ResourceVariable; + if (resource_variable_ops.is_resource_variable(primary)) + use_resource = true; + + var slot = tf.compat.v1.get_variable( + scope, + initializer: val, + trainable: false, + use_resource: use_resource, + shape: shape, + dtype: dtype, + validate_shape: validate_shape); + + return slot; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs new file mode 100644 index 000000000..3eff34875 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -0,0 +1,361 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using OneOf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public abstract class Trackable: IWithTrackable + { + /// + /// Corresponding to tensorflow/python/trackable/constants.py + /// + public static class Constants + { + public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; + public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; + public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; + } + protected int _self_update_uid; + protected IDictionary _unconditional_dependency_names; + + protected IList _unconditional_checkpoint_dependencies; + protected Dictionary> _unconditional_deferred_dependencies; + + protected IDictionary>> _self_saveable_object_factories = + new Dictionary>>(); + private bool _manual_tracking = true; + + private static Trackable _none = new AutoTrackable(); + /// + /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. + /// The `None` can be any object that inherits `Trackable`. + /// This Property is supposed to be used only internal. + /// + public static Trackable None + { + get + { + return _none; + } + } + public Trackable GetTrackable() + { + return this; + } + public virtual string ObjectIdentifier + { + get => "_generic_user_object"; + } + public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } + public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } + public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } + public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + public Dictionary> DeferredDependencies => _unconditional_deferred_dependencies; + public IDictionary>> SelfSaveableObjectFactories + { + get + { + return _self_saveable_object_factories; + } + set + { + _self_saveable_object_factories = value; + } + } + public Dictionary CustomizedFields { get; set; } = new Dictionary(); + + public virtual void SetAttr(string name, object value) + { + var t = this.GetType(); + var field_info = t.GetField(name); + if(field_info is not null) + { + field_info.SetValue(this, value); + } + else + { + CustomizedFields[name] = value; + } + + // On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`. + // When adding new members or properties to this class, please add corresponding process to this method. + //switch (name) + //{ + // case "_manual_tracking": + // { + // _manual_tracking = (bool)value; + // break; + // } + // case "_self_saveable_object_factories": + // { + // _self_saveable_object_factories = (IDictionary>>)value; + // break; + // } + // case "_self_update_uid": + // { + // _self_update_uid = (int)value; + // break; + // } + // case "_unconditional_checkpoint_dependencies": + // { + // _unconditional_checkpoint_dependencies = (IList)value; + // break; + // } + // case "_unconditional_deferred_dependencies": + // { + // _unconditional_deferred_dependencies = (Dictionary>)value; + // break; + // } + // case "_unconditional_dependency_names": + // { + // _unconditional_dependency_names = (IDictionary)value; + // break; + // } + // case "SelfSaveableObjectFactories": + // { + // SelfSaveableObjectFactories = (IDictionary>>)value; + // break; + // } + // case "UpdateUid": + // { + // UpdateUid = (int)value; + // break; + // } + // default: + // { + // CustomizedAttributes[name] = value; + // break; + // } + // } + } + + /// + /// Restore-on-create for a variable be saved with this `Checkpointable`. + /// + /// + protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) + { + tf_with(ops.init_scope(), delegate + { +#pragma warning disable CS0219 // Variable is assigned but its value is never used + IInitializer checkpoint_initializer = null; +#pragma warning restore CS0219 // Variable is assigned but its value is never used + if (tf.Context.executing_eagerly()) +#pragma warning disable CS0642 // Possible mistaken empty statement + ; +#pragma warning restore CS0642 // Possible mistaken empty statement + else + checkpoint_initializer = null; + }); + + var new_variable = args.Getter(args); + + // If we set an initializer and the variable processed it, tracking will not + // assign again. It will add this variable to our dependencies, and if there + // is a non-trivial restoration queued, it will handle that. This also + // handles slot variables. + if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) + { + var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite); + Debug.Assert(res is IVariableV1); + return res as IVariableV1; + } + else + return new_variable; + } + + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// + /// + /// + protected void _handle_deferred_dependencies(string name, IVariableV1 trackable) + { + _maybe_initialize_trackable(); + // TODO + } + + protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string name, bool overwrite = false) + { + return checkpointable; + } + + /// + /// Initialize dependency management. + /// + public void _maybe_initialize_trackable() + { + if(_unconditional_checkpoint_dependencies is not null) + { + return; + } + _self_update_uid = -1; + _unconditional_checkpoint_dependencies = new List(); + _unconditional_dependency_names = new Dictionary(); + _unconditional_deferred_dependencies = new Dictionary>(); + } + + public virtual IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, + IDictionary>? cache = null) + { + _maybe_initialize_trackable(); + return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); + } + + public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) + { + _maybe_initialize_trackable(); + if (!_manual_tracking) return trackable; + var new_reference = new TrackableReference(name, trackable); + var current_object = _lookup_dependency(name); + + if(current_object is null) + { + _unconditional_checkpoint_dependencies.Add(new_reference); + _handle_deferred_dependencies(name, trackable); + } + _unconditional_dependency_names[name] = trackable; + return trackable; + } + + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for + /// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are + /// considered fulfilled and so are deleted. + /// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. + /// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, + /// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context + /// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). + /// + /// + /// + public virtual void _handle_deferred_dependencies(string name, Trackable trackable) + { + _maybe_initialize_trackable(); + trackable._maybe_initialize_trackable(); + + if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) + { + _unconditional_deferred_dependencies.Remove(name); + foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) + { + checkpoint_position.restore(trackable); + } + } + + // TODO(Rinne): deal with `_self_name_based_restores` + } + + public virtual Trackable? _lookup_dependency(string name) + { + if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; + else return null; + } + + public static Trackable convert_to_trackable(object obj, object? parent = null) + { + if (obj is Trackable) + { + return (Trackable)obj; + } + else + { + throw new NotImplementedException(); + } + } + + public virtual IDictionary deserialization_dependencies(IDictionary children) + { + return new Dictionary(); + } + + public virtual (IDictionary, IDictionary) map_resources( + SaveOptions? save_options) + { + return (new Dictionary(), new Dictionary()); + } + + public virtual List export_to_saved_model_graph(IDictionary object_map, + IDictionary tensor_map, SaveOptions? options = null) + { + var (self_object_map, self_tensor_map) = map_resources(options); + foreach (var pair in self_object_map) + { + object_map.Add(pair); + } + foreach (var pair in self_tensor_map) + { + tensor_map.Add(pair); + } + + return self_tensor_map.Keys.ToList(); + } + + public virtual IDictionary>> gather_saveables_for_checkpoint() + { + OneOf create_saveable(string name = "") + { + throw new NotImplementedException(); + //return new TrackableSaveable(this, null, name, null, null); + } + if (saveable_object_util.trackable_has_serialize_to_tensor(this)) + { + // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). + Dictionary>> res = new(); + res[""] = create_saveable; + return res; + } + else + { + return _self_saveable_object_factories; + } + } + + /// + /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` + /// if you are defining a custom resource or variable with custom ops. + /// Otherwise, please store the state of your trackable in `tf.Variable` objects + /// and add them to Trackable object hierarchy using `setattr` (for subclasses + /// of `AutoTrackable`) or overriding the `_trackable_children` method. + /// + /// + /// + public virtual IDictionary>> serialize_to_tensors() + { + throw new NotImplementedException(); + } + + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + throw new NotImplementedException(); + } + } + + public record class TrackableReference(string Name, Trackable Refer); + + public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); +} diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs new file mode 100644 index 000000000..89bb614d2 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -0,0 +1,173 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Exceptions; +using Tensorflow.Train; + +namespace Tensorflow.Training; + +public static class TrackableUtils +{ + public class CyclicDependencyError: System.Exception + { + public IDictionary> LeftOverDependencyMap { get; } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map; + } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); + } + } + internal static string _ESCAPE_CHAR = "."; + internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + public static string object_path_to_string(IEnumerable node_path_arr) + { + return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); + } + + public static string escape_local_name(string name) + { + return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); + } + + public static string checkpoint_key(string object_path, string local_name) + { + var key_suffix = escape_local_name(local_name); + if (local_name == SERIALIZE_TO_TENSORS_NAME) + { + key_suffix = ""; + } + + return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; + } + + /// + /// Topologically sorts the keys of a map so that dependencies appear first. + /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + /// + /// + /// + public static List order_by_dependency(IDictionary> dependency_map) + { + Dictionary> reverse_dependency_map = new(); + foreach (var pair in dependency_map) + { + foreach (var dep in pair.Value) + { + if (reverse_dependency_map.ContainsKey(dep)) + { + reverse_dependency_map[dep].Add(pair.Key); + } + else + { + reverse_dependency_map[dep] = new HashSet(); + reverse_dependency_map[dep].Add(pair.Key); + } + } + } + + // Validate that all values in the dependency map are also keys. + var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); + if (unknown_keys.Count() > 0) + { + throw new ValueError( + $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); + } + + // Generate the list sorted by objects without dependencies -> dependencies. + // The returned list will reverse this. + List reversed_dependency_arr = new(); + + Queue to_visit = new(); + foreach (var x in dependency_map.Keys) + { + if (!reverse_dependency_map.ContainsKey(x)) + { + to_visit.Enqueue(x); + } + } + + while (to_visit.Count > 0) + { + var x = to_visit.Dequeue(); + reversed_dependency_arr.Add(x); + foreach (var dep in dependency_map[x].Distinct()) + { + var edges = reverse_dependency_map[dep]; + edges.Remove(x); + if (edges.Count == 0) + { + to_visit.Enqueue(dep); + if (!reverse_dependency_map.Remove(dep)) + { + throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); + } + } + } + } + + if (reverse_dependency_map.Count > 0) + { + Dictionary> leftover_dependency_map = new(); + foreach (var pair in reverse_dependency_map) + { + foreach (var x in pair.Value) + { + if (leftover_dependency_map.ContainsKey(x)) + { + leftover_dependency_map[x].Add(pair.Key); + } + else + { + leftover_dependency_map[x] = new List() { pair.Key }; + } + } + } + + throw new CyclicDependencyError(leftover_dependency_map); + } + + reversed_dependency_arr.Reverse(); + return reversed_dependency_arr; + } + + public static string pretty_print_node_path(IEnumerable paths) + { + if (paths.Count() == 0) + { + return "root object"; + } + else + { + return $"root.{string.Join(".", paths.Select(x => x.Name))}"; + } + } + + /// + /// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. + /// + /// + /// + /// + public static string extract_local_name(string key, string? prefix = null) + { + if(prefix is null) + { + prefix = ""; + } + var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; + try + { + return key.Substring(key.IndexOf(search_key) + search_key.Length); + } + catch(ArgumentOutOfRangeException) + { + return key; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/TrainingUtil.cs b/src/TensorFlowNET.Core/Training/TrainingUtil.cs new file mode 100644 index 000000000..1fd923353 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrainingUtil.cs @@ -0,0 +1,87 @@ +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class TrainingUtil + { + public static IVariableV1 create_global_step(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + if (get_global_step(graph) != null) + throw new ValueError("global_step already exists."); + + // Create in proper graph and base name_scope. + var g = graph.as_default(); + g.name_scope(null); + var v = tf.compat.v1.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, + initializer: tf.zeros_initializer, + trainable: false, + aggregation: VariableAggregation.OnlyFirstReplica, + collections: new List { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP }); + return v; + } + + public static RefVariable get_global_step(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + RefVariable global_step_tensor = null; + var global_step_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP); + if (global_step_tensors.Count == 1) + { + global_step_tensor = global_step_tensors[0]; + } + else + { + try + { + global_step_tensor = graph.get_tensor_by_name("global_step:0"); + } + catch (KeyError) + { + return null; + } + } + + return global_step_tensor; + } + + public static Tensor _get_or_create_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensor = _get_global_step_read(graph); + if (global_step_read_tensor != null) + return global_step_read_tensor; + + var global_step_tensor = get_global_step(graph); + + if (global_step_tensor == null) + return null; + + var g = graph.as_default(); + g.name_scope(null); + g.name_scope(global_step_tensor.Op.name + "/"); + // using initialized_value to ensure that global_step is initialized before + // this run. This is needed for example Estimator makes all model_fn build + // under global_step_read_tensor dependency. + var global_step_value = global_step_tensor.initialized_value(); + ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0); + + return _get_global_step_read(graph); + } + + private static Tensor _get_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY); + if (global_step_read_tensors.Count > 1) + throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " + + "There should be only one."); + + if (global_step_read_tensors.Count == 1) + return global_step_read_tensors[0]; + + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/VariableAggregationType.cs b/src/TensorFlowNET.Core/Training/VariableAggregationType.cs new file mode 100644 index 000000000..70976b1f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/VariableAggregationType.cs @@ -0,0 +1,10 @@ +namespace Tensorflow +{ + public enum VariableAggregationType + { + NONE = 0, + SUM = 1, + MEAN = 2, + ONLY_FIRST_TOWER = 3 + } +} diff --git a/src/TensorFlowNET.Core/Training/_HookTimer.cs b/src/TensorFlowNET.Core/Training/_HookTimer.cs new file mode 100644 index 000000000..8c7b299fd --- /dev/null +++ b/src/TensorFlowNET.Core/Training/_HookTimer.cs @@ -0,0 +1,32 @@ +namespace Tensorflow.Training +{ + /// + /// Base timer for determining when Hooks should trigger. + /// + public abstract class _HookTimer + { + /// + /// Resets the timer. + /// + public abstract void reset(); + + /// + /// Return true if the timer should trigger for the specified step. + /// + /// + /// + public abstract bool should_trigger_for_step(int step); + + /// + /// Update the last triggered time and step number. + /// + /// + public abstract void update_last_triggered_step(int step); + + /// + /// Returns the last triggered time step or None if never triggered. + /// + /// + public abstract int last_triggered_step(); + } +} diff --git a/src/TensorFlowNET.Core/Training/_MonitoredSession.cs b/src/TensorFlowNET.Core/Training/_MonitoredSession.cs new file mode 100644 index 000000000..26e986392 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/_MonitoredSession.cs @@ -0,0 +1,6 @@ +namespace Tensorflow.Train +{ + internal class _MonitoredSession + { + } +} diff --git a/src/TensorFlowNET.Core/Training/_OptimizableVariable.cs b/src/TensorFlowNET.Core/Training/_OptimizableVariable.cs new file mode 100644 index 000000000..86d53fe64 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/_OptimizableVariable.cs @@ -0,0 +1,8 @@ +namespace Tensorflow +{ + public interface _OptimizableVariable + { + Tensor target(); + Operation update_op(Optimizer optimizer, Tensor g); + } +} diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs new file mode 100644 index 000000000..6b607e853 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -0,0 +1,687 @@ +using Google.Protobuf; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO.Compression; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Keras; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; + +namespace Tensorflow.Training +{ + public class NoDependency + { + public Trackable Value { get; set; } + public NoDependency(Trackable value) + { + Value = value; + } + } + + static class TrackableWrapperUtils + { + internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto) + { + if (proto.Identifier != wrapper.Identifier) + { + return false; + } + if (wrapper.Version < proto.Version.MinConsumer) + { + return false; + } + if (proto.Version.Producer < wrapper.MinProducerVersion) + { + return false; + } + foreach (var bad_version in proto.Version.BadConsumers) + { + if (bad_version == wrapper.Version) + { + return false; + } + } + return true; + } + + internal static bool is_function(Trackable x) + { + return x is Function or ConcreteFunction; + } + } + + public interface ITrackableWrapper + { + void SetValue(object name, object value); + String Identifier { get; } + int Version { get; } + int MinConsumerVersion { get; } + int MinProducerVersion { get; } + Trackable FromProto(SavedUserObject proto); + } + + public abstract class TrackableDataStructure : Trackable + { + private bool _self_trainable; + private List _self_extra_variables; + + public TrackableDataStructure() + { + _self_trainable = true; + _self_extra_variables = new List(); + } + + public abstract ICollection Values { get; } + public bool Trainable { get => _self_trainable; set => _self_trainable = value; } + public IEnumerable Layers + { + get + { + List collected = new(); + foreach(var obj in Values) + { + if(obj is ILayer) + { + collected.Add((ILayer)obj); + } + else if(obj is TrackableDataStructure) + { + collected.AddRange((obj as TrackableDataStructure).Layers); + } + } + return collected; + } + } + public IEnumerable TrainableWeights + { + get + { + if (!_self_trainable) + { + return new List(); + } + List trainable_variables = new(); + foreach (var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + foreach(var v in _self_extra_variables) + { + if (v.Trainable) + { + trainable_variables.Add(v); + } + } + return trainable_variables; + } + } + public IEnumerable NonTrainableWeights + { + get + { + var trainable_extra_variables = _self_extra_variables.Where(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.Where(x => !x.Trainable).ToList(); + List non_trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); + } + } + + if (!_self_trainable) + { + // Return order is all trainable vars, then all non-trainable vars. + List trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); + } + else + { + return non_trainable_variables.concat(non_trainable_extra_variables); + } + } + } + public IEnumerable Weights => TrainableWeights.Concat(NonTrainableWeights); + public IEnumerable TrainableVariables => TrainableWeights; + public IEnumerable NonTrainableVariables => NonTrainableWeights; + public IEnumerable Variables => Weights; + + // TODO: `losses` property. + + /// + /// Add a dependency on `value`. + /// + /// + /// + protected virtual Trackable _track_value(Trackable value, string name) + { + value = (Trackable)sticky_attribute_assignment(this, name, value); + if(value is IVariableV1) + { + _self_extra_variables.Add(value as IVariableV1); + } + // skip the left process (need to be done in the future). + return value; + } + + public static Trackable wrap_or_unwrap(NoDependency value) + { + return value.Value; + } + + public static object wrap_or_unwrap(object value) + { + if(value is NoDependency dependency) + { + return dependency.Value; + } + if(value is Trackable trackable) + { + return trackable; + } + else if(value is IDictionary obj_dict) + { + return new DictWrapper(obj_dict); + } + else if(value is IList list) + { + return new ListWrapper(list); + } + else + { + return value; + } + } + + public static object sticky_attribute_assignment(Trackable trackable, string name, object value) + { + bool add_dependency = value is not NoDependency; + value = wrap_or_unwrap(value); + if (!add_dependency) + { + return value; + } + if(value is Trackable trackable_obj) + { + trackable._track_trackable(trackable_obj, name, true); + } + return value; + } + } + // TODO(Rinne): Add Dict wrapper and Tuple wrapper + + public class DictWrapper : TrackableDataStructure, IDictionary, ICloneable, ITrackableWrapper + { + private IDictionary _storage; + private bool _non_string_key; + private bool _external_modification; + private IDictionary _last_wrapped_dict_snapshot; + + public DictWrapper(IDictionary wrapped_dict = null) + { + if(wrapped_dict is not null) + { + _storage = new Dictionary(wrapped_dict); + } + else + { + _storage = new Dictionary(); + } + _update_snapshot(); + } + + public void SetValue(object name, object value) + { + Debug.Assert(value is Trackable); + this[name] = value as Trackable; + } + public String Identifier => "trackable_dict_wrapper"; + public int Version => 1; + public int MinConsumerVersion => 1; + public int MinProducerVersion => 1; + public Trackable FromProto(SavedUserObject proto) + { + return new DictWrapper(new Dictionary()); + } + + public Trackable this[object key] + { + get + { + return _storage[key]; + } + set + { + _check_self_external_modification(); + _maybe_initialize_trackable(); + bool no_dep = value is NoDependency; + if(key is string) + { + value = _track_value(value, key); + } + else + { + value = (Trackable)wrap_or_unwrap(value); + if(!no_dep && value is Trackable) + { + _non_string_key = true; + } + } + _storage[key] = value; + _update_snapshot(); + } + } + + public ICollection Keys => _storage.Keys; + + public override ICollection Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray(); + + public void Add(object key, Trackable value) + { + _storage[key] = value; + } + + public bool ContainsKey(object key) + { + return _storage.ContainsKey(key); + } + + public bool Remove(object key) + { + _check_self_external_modification(); + var res = _storage.Remove(key); + _update_snapshot(); + return res; + } + + public bool TryGetValue(object key, out Trackable value) + { + return _storage.TryGetValue(key, out value); + } + + public int Count => _storage.Count; + + public bool IsReadOnly => _storage.IsReadOnly; + + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + public void Clear() + { + _storage.Clear(); + _update_snapshot(); + } + + public bool Contains(KeyValuePair item) + { + return _storage.Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _storage.CopyTo(array, arrayIndex); + } + + public bool Remove(KeyValuePair item) + { + _check_self_external_modification(); + var res = Remove(item); + _update_snapshot(); + return res; + } + + public IEnumerator> GetEnumerator() + { + return _storage.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + public object Clone() + { + var copied = new DictWrapper(_storage); + copied._external_modification = _external_modification; + copied._non_string_key = _non_string_key; + return copied; + } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + _check_self_external_modification(); + if (_non_string_key) + { + throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" + + $"automatically on attribute assignment). The wrapped dictionary " + + $"contains a non-string key which maps to a trackable object or " + + $"mutable data structure.\n\nIf you don't need this dictionary " + + $"checkpointed, wrap it in a non-trackable " + + $"object; it will be subsequently ignored."); + } + if (_external_modification) + { + throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " + + $"automatically on attribute assignment). The wrapped dictionary was " + + $"modified outside the wrapper (its final value was {this}, its value" + + $" when a checkpoint dependency was added was " + + $"{this._last_wrapped_dict_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this " + + $"dictionary checkpointed, wrap it in a " + + $"non-trackable object; it will be subsequently ignored."); + } + Debug.Assert(!Dirty); + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + foreach(var item in _storage) + { + var key = item.Key; + var value = item.Value; + if (TrackableWrapperUtils.is_function(value)) + { + Debug.Assert(key is string); + children[key as string] = value; + } + } + } + + return children; + } + + protected Trackable _track_value(Trackable value, object name) + { + bool string_key = name is string; + if (!string_key) + { + name = "-non_string_key"; + } + try + { + bool no_dependency = value is NoDependency; + value = base._track_value(value, name as string); + if(!(string_key || no_dependency)) + { + _non_string_key = true; + } + return value; + } + catch (ValueError) + { + return (Trackable)sticky_attribute_assignment(this, name as string, value); + } + } + + private bool Dirty => _external_modification || _non_string_key; + + private void _check_self_external_modification() + { + if (Dirty) + { + return; + } + if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot)) + { + _external_modification = true; + _last_wrapped_dict_snapshot = null; + } + } + + private void _update_snapshot() + { + // TODO(Rinne): deal with attribute_sentinel. + if (Dirty) return; + _last_wrapped_dict_snapshot = new Dictionary(_storage); + } + } + public class ListWrapper : TrackableDataStructure, IList, ICloneable, ITrackableWrapper + { + private IList _storage; + private bool _non_append_mutation_value; + private bool _external_modification_value; + private IList _last_wrapped_list_snapshot; + /// + /// + /// + /// The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be + /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save. + public ListWrapper(IList wrapped_list) + { + _storage = new List(wrapped_list); + _non_append_mutation_value = _external_modification_value = false; + _last_wrapped_list_snapshot = new List(_storage); + } + + public string Identifier => "trackable_list_wrapper"; + public int Version => 1; + public int MinConsumerVersion => 1; + public int MinProducerVersion => 1; + public Trackable FromProto(SavedUserObject proto) + { + if(TrackableWrapperUtils.ShouldLoad(this, proto)) + { + return new ListWrapper(new Trackable[] { }); + } + else + { + return null; + } + } + public void SetValue(object name, object value) + { + Debug.Assert(name is string); + if(int.TryParse(name as string, out var index)) + { + if(value is not Trackable trackable) + { + throw new TypeError("Cannot set an object which is not trackable to ListWrapper."); + } + if(Count <= index) + { + Add(trackable); + } + else + { + this[index] = trackable; + } + } + else + { + throw new NotImplementedException("Encounter an unexpected behavior in , please " + + "submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + + protected bool NonAppendMuation { + get => _non_append_mutation_value; + set + { + // TODO: deal with `attribute_sentinel`. + _non_append_mutation_value = value; + } + } + + protected bool ExternalModification + { + get => _external_modification_value; + set + { + // TODO: deal with `attribute_sentinel`. + _external_modification_value = value; + } + } + + public override ICollection Values => this; + public bool IsReadOnly { get => _storage.IsReadOnly; } + + /// + /// Checks for any changes to the wrapped list not through the wrapper. + /// + private void check_external_modification() + { + if (_external_modification_value || _non_append_mutation_value) return; + if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) + { + _external_modification_value = true; + } + } + + private void update_snapshot() + { + // TODO(Rinne): deal with `attribute_sentinel`. + if (_external_modification_value || _non_append_mutation_value) return; + _last_wrapped_list_snapshot = new List(_storage); + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + check_external_modification(); + if (_non_append_mutation_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + + $", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + + $"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); + } + if (_external_modification_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + + $"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); + } + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + children = children.Concat(this.Where(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + } + + return children; + } + + private bool has_mutation_or_trackable() + { + return _non_append_mutation_value; + } + + /// + /// Allows storage of non-trackable objects. + /// + /// + /// + /// + protected override Trackable _track_value(Trackable value, string name) + { + try + { + base._track_value(value, name); + } + catch(ValueError) + { + value = (Trackable)sticky_attribute_assignment(this, name, value); + } + return value; + } + + public object Clone() + { + var res = new ListWrapper(_storage.Select(x => x).ToList()); + res.NonAppendMuation= _non_append_mutation_value; + res.ExternalModification = _external_modification_value; + return res; + } + + public Trackable this[int index] { + get => _storage[index]; + set + { + // skip the process of `Slice`, maybe support it in the future. + _non_append_mutation_value = true; + _storage[index] = _track_value(value, _name_element(index)); + + update_snapshot(); + } + } + + public int IndexOf(Trackable item) => _storage.IndexOf(item); + + public void Insert(int index, Trackable item) + { + check_external_modification(); + _non_append_mutation_value = true; + _storage.Insert(index, item); + update_snapshot(); + } + + public void RemoveAt(int index) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + _storage.RemoveAt(index); + update_snapshot(); + } + + public int Count { get => _storage.Count; } + + public void Add(Trackable item) + { + check_external_modification(); + _storage.Add(item); + update_snapshot(); + } + + public void Clear() + { + _storage.Clear(); + update_snapshot(); + } + + public bool Contains(Trackable item) => _storage.Contains(item); + + public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); + + public bool Remove(Trackable item) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + var res = _storage.Remove(item); + update_snapshot(); + return res; + } + + public IEnumerator GetEnumerator() => _storage.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + protected string _name_element(int index) => $"{index}"; + } +} diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs new file mode 100644 index 000000000..df7dd9e65 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -0,0 +1,59 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_training_ops + { + public static Tensor resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, + Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, + bool use_locking = false, bool use_nesterov = false, string name = null) + => tf.Context.ExecuteOp("ResourceApplyAdam", name, + new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + .SetAttributes(new { use_locking, use_nesterov })); + + public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, + Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, + bool use_locking = false, bool use_nesterov = false, string name = null) + => tf.Context.ExecuteOp("ApplyAdam", name, + new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + .SetAttributes(new { use_locking, use_nesterov })); + + public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name, new + { + var, + alpha, + delta, + use_locking + }); + + return _op.output; + } + + public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) + => tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, + new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); + + public static Tensor resource_apply_keras_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool use_locking = false, bool use_nesterov = false, string name = null) + => tf.Context.ExecuteOp("ResourceApplyKerasMomentum", name, + new ExecuteOpArgs(var, accum, lr, grad, momentum).SetAttributes(new { use_locking, use_nesterov })); + } +} diff --git a/src/TensorFlowNET.Core/Training/learning_rate_decay.cs b/src/TensorFlowNET.Core/Training/learning_rate_decay.cs new file mode 100644 index 000000000..10259cb61 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/learning_rate_decay.cs @@ -0,0 +1,25 @@ +using System; + +namespace Tensorflow.Training +{ + public class learning_rate_decay + { + /// + /// Applies a polynomial decay to the learning rate. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, + float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, + string name = null) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/moving_averages.cs b/src/TensorFlowNET.Core/Training/moving_averages.cs new file mode 100644 index 000000000..f9937482f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/moving_averages.cs @@ -0,0 +1,29 @@ +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class moving_averages + { + /// + /// Compute the moving average of a variable. + /// + /// + /// + /// + /// + /// + /// + public static Tensor assign_moving_average(IVariableV1 variable, IVariableV1 value, Tensor decay, + bool zero_debias = true, string name = null) + { + return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope => + { + decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); + if (decay.dtype != variable.dtype.as_base_dtype()) + decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); + + return state_ops.assign_sub(variable, (variable.AsTensor() - value.AsTensor()) * decay, name: scope); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/optimizer.py.cs b/src/TensorFlowNET.Core/Training/optimizer.py.cs new file mode 100644 index 000000000..115af5747 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/optimizer.py.cs @@ -0,0 +1,95 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Framework; + +namespace Tensorflow +{ + public class optimizer + { + public static _OptimizableVariable _get_processor(RefVariable v) + { + return new _RefVariableProcessor(v); + } + + public static _OptimizableVariable _get_processor(ResourceVariable v) + { + return new _DenseResourceVariableProcessor(v); + } + } + + public class _RefVariableProcessor : _OptimizableVariable + { + private RefVariable _v; + + public _RefVariableProcessor(RefVariable v) + { + _v = v; + } + + public Tensor target() + { + return _v._ref(); + } + + public Operation update_op(Optimizer optimizer, Tensor g) + { + Operation update_op = null; + + if (g.Tag == null) + { + update_op = optimizer._apply_dense(g, _v); + } + else if (g.Tag is IndexedSlices) + { + return optimizer._apply_sparse_duplicate_indices(g, _v); + } + + return update_op; + } + } + + public class _DenseResourceVariableProcessor : _OptimizableVariable + { + private ResourceVariable _v; + + public _DenseResourceVariableProcessor(ResourceVariable v) + { + _v = v; + } + + public Tensor target() + { + return _v.Handle; + } + + public Operation update_op(Optimizer optimizer, Tensor g) + { + Operation update_op = null; + + if (g.Tag == null) + { + update_op = optimizer._apply_dense(g, _v); + } + else if (g.Tag is IndexedSlices) + { + return optimizer._apply_sparse_duplicate_indices(g, _v); + } + + return update_op; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/Arrays.cs b/src/TensorFlowNET.Core/Util/Arrays.cs new file mode 100644 index 000000000..bdf588bad --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Arrays.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + public static class Arrays + { + public static Type ResolveElementType(this Array arr) + { + if (arr == null) + throw new ArgumentNullException(nameof(arr)); + + var t = arr.GetType().GetElementType(); + // ReSharper disable once PossibleNullReferenceException + while (t.IsArray) + t = t.GetElementType(); + + return t; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/CmdHelper.cs b/src/TensorFlowNET.Core/Util/CmdHelper.cs new file mode 100644 index 000000000..9e9fb81f6 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/CmdHelper.cs @@ -0,0 +1,50 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; + +namespace Tensorflow.Util +{ + public static class CmdHelper + { + public static void Command(string command) + { + Process proc = new System.Diagnostics.Process(); + proc.StartInfo.FileName = @"C:\Windows\System32\cmd.exe"; + proc.StartInfo.Arguments = "/c \" " + command + " \""; + proc.StartInfo.UseShellExecute = false; + proc.StartInfo.RedirectStandardOutput = true; + proc.Start(); + + while (!proc.StandardOutput.EndOfStream) + Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); + } + + public static void Bash(string command) + { + Process proc = new System.Diagnostics.Process(); + proc.StartInfo.FileName = "/bin/bash"; + proc.StartInfo.Arguments = "-c \" " + command + " \""; + proc.StartInfo.UseShellExecute = false; + proc.StartInfo.RedirectStandardOutput = true; + proc.Start(); + + while (!proc.StandardOutput.EndOfStream) + Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); + } + } +} diff --git a/src/TensorFlowNET.Core/Util/Converts.cs b/src/TensorFlowNET.Core/Util/Converts.cs new file mode 100644 index 000000000..bfc7dd138 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Converts.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + public class Converts + { + + } +} diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs new file mode 100644 index 000000000..fe3466ed0 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -0,0 +1,78 @@ +using OneOf; +using Tensorflow.NumPy; + +namespace Tensorflow.Util +{ + /// + /// ValidationDataPack is used to pass validation data to fit method. + /// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. + /// + public class ValidationDataPack + { + internal OneOf val_x; + internal NDArray val_y; + internal NDArray val_sample_weight = null; + public bool val_x_is_array = false; + public ValidationDataPack((NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1; + this.val_y = validation_data.Item2; + } + + public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1; + this.val_y = validation_data.Item2; + this.val_sample_weight = validation_data.Item3; + } + + public ValidationDataPack((IEnumerable, NDArray) validation_data) + { + this.val_x = validation_data.Item1.ToArray(); + this.val_y = validation_data.Item2; + val_x_is_array = true; + } + + public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + { + this.val_x = validation_data.Item1.ToArray(); + this.val_y = validation_data.Item2; + this.val_sample_weight = validation_data.Item3; + val_x_is_array = true; + } + + public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((IEnumerable, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public static implicit operator ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) + => new ValidationDataPack(validation_data); + + public void Deconstruct(out NDArray val_x, out NDArray val_y) + { + val_x = this.val_x.AsT0; + val_y = this.val_y; + } + + public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) + { + val_x = this.val_x.AsT0; + val_y = this.val_y; + val_sample_weight = this.val_sample_weight; + } + + // add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) + public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) + { + val_x_array = this.val_x.AsT1; + val_y = this.val_y; + val_sample_weight = this.val_sample_weight; + unuse = null; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/Locks.cs b/src/TensorFlowNET.Core/Util/Locks.cs new file mode 100644 index 000000000..3b54ee2ce --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Locks.cs @@ -0,0 +1,21 @@ +using System.Threading; + +namespace Tensorflow.Util +{ + /// + /// Provides a set of locks on different shared levels. + /// + public static class Locks + { + private static readonly ThreadLocal _lockpool = new ThreadLocal(() => new object()); + + /// + /// A seperate lock for every requesting thread. + /// + /// This property is thread-safe. + public static object ThreadWide => _lockpool.Value; + + + public static readonly object ProcessWide = new object(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Util/ProtoUtils.cs b/src/TensorFlowNET.Core/Util/ProtoUtils.cs new file mode 100644 index 000000000..c1557da42 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/ProtoUtils.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + internal static class ProtoUtils + { + public static object GetSingleAttrValue(AttrValue value, AttrValue.ValueOneofCase valueCase) + { + return valueCase switch + { + AttrValue.ValueOneofCase.S => value.S.ToStringUtf8(), + AttrValue.ValueOneofCase.I => value.I, + AttrValue.ValueOneofCase.F => value.F, + AttrValue.ValueOneofCase.B => value.B, + AttrValue.ValueOneofCase.Type => value.Type, + AttrValue.ValueOneofCase.Shape => value.Shape, + AttrValue.ValueOneofCase.Tensor => value.Tensor, + AttrValue.ValueOneofCase.Func => value.Func, + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/SafeHandleArrayMarshaler.cs b/src/TensorFlowNET.Core/Util/SafeHandleArrayMarshaler.cs new file mode 100644 index 000000000..74846d4f7 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/SafeHandleArrayMarshaler.cs @@ -0,0 +1,132 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; + +namespace Tensorflow.Util +{ + internal sealed class SafeHandleArrayMarshaler : ICustomMarshaler + { + private static readonly SafeHandleArrayMarshaler Instance = new SafeHandleArrayMarshaler(); + + private SafeHandleArrayMarshaler() + { + } + +#pragma warning disable IDE0060 // Remove unused parameter (method is used implicitly) + public static ICustomMarshaler GetInstance(string cookie) +#pragma warning restore IDE0060 // Remove unused parameter + { + return Instance; + } + + public int GetNativeDataSize() + { + return IntPtr.Size; + } + + [HandleProcessCorruptedStateExceptions] + public IntPtr MarshalManagedToNative(object ManagedObj) + { + if (ManagedObj is null) + return IntPtr.Zero; + + var array = (SafeHandle[])ManagedObj; + var native = IntPtr.Zero; + var marshaledArrayHandle = false; + try + { + native = Marshal.AllocHGlobal((array.Length + 1) * IntPtr.Size); + Marshal.WriteIntPtr(native, GCHandle.ToIntPtr(GCHandle.Alloc(array))); + marshaledArrayHandle = true; + + var i = 0; + var success = false; + try + { + for (i = 0; i < array.Length; i++) + { + success = false; + var current = array[i]; + var currentHandle = IntPtr.Zero; + if (current is object) + { + current.DangerousAddRef(ref success); + currentHandle = current.DangerousGetHandle(); + } + + Marshal.WriteIntPtr(native, ofs: (i + 1) * IntPtr.Size, currentHandle); + } + + return IntPtr.Add(native, IntPtr.Size); + } + catch + { + // Clean up any handles which were leased prior to the exception + var total = success ? i + 1 : i; + for (var j = 0; j < total; j++) + { + var current = array[i]; + if (current is object) + current.DangerousRelease(); + } + + throw; + } + } + catch + { + if (native != IntPtr.Zero) + { + if (marshaledArrayHandle) + GCHandle.FromIntPtr(Marshal.ReadIntPtr(native)).Free(); + + Marshal.FreeHGlobal(native); + } + + throw; + } + } + + public void CleanUpNativeData(IntPtr pNativeData) + { + if (pNativeData == IntPtr.Zero) + return; + + var managedHandle = GCHandle.FromIntPtr(Marshal.ReadIntPtr(pNativeData, -IntPtr.Size)); + var array = (SafeHandle[])managedHandle.Target; + managedHandle.Free(); + + for (var i = 0; i < array.Length; i++) + { + if (array[i] is object && !array[i].IsClosed) + array[i].DangerousRelease(); + } + } + + public object MarshalNativeToManaged(IntPtr pNativeData) + { + throw new NotSupportedException(); + } + + public void CleanUpManagedData(object ManagedObj) + { + throw new NotSupportedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Util/SafeHandleExtensions.cs b/src/TensorFlowNET.Core/Util/SafeHandleExtensions.cs new file mode 100644 index 000000000..6594b0b59 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/SafeHandleExtensions.cs @@ -0,0 +1,59 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Tensorflow.Util +{ + internal static class SafeHandleExtensions + { + /// + /// Acquires a lease on a safe handle. The lease increments the reference count of the + /// to ensure the handle is not released prior to the lease being released. + /// + /// + /// This method is intended to be used in the initializer of a using statement. Failing to release the + /// lease will permanently prevent the underlying from being released by the garbage + /// collector. + /// + /// The to lease. + /// A , which must be disposed to release the resource. + /// If the lease could not be acquired. + public static SafeHandleLease Lease(this SafeHandle handle) + { + if (handle is null) + throw new ArgumentNullException(nameof(handle)); + + var success = false; + try + { + handle.DangerousAddRef(ref success); + Debug.Assert(success, $"'{nameof(SafeHandle.DangerousAddRef)}' does not return when '{nameof(success)}' is false."); + + return new SafeHandleLease(handle); + } + catch + { + if (success) + handle.DangerousRelease(); + + throw; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Util/SafeHandleLease.cs b/src/TensorFlowNET.Core/Util/SafeHandleLease.cs new file mode 100644 index 000000000..19f4ec57e --- /dev/null +++ b/src/TensorFlowNET.Core/Util/SafeHandleLease.cs @@ -0,0 +1,46 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow.Util +{ + /// + /// Represents a lease of a . + /// + /// + /// + /// Elements in this section may be referenced by <inheritdoc> elements to provide common + /// language in documentation remarks. + /// + /// + /// The result of this method is only valid when the underlying handle has not been disposed. If the lifetime + /// of the object is unclear, a lease may be used to prevent disposal while the object is in use. See + /// . + /// + /// + public readonly struct SafeHandleLease : IDisposable + { + private readonly SafeHandle _handle; + + internal SafeHandleLease(SafeHandle handle) + => _handle = handle; + + public void Dispose() + => _handle?.DangerousRelease(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs b/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs new file mode 100644 index 000000000..a3f5dfed2 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/SafeTensorflowHandle.cs @@ -0,0 +1,46 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow.Util +{ + public abstract class SafeTensorflowHandle : SafeHandle + { + private protected SafeTensorflowHandle() + : base(IntPtr.Zero, ownsHandle: true) + { + } + + private protected SafeTensorflowHandle(IntPtr handle) + : base(IntPtr.Zero, ownsHandle: true) + { + SetHandle(handle); + } + + private protected SafeTensorflowHandle(IntPtr handle, bool ownsHandle) + : base(IntPtr.Zero, ownsHandle) + { + SetHandle(handle); + } + + public override bool IsInvalid => handle == IntPtr.Zero; + + public override string ToString() + => $"0x{handle.ToString("x16")}"; + } +} diff --git a/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs new file mode 100644 index 000000000..5add8cada --- /dev/null +++ b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs @@ -0,0 +1,57 @@ +using System; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Tensorflow.Util +{ + public static class UnmanagedExtensions + { + //internally UnmanagedMemoryStream can't construct with null address. + private static readonly unsafe byte* _empty = (byte*)Marshal.AllocHGlobal(1); + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + // ReSharper disable once AssignNullToNotNullAttribute + return new UnmanagedMemoryStream((byte*)address, length); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// Offset from the start of the block. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + return new UnmanagedMemoryStream((byte*)address + offset, length); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Util/UnorderedMap.cs b/src/TensorFlowNET.Core/Util/UnorderedMap.cs new file mode 100644 index 000000000..219a3c140 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/UnorderedMap.cs @@ -0,0 +1,87 @@ +using System.Collections.Generic; + +namespace Tensorflow.Util +{ + public class UnorderedMap : Dictionary + { + /// + /// Avoid null when accessing not existed element + /// + /// + /// + public new Tv this[Tk key] + { + get + { + if (!ContainsKey(key)) + Add(key, default); + + return base[key]; + } + + set + { + base[key] = value; + } + } + + public Tv SetDefault(Tk key, Tv default_value) + { + if(TryGetValue(key, out var res)) + { + return res; + } + else + { + base[key] = default_value; + return base[key]; + } + } + + public void push_back(Tk key, Tv value) + => this[key] = value; + + public void emplace(Tk key, Tv value) + => this[key] = value; + + public bool find(Tk key) + => ContainsKey(key); + + public void erase(Tk key) + => Remove(key); + + public bool find(Tk key, out Tv value) + { + if (ContainsKey(key)) + { + value = this[key]; + return true; + } + else + { + value = default(Tv); + return false; + } + } + } + + public class UnorderedMapEnumerable : UnorderedMap + where Tv : new() + { + public new Tv this[Tk key] + { + get + { + if (!ContainsKey(key)) + Add(key, new Tv()); + + return base[key]; + } + + set + { + base[key] = value; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Util/UnorderedSet.cs b/src/TensorFlowNET.Core/Util/UnorderedSet.cs new file mode 100644 index 000000000..95f936b00 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/UnorderedSet.cs @@ -0,0 +1,16 @@ +using System.Collections.Generic; + +namespace Tensorflow.Util +{ + public class UnorderedSet : HashSet + { + public UnorderedSet(T[] elements) + { + foreach (var el in elements) + Add(el); + } + + public bool find(T value) + => Contains(value); + } +} diff --git a/src/TensorFlowNET.Core/Util/function_utils.cs b/src/TensorFlowNET.Core/Util/function_utils.cs new file mode 100644 index 000000000..d4ba44237 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/function_utils.cs @@ -0,0 +1,23 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + internal static class function_utils + { + private static ByteString _rewriter_config_optimizer_disabled; + public static ByteString get_disabled_rewriter_config() + { + if(_rewriter_config_optimizer_disabled is null) + { + var config = new ConfigProto(); + var rewriter_config = config.GraphOptions.RewriteOptions; + rewriter_config.DisableMetaOptimizer = true; + _rewriter_config_optimizer_disabled = config.ToByteString(); + } + return _rewriter_config_optimizer_disabled; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs new file mode 100644 index 000000000..3ba3ce78b --- /dev/null +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -0,0 +1,1009 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Util +{ + //Functions for working with arbitrarily nested sequences of elements. + + //This module can perform operations on nested structures. A nested structure is a + //Python sequence, tuple (including `namedtuple`), or dict that can contain + //further sequences, tuples, and dicts. + + //The utilities here assume (and do not check) that the nested structures form a + //'tree', i.e., no references in the structure of the input of these functions + //should be recursive. + + //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), + // (np.array([3, 4]), tf.constant([3, 4])))` + // + + [Obsolete] + public static class nest + { + + + /// + /// Untyped implementation of zip for arbitrary data + /// + /// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays + /// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]] + /// + /// one or multiple sequences to be zipped + /// + public static IEnumerable zip_many(params IEnumerable[] lists) + { + if (lists.Length == 0) + yield break; + var first = lists[0]; + if (first == null) + yield break; + var arity = first.Count(); + for (int i = 0; i < arity; i++) + { + var array = new object[lists.Length]; + for (int j = 0; j < lists.Length; j++) + array[j] = GetSequenceElementAt(lists[j], i); + yield return array; + } + } + + private static object GetSequenceElementAt(object sequence, int i) + { + switch (sequence) + { + case Array array: + return array.GetValue(i); + case IList list: + return list[i]; + default: + return _yield_value(sequence).Skip(Math.Max(0, i)).FirstOrDefault(); + } + } + + public static IEnumerable<(T1, T2)> zip(IEnumerable e1, IEnumerable e2) + => zip(e1, e2); + + public static Dictionary ConvertToDict(object dyn) + => ConvertToDict(dyn); + + //def _get_attrs_values(obj): + // """Returns the list of values from an attrs instance.""" + // attrs = getattr(obj.__class__, "__attrs_attrs__") + // return [getattr(obj, a.name) for a in attrs] + + /// + /// Returns a sorted list of the dict keys, with error if keys not sortable. + /// + private static IEnumerable _sorted(IDictionary dict_) + { + return dict_.Keys.OfType().OrderBy(x => x); + } + + + //def _is_namedtuple(instance, strict=False): + // """Returns True iff `instance` is a `namedtuple`. + + // Args: + // instance: An instance of a Python object. + // strict: If True, `instance` is considered to be a `namedtuple` only if + // it is a "plain" namedtuple. For instance, a class inheriting + // from a `namedtuple` will be considered to be a `namedtuple` + // iff `strict=False`. + + // Returns: + // True if `instance` is a `namedtuple`. + // """ + // return _pywrap_tensorflow.IsNamedtuple(instance, strict) + + + //# See the swig file (util.i) for documentation. + //_is_mapping = _pywrap_tensorflow.IsMapping + //_is_attrs = _pywrap_tensorflow.IsAttrs + + /// + /// Converts the sequence `args` to the same type as `instance`. + /// + /// an instance of `tuple`, `list`, `namedtuple`, `dict`, or + /// `collections.OrderedDict`. + /// elements to be converted to the `instance` type. + /// `args` with the type of `instance`. + private static object _sequence_like(object instance, IEnumerable args) + { + if (is_mapping(instance)) + { + //# Pack dictionaries in a deterministic order by sorting the keys. + //# Notice this means that we ignore the original order of `OrderedDict` + //# instances. This is intentional, to avoid potential bugs caused by mixing + //# ordered and plain dicts (e.g., flattening a dict but using a + //# corresponding `OrderedDict` to pack it back). + switch (instance) + { + case Hashtable hash: + { + var result = new Hashtable(); + foreach ((object key, object value) in zip(_sorted(hash), args)) + result[key] = value; + return result; + } + } + } + //else if( _is_namedtuple(instance) || _is_attrs(instance)) + // return type(instance)(*args) + else + { + // Not a namedtuple + switch (instance) + { + case object[] array: + var result_array = new object[args.Count()]; + int i = 0; + foreach (var x in args) + { + result_array[i] = x; + i++; + } + return result_array; + case List list: + return new List(args); + default: + throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); + } + } + throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); + } + + /// + /// Yields the next value from the given iterable. + /// + private static IEnumerable _yield_value(object iterable) + { + if (is_mapping(iterable)) + { + var dict = iterable as IDictionary; + //# Iterate through dictionaries in a deterministic order by sorting the + //# keys. Notice this means that we ignore the original order of `OrderedDict` + //# instances. This is intentional, to avoid potential bugs caused by mixing + //# ordered and plain dicts (e.g., flattening a dict but using a + //# corresponding `OrderedDict` to pack it back). + foreach (var key in _sorted(dict)) + yield return dict[key]; + } + //else if (_is_attrs(iterable)) + //{ + // // for value in _get_attrs_values(iterable): + // // yield value + //} + else if (iterable is IEnumerable) + { + var enumerable = iterable as IEnumerable; + foreach (var value in enumerable) + yield return value; + } + else + { + throw new TypeError("Unexpected iterable type: " + iterable.GetType()); + //var jobj = JObject.FromObject(iterable); + //foreach (var key in _sorted()) + // yield return jobj[key]; + } + } + + //# See the swig file (util.i) for documentation. + public static bool is_sequence(object arg) + => arg is IEnumerable && !(arg is string) && !(arg is NDArray) && + !(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>)); + + public static bool is_mapping(object arg) => arg is IDictionary; + + //# See the swig file (util.i) for documentation. + //flatten = _pywrap_tensorflow.Flatten + + public static List flatten(T structure) + { + var list = new List(); + _flatten_recursive(structure, list); + return list; + } + + public static List flatten(IEnumerable structure) + { + var list = new List(); + foreach(var item in structure) + { + _flatten_recursive(item, list); + } + return list; + } + + public static object[] flatten2(ICanBeFlattened structure) + => structure.Flatten(); + + public static T[] flatten2(T[] structure) + => structure; + + private static void _flatten_recursive(T obj, List list) + { + switch (obj) + { + case IDictionary dict: + foreach (var key in _sorted(dict)) + _flatten_recursive((T)dict[key], list); + break; + case String str: + list.Add(obj); + break; + case NDArray nd: + list.Add(obj); + break; + case IEnumerable structure: + foreach (var child in structure) + _flatten_recursive((T)child, list); + break; + default: + list.Add(obj); + break; + } + } + + + //# See the swig file (util.i) for documentation. + //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples + + + //class _DotString(object): + + // def __str__(self): + // return "." + + // def __repr__(self): + // return "." + + + //_DOT = _DotString() + + + //def assert_same_structure(nest1, nest2, check_types=True): + // """Asserts that two structures are nested in the same way. + + // Note that namedtuples with identical name and fields are always considered + // to have the same shallow structure (even with `check_types=True`). + // For intance, this code will print `True`: + + // ```python + // def nt(a, b): + // return collections.namedtuple('foo', 'a b')(a, b) + // print(assert_same_structure(nt(0, 1), nt(2, 3))) + // ``` + + // Args: + // nest1: an arbitrarily nested structure. + // nest2: an arbitrarily nested structure. + // check_types: if `True` (default) types of sequences are checked as well, + // including the keys of dictionaries. If set to `False`, for example a + // list and a tuple of objects will look the same if they have the same + // size. Note that namedtuples with identical name and fields are always + // considered to have the same shallow structure. Two types will also be + // considered the same if they are both list subtypes (which allows "list" + // and "_ListWrapper" from checkpointable dependency tracking to compare + // equal). + + // Raises: + // ValueError: If the two structures do not have the same number of elements or + // if the two structures are not nested in the same way. + // TypeError: If the two structures differ in the type of sequence in any of + // their substructures. Only possible if `check_types` is `True`. + // """ + // try: + // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) + // except (ValueError, TypeError) as e: + // str1 = str(map_structure(lambda _: _DOT, nest1)) + // str2 = str(map_structure(lambda _: _DOT, nest2)) + // raise type(e)("%s\n" + // "Entire first structure:\n%s\n" + // "Entire second structure:\n%s" + // % (str(e), str1, str2)) + + + //def flatten_dict_items(dictionary): + // """Returns a dictionary with flattened keys and values. + + // This function flattens the keys and values of a dictionary, which can be + // arbitrarily nested structures, and returns the flattened version of such + // structures: + + // ```python + // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} + // result = {4: "a", 5: "b", 6: "c", 8: "d"} + // flatten_dict_items(example_dictionary) == result + // ``` + + // The input dictionary must satisfy two properties: + + // 1. Its keys and values should have the same exact nested structure. + // 2. The set of all flattened keys of the dictionary must not contain repeated + // keys. + + // Args: + // dictionary: the dictionary to zip + + // Returns: + // The zipped dictionary. + + // Raises: + // TypeError: If the input is not a dictionary. + // ValueError: If any key and value have not the same structure, or if keys are + // not unique. + // """ + // if not isinstance(dictionary, (dict, _collections.Mapping)): + // raise TypeError("input must be a dictionary") + // flat_dictionary = {} + // for i, v in _six.iteritems(dictionary): + // if not is_sequence(i): + // if i in flat_dictionary: + // raise ValueError( + // "Could not flatten dictionary: key %s is not unique." % i) + // flat_dictionary[i] = v + // else: + // flat_i = flatten(i) + // flat_v = flatten(v) + // if len(flat_i) != len(flat_v): + // raise ValueError( + // "Could not flatten dictionary. Key had %d elements, but value had " + // "%d elements. Key: %s, value: %s." + // % (len(flat_i), len(flat_v), flat_i, flat_v)) + // for new_i, new_v in zip(flat_i, flat_v): + // if new_i in flat_dictionary: + // raise ValueError( + // "Could not flatten dictionary: key %s is not unique." + // % (new_i)) + // flat_dictionary[new_i] = new_v + // return flat_dictionary + + /// + /// Helper function for pack_sequence_as. + /// + /// Substructure (list / tuple / dict) to mimic. + /// Flattened values to output substructure for. + /// Index at which to start reading from flat. + /// + /// The tuple(new_index, child), where: + /// * new_index - the updated index into `flat` having processed `structure`. + /// * packed - the subset of `flat` corresponding to `structure`, + /// having started at `index`, and packed into the same nested + /// format. + private static (int new_index, List child) _packed_nest_with_indices(object structure, List flat, + int index) + { + var packed = new List(); + foreach (var s in _yield_value(structure)) + { + if (is_sequence(s)) + { + var (new_index, child) = _packed_nest_with_indices(s, flat, index); + packed.Add(_sequence_like(s, child)); + index = new_index; + } + else + { + packed.Add(flat[index]); + index += 1; + } + } + return (index, packed); + } + + private static int len(IEnumerable x) => x.Count(); + + public static T pack_sequence_as2(T structure, object[] flat_sequence, bool expand_composites = false) + where T : IPackable + => structure.Pack(flat_sequence); + + /// + /// Returns a given flattened sequence packed into a given structure. + /// If `structure` is a scalar, `flat_sequence` must be a single-element list; + /// in this case the return value is `flat_sequence[0]`. + /// + /// If `structure` is or contains a dict instance, the keys will be sorted to + /// pack the flat sequence in deterministic order. This is true also for + /// `OrderedDict` instances: their sequence order is ignored, the sorting order of + /// keys is used instead. The same convention is followed in `flatten`. + /// This correctly repacks dicts and `OrderedDict`s after they have been + /// flattened, and also allows flattening an `OrderedDict` and then repacking it + /// back using a corresponding plain dict, or vice-versa. + /// Dictionaries with non-sortable keys cannot be flattened. + /// + /// + /// Nested structure, whose structure is given by nested lists, + /// tuples, and dicts. Note: numpy arrays and strings are considered + /// scalars. + /// + /// flat sequence to pack. + /// `flat_sequence` converted to have the same recursive structure as + /// `structure`. + /// + public static object pack_sequence_as(object structure, IEnumerable flat_sequence, bool expand_composites = false) + { + List flat = null; + if (flat_sequence is List) + flat = flat_sequence as List; + else + flat = new List(flat_sequence); + if (flat_sequence == null) + throw new ArgumentException("flat_sequence must not be null"); + // if not is_sequence(flat_sequence): + // raise TypeError("flat_sequence must be a sequence") + + if (!is_sequence(structure)) + { + if (len(flat) != 1) + throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1"); + return flat.FirstOrDefault(); + } + int final_index = 0; + List packed = null; + try + { + (final_index, packed) = _packed_nest_with_indices(structure, flat, 0); + if (final_index < len(flat)) + throw new IndexOutOfRangeException( + $"Final index: {final_index} was smaller than len(flat_sequence): {len(flat)}"); + return _sequence_like(structure, packed); + } + catch (IndexOutOfRangeException) + { + var flat_structure = flatten(structure); + if (len(flat_structure) != len(flat)) + { + throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " + + $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}"); + } + return _sequence_like(structure, packed); + } + catch (ArgumentOutOfRangeException) + { + var flat_structure = flatten(structure); + if (len(flat_structure) != len(flat)) + { + throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " + + $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}"); + } + return _sequence_like(structure, packed); + } + } + + /// + /// Applies `func` to each entry in `structure` and returns a new structure. + /// + /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in + /// `structure[i]`. All structures in `structure` must have the same arity, + /// and the return value will contain the results in the same structure. + /// + /// A callable that accepts as many arguments as there are structures. + /// one or many IEnumerable of object + /// + /// A new structure with the same arity as `structure`, whose values correspond + /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding + /// location in `structure[i]`. If there are different sequence types and + /// `check_types` is `False` the sequence types of the first structure will be + /// used. + /// + public static IEnumerable map_structure(Func func, params IEnumerable[] structure) + { + // TODO: check structure and types + // for other in structure[1:]: + // assert_same_structure(structure[0], other, check_types=check_types) + + if (structure.Length == 1) + { + // we don't need to zip if we have only one structure + return map_structure(a => func(new object[] { a }), structure[0]); + } + var flat_structures = structure.Select(flatten).ToArray(); // ToArray is important here! + var entries = zip_many(flat_structures); + var mapped_flat_structure = entries.Select(func); + + return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList(); + } + + public static Tensor map_structure(Func func, T structure) + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).ToList(); + + return pack_sequence_as(structure, mapped_flat_structure) as Tensor; + } + + public static T2 map_structure(Func func, T1 structure) where T2: class + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); + + return pack_sequence_as(structure, mapped_flat_structure) as T2; + } + + public static IEnumerable map_structure(Func func, IEnumerable structure) where T2 : class + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); + + return pack_sequence_as(structure, mapped_flat_structure) as IEnumerable; + } + + /// + /// Same as map_structure, but with only one structure (no combining of multiple structures) + /// + /// + /// + /// + public static IEnumerable map_structure(Func func, IEnumerable structure) + { + // TODO: check structure and types + // for other in structure[1:]: + // assert_same_structure(structure[0], other, check_types=check_types) + + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).ToList(); + + return _yield_value(pack_sequence_as(structure, mapped_flat_structure)).ToList(); + } + + //def map_structure_with_paths(func, *structure, **kwargs): + // """Applies `func` to each entry in `structure` and returns a new structure. + + // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in + // `structure[i]` and `path` is the common path to x[i] in the structures. All + // structures in `structure` must have the same arity, and the return value will + // contain the results in the same structure. Special kwarg `check_types` + // determines whether the types of iterables within the structure must be the + // same-- see **kwargs definition below. + + // Args: + // func: A callable with the signature func(path, *values, **kwargs) that is + // evaluated on the leaves of the structure. + // *structure: A variable number of compatible structures to process. + // **kwargs: Optional kwargs to be passed through to func. Special kwarg + // `check_types` is not passed to func, but instead determines whether the + // types of iterables within the structures have to be same (e.g., + // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By + // default, the types must match. To allow iteration over structures of + // different types (but common arity), set this kwarg to `False`. + + // Returns: + // A structure of the same form as the input structures whose leaves are the + // result of evaluating func on corresponding leaves of the input structures. + + // Raises: + // TypeError: If `func` is not callable or if the structures do not match + // each other by depth tree. + // TypeError: If `check_types` is not `False` and the two structures differ in + // the type of sequence in any of their substructures. + // ValueError: If no structures are provided. + // """ + // if not callable(func): + // raise TypeError("func must be callable, got: %s" % func) + // if not structure: + // raise ValueError("Must provide at least one structure") + + // check_types = kwargs.pop("check_types", True) + // for other in structure[1:]: + // assert_same_structure(structure[0], other, check_types=check_types) + + //# First set paths_and_values to: + //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] + // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] + + //# Now zip(*paths_and_values) would be: + //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] + //# so grouped_by_path is set to: + //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] + //# Note that p1i, ... pmi must all be equal since the structures are the same. + // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] + + // return pack_sequence_as(structure[0], [ + // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) + + + //def _yield_flat_up_to(shallow_tree, input_tree): + // """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" + // if is_sequence(shallow_tree): + // for shallow_branch, input_branch in zip(_yield_value(shallow_tree), + // _yield_value(input_tree)): + // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): + // yield input_leaf + // else: + // yield input_tree + + + //def assert_shallow_structure(shallow_tree, input_tree, check_types=True): + // """Asserts that `shallow_tree` is a shallow structure of `input_tree`. + + // That is, this function tests if the `input_tree` structure can be created from + // the `shallow_tree` structure by replacing its leaf nodes with deeper + // tree structures. + + // Examples: + + // The following code will raise an exception: + // ```python + // shallow_tree = ["a", "b"] + // input_tree = ["c", ["d", "e"], "f"] + // assert_shallow_structure(shallow_tree, input_tree) + // ``` + + // The following code will not raise an exception: + // ```python + // shallow_tree = ["a", "b"] + // input_tree = ["c", ["d", "e"]] + // assert_shallow_structure(shallow_tree, input_tree) + // ``` + + // Args: + // shallow_tree: an arbitrarily nested structure. + // input_tree: an arbitrarily nested structure. + // check_types: if `True` (default) the sequence types of `shallow_tree` and + // `input_tree` have to be the same. Note that even with check_types==True, + // this function will consider two different namedtuple classes with the same + // name and _fields attribute to be the same class. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. Only raised if `check_types` is `True`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + // """ + // if is_sequence(shallow_tree): + // if not is_sequence(input_tree): + // raise TypeError( + // "If shallow structure is a sequence, input must also be a sequence. " + // "Input has type: %s." % type(input_tree)) + + // if check_types and not isinstance(input_tree, type(shallow_tree)): + //# Duck-typing means that nest should be fine with two different + //# namedtuples with identical name and fields. + // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) + // input_is_namedtuple = _is_namedtuple(input_tree, False) + // if shallow_is_namedtuple and input_is_namedtuple: + // if not _same_namedtuples(shallow_tree, input_tree): + // raise TypeError( + // "The two namedtuples don't have the same sequence type. Input " + // "structure has type %s, while shallow structure has type %s." + // % (type(input_tree), type(shallow_tree))) + // elif not (isinstance(shallow_tree, _collections.Mapping) + // and isinstance(input_tree, _collections.Mapping)): + // raise TypeError( + // "The two structures don't have the same sequence type. Input " + // "structure has type %s, while shallow structure has type %s." + // % (type(input_tree), type(shallow_tree))) + + // if len(input_tree) != len(shallow_tree): + // raise ValueError( + // "The two structures don't have the same sequence length. Input " + // "structure has length %s, while shallow structure has length %s." + // % (len(input_tree), len(shallow_tree))) + + // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): + // if set(input_tree) != set(shallow_tree): + // raise ValueError( + // "The two structures don't have the same keys. Input " + // "structure has keys %s, while shallow structure has keys %s." % + // (list(_six.iterkeys(input_tree)), + // list(_six.iterkeys(shallow_tree)))) + + // input_tree = list(sorted(_six.iteritems(input_tree))) + // shallow_tree = list(sorted(_six.iteritems(shallow_tree))) + + // for shallow_branch, input_branch in zip(shallow_tree, input_tree): + // assert_shallow_structure(shallow_branch, input_branch, + // check_types=check_types) + + + //def flatten_up_to(shallow_tree, input_tree): + // """Flattens `input_tree` up to `shallow_tree`. + + // Any further depth in structure in `input_tree` is retained as elements in the + // partially flatten output. + + // If `shallow_tree` and `input_tree` are not sequences, this returns a + // single-element list: `[input_tree]`. + + // Use Case: + + // Sometimes we may wish to partially flatten a nested sequence, retaining some + // of the nested structure. We achieve this by specifying a shallow structure, + // `shallow_tree`, we wish to flatten up to. + + // The input, `input_tree`, can be thought of as having the same structure as + // `shallow_tree`, but with leaf nodes that are themselves tree structures. + + // Examples: + + // ```python + // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + // shallow_tree = [[True, True], [False, True]] + + // flattened_input_tree = flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) + + //# Output is: + //# [[2, 2], [3, 3], [4, 9], [5, 5]] + //# [True, True, False, True] + // ``` + + // ```python + // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] + // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] + + // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) + // input_tree_flattened = flatten(input_tree) + + //# Output is: + //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + //# ['a', 1, 'b', 2, 'c', 3, 'd', 4] + // ``` + + // Non-Sequence Edge Cases: + + // ```python + // flatten_up_to(0, 0) # Output: [0] + // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] + // flatten_up_to([0, 1, 2], 0) # Output: TypeError + // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] + // ``` + + // Args: + // shallow_tree: a possibly pruned structure of input_tree. + // input_tree: an arbitrarily nested structure or a scalar object. + // Note, numpy arrays are considered scalars. + + // Returns: + // A Python list, the partially flattened version of `input_tree` according to + // the structure of `shallow_tree`. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + // """ + // assert_shallow_structure(shallow_tree, input_tree) + // return list(_yield_flat_up_to(shallow_tree, input_tree)) + + + //def map_structure_up_to(shallow_tree, func, *inputs): + // """Applies a function or op to a number of partially flattened inputs. + + // The `inputs` are flattened up to `shallow_tree` before being mapped. + + // Use Case: + + // Sometimes we wish to apply a function to a partially flattened + // sequence (for example when the function itself takes sequence inputs). We + // achieve this by specifying a shallow structure, `shallow_tree` we wish to + // flatten up to. + + // The `inputs`, can be thought of as having the same structure as + // `shallow_tree`, but with leaf nodes that are themselves tree structures. + + // This function therefore will return something with the same base structure as + // `shallow_tree`. + + // Examples: + + // ```python + // ab_tuple = collections.namedtuple("ab_tuple", "a, b") + // op_tuple = collections.namedtuple("op_tuple", "add, mul") + // inp_val = ab_tuple(a=2, b=3) + // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, + // inp_val, inp_ops) + + //# Output is: ab_tuple(a=6, b=15) + // ``` + + // ```python + // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + // name_list = ['evens', ['odds', 'primes']] + // out = map_structure_up_to( + // name_list, + // lambda name, sec: "first_{}_{}".format(len(sec), name), + // name_list, data_list) + + //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] + // ``` + + // Args: + // shallow_tree: a shallow tree, common to all the inputs. + // func: callable which will be applied to each input individually. + // *inputs: arbitrarily nested combination of objects that are compatible with + // shallow_tree. The function `func` is applied to corresponding + // partially flattened elements of each input, so the function must support + // arity of `len(inputs)`. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + + // Returns: + // result of repeatedly applying `func`, with same structure as + // `shallow_tree`. + // """ + // if not inputs: + // raise ValueError("Cannot map over no sequences") + // for input_tree in inputs: + // assert_shallow_structure(shallow_tree, input_tree) + + //# Flatten each input separately, apply the function to corresponding elements, + //# then repack based on the structure of the first input. + // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) + // for input_tree in inputs] + // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] + // return pack_sequence_as(structure=shallow_tree, flat_sequence=results) + + + //def get_traverse_shallow_structure(traverse_fn, structure): + // """Generates a shallow structure from a `traverse_fn` and `structure`. + + // `traverse_fn` must accept any possible subtree of `structure` and return + // a depth=1 structure containing `True` or `False` values, describing which + // of the top-level subtrees may be traversed. It may also + // return scalar `True` or `False` "traversal is OK / not OK for all subtrees." + + // Examples are available in the unit tests (nest_test.py). + + // Args: + // traverse_fn: Function taking a substructure and returning either a scalar + // `bool` (whether to traverse that substructure or not) or a depth=1 + // shallow structure of the same type, describing which parts of the + // substructure to traverse. + // structure: The structure to traverse. + + // Returns: + // A shallow structure containing python bools, which can be passed to + // `map_structure_up_to` and `flatten_up_to`. + + // Raises: + // TypeError: if `traverse_fn` returns a sequence for a non-sequence input, + // or a structure with depth higher than 1 for a sequence input, + // or if any leaf values in the returned structure or scalar are not type + // `bool`. + // """ + // to_traverse = traverse_fn(structure) + // if not is_sequence(structure): + // if not isinstance(to_traverse, bool): + // raise TypeError("traverse_fn returned structure: %s for non-structure: %s" + // % (to_traverse, structure)) + // return to_traverse + // level_traverse = [] + // if isinstance(to_traverse, bool): + // if not to_traverse: + //# Do not traverse this substructure at all. Exit early. + // return False + // else: + //# Traverse the entire substructure. + // for branch in _yield_value(structure): + // level_traverse.append( + // get_traverse_shallow_structure(traverse_fn, branch)) + // elif not is_sequence(to_traverse): + // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" + // % (to_traverse, structure)) + // else: + //# Traverse some subset of this substructure. + // assert_shallow_structure(to_traverse, structure) + // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): + // if not isinstance(t, bool): + // raise TypeError( + // "traverse_fn didn't return a depth=1 structure of bools. saw: %s " + // " for structure: %s" % (to_traverse, structure)) + // if t: + // level_traverse.append( + // get_traverse_shallow_structure(traverse_fn, branch)) + // else: + // level_traverse.append(False) + // return _sequence_like(structure, level_traverse) + + + //def yield_flat_paths(nest): + // """Yields paths for some nested structure. + + // Paths are lists of objects which can be str-converted, which may include + // integers or other types which are used as indices in a dict. + + // The flat list will be in the corresponding order as if you called + // `snt.nest.flatten` on the structure. This is handy for naming Tensors such + // the TF scope structure matches the tuple structure. + + // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` + + // ```shell + // >>> nest.flatten(value) + // [3, 23, 42] + // >>> list(nest.yield_flat_paths(value)) + // [('a',), ('b', 'c'), ('b', 'd')] + // ``` + + // ```shell + // >>> list(nest.yield_flat_paths({'a': [3]})) + // [('a', 0)] + // >>> list(nest.yield_flat_paths({'a': 3})) + // [('a',)] + // ``` + + // Args: + // nest: the value to produce a flattened paths list for. + + // Yields: + // Tuples containing index or key values which form the path to a specific + // leaf value in the nested structure. + // """ + + //# The _maybe_add_final_path_element function is used below in order to avoid + //# adding trailing slashes when the sub-element recursed into is a leaf. + // if isinstance(nest, (dict, _collections.Mapping)): + // for key in _sorted(nest): + // value = nest[key] + // for sub_path in yield_flat_paths(value): + // yield (key,) + sub_path + // elif _is_namedtuple(nest): + // for key in nest._fields: + // value = getattr(nest, key) + // for sub_path in yield_flat_paths(value): + // yield (key,) + sub_path + // elif isinstance(nest, _six.string_types): + // yield () + // elif isinstance(nest, _collections.Sequence): + // for idx, value in enumerate(nest): + // for sub_path in yield_flat_paths(value): + // yield (idx,) + sub_path + // else: + // yield () + + + //def flatten_with_joined_string_paths(structure, separator="/"): + // """Returns a list of (string path, data element) tuples. + + // The order of tuples produced matches that of `nest.flatten`. This allows you + // to flatten a nested structure while keeping information about where in the + // structure each data element was located. See `nest.yield_flat_paths` + // for more information. + + // Args: + // structure: the nested structure to flatten. + // separator: string to separate levels of hierarchy in the results, defaults + // to '/'. + + // Returns: + // A list of (string, data element) tuples. + // """ + // flat_paths = yield_flat_paths(structure) + // def stringify_and_join(path_elements): + // return separator.join(str(path_element) for path_element in path_elements) + // flat_string_paths = [stringify_and_join(path) for path in flat_paths] + // return list(zip(flat_string_paths, flatten(structure))) + + + } +} diff --git a/src/TensorFlowNET.Core/Util/variable_utils.cs b/src/TensorFlowNET.Core/Util/variable_utils.cs new file mode 100644 index 000000000..13237f9d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/variable_utils.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; + +namespace Tensorflow.Util +{ + internal static class variable_utils + { + public static Tensor[] convert_variables_to_tensors(object[] values) + { + return values.Select(x => + { + if (resource_variable_ops.is_resource_variable(x)) + { + return ops.convert_to_tensor(x); + } + else if (x is CompositeTensor) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else if(x is Tensor tensor) + { + return tensor; + } + else + { + throw new TypeError("Currently the output of function to be traced must be `Tensor`."); + } + }).ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs new file mode 100644 index 000000000..a54283bd4 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -0,0 +1,379 @@ +using Tensorflow.NumPy; +using System; +using Tensorflow.Eager; +using Tensorflow.Variables; +using Tensorflow.Train; +using static Tensorflow.Binding; +using System.Collections.Generic; +using System.Diagnostics; +using Tensorflow.Checkpoint; +using Tensorflow.Training.Saving.SavedModel; +using OneOf; +using Tensorflow.Graphs; + +namespace Tensorflow +{ + public class BaseResourceVariable : DisposableTrackableObject + { + protected string _name; + public virtual string Name => _handle_name; + public virtual string SharedName + { + get + { + // TODO(Rinne): optimize the implementation with refactor of variable. + return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1); + } + } + protected TF_DataType _dtype; + public TF_DataType dtype => _dtype; + protected string _handle_name; + public string handle_name + { + get { return _handle_name; } + set { _handle_name = value; } + } + + protected string _unique_id; + public string UniqueId => _unique_id; + + protected bool _in_graph_mode; + internal bool InGraphMode => _in_graph_mode; + + protected bool _trainable; + public bool Trainable => _trainable; + + protected Tensor _initial_value; + + public Operation initializer => initializer_op; + + protected Tensor _parent_op; + public Tensor parent_op => _parent_op; + + /// + /// Tensor handle + /// + protected Tensor handle; + public Tensor Handle => handle; + protected Tensor _graph_element; + public Tensor GraphElement => _graph_element; + protected Shape _shape; + public Shape shape => _shape; + + protected Operation initializer_op; + public Operation Initializer => initializer_op; + public Operation Op => handle.op; + public Graph Graph => handle.graph; + public string Device => handle.Device; + EagerResourceDeleter eager_resource_deleter; + public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; + + public BaseResourceVariable() + { + } + + public void __init__(bool trainable = true, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor handle = null, + string name = null, + string unique_id = null, + string handle_name = null) + { + _trainable = trainable; + _handle_name = handle_name + ":0"; + _unique_id = unique_id; + this.handle = handle; + _name = name; + if(shape is not null) + { + _shape = shape; + } + if(dtype != TF_DataType.DtInvalid) + { + _dtype = dtype; + } + + // After the handle has been created, set up a way to clean it up when + // executing eagerly. We'll hold the only reference to the deleter, so that + // when this object is garbage collected the deleter will be too. This + // means ResourceVariables can be part of reference cycles without those + // cycles being uncollectable. + if (handle is EagerTensor) + { + _handle = handle.EagerTensorHandle.DangerousGetHandle(); + // eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); + } + else if(handle is null) + { + // TODO: fix this dangerous change. + _handle = IntPtr.Zero; + } + else + { + _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); + } + +#if TRACK_TENSOR_LIFE + print($"Created Resource 0x{_handle.ToString("x16")} {_name}"); +#endif + } + + public Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true) + { + if (value.GetType() == typeof(Tensor)) + { + var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name); + if (read_value) + return assign; + return assign.op; + } + + var value_tensor = ops.convert_to_tensor(value, dtype: dtype); + var assign_op = gen_resource_variable_ops.assign_variable_op( + handle, value_tensor, name: name); + + if (read_value) + return gen_resource_variable_ops.read_variable_op(handle, dtype); + + if (assign_op == null) + return null; + + return assign_op; + } + + public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice) + { + _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value); + } + + void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null, + int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0) + { + var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value, + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask); + } + + public IVariableV1 assign_lazy_load(Tensor value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, dtype: dtype); + var assign_op = gen_resource_variable_ops.assign_variable_op( + handle, value_tensor, name: name); + var variable = _lazy_read(assign_op, value_tensor); + return variable; + } + + public Tensor value() + => GraphElement ?? _read_variable_op(); + + protected Tensor _read_variable_op(bool no_copy = false) + { + variable_accessed(this); + + Tensor read_and_set_handle(bool no_copy) + { + if (no_copy) + { + gen_resource_variable_ops.disable_copy_on_read(handle); + } + var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); + resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); + return result; + } + + // TODO(Rinne): deal with caching device. + var result = read_and_set_handle(no_copy); + if (!tf.Context.executing_eagerly()) + { + tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle }, + backward_function: (x, _) => x); + } + + // have to set shape when converting to substituent placeholder + if (result.shape.ndim == -1) + { + c_api.TF_GraphSetTensorShape(result.graph, + result._as_tf_output(), + shape.dims, + shape.ndim, + tf.Status); + tf.Status.Check(true); + } + + return result; + } + + IVariableV1 _lazy_read(Operation op, Tensor value) + { + variable_accessed(this); + return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id); + } + + /// + /// Records that `variable` was accessed for the tape and FuncGraph. + /// + void variable_accessed(BaseResourceVariable variable) + { + if(ops.get_default_graph() is FuncGraph func_graph) + { + func_graph.watch_variable(variable as IVariableV1); + } + if (variable.Trainable) + { + foreach (var tape in tf.GetTapeSet()) + tape.VariableAccessed(variable as ResourceVariable); + } + } + + /// + /// Constructs an op which reads the value of this variable. + /// + /// Should be used when there are multiple reads, or when it is desirable to + /// read the value only after some condition is true. + /// + /// + protected Tensor read_value() + { + var value = tf_with(ops.name_scope("Read"), delegate + { + return _read_variable_op(); + }); + return array_ops.identity(value); + } + + + public Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + if (read_value) + return gen_resource_variable_ops.read_variable_op(handle, dtype); + // return _lazy_read(assign_add_op); + return assign_add_op; + } + + public Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + if (read_value) + return gen_resource_variable_ops.read_variable_op(handle, dtype); + // return _lazy_read(assign_add_op); + return assign_sub_op; + } + + public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null) + { + var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + return _lazy_read(assign_sub_op, delta); + } + + public override string ToString() + { + if (tf.Context.executing_eagerly()) + return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}"; + else + return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}"; + } + + public NDArray numpy() => read_value().numpy(); + + protected override void DisposeUnmanagedResources(IntPtr handle) + { +#if TRACK_TENSOR_LIFE + print($"Deleted Resource 0x{handle.ToString("x16")} {_name}"); +#endif + } + + public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + if (as_ref) + return read_value().op.inputs[0]; + else + return value(); + } + + public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) + { + BaseResourceVariable new_variable; + if (save_options.experimental_variable_policy.save_variable_devices()) + { + Debug.Assert(this is ResourceVariable); + new_variable = tf_with(ops.device(this.Device), _ => + { + return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + }); + } + else + { + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + Dictionary obj_map = new(); + Dictionary resource_map = new(); + obj_map[this] = new_variable; + resource_map[this.handle] = new_variable.handle; + return (obj_map, resource_map); + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// ubclasses of ResourceVariables could choose to override this method to + /// customize extra information to provide when saving a SavedModel. + /// + /// + /// + public virtual void write_object_proto(SavedObject proto, SaveOptions options) + { + resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); + } + + public override IDictionary>> gather_saveables_for_checkpoint() + { + var res = new Dictionary>>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; + return res; + } + + public Tensor is_initialized(string name = null) + { + return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); + } + + public Tensor read_value_no_copy() + { + Tensor value = null; + tf_with(ops.name_scope("Read"), _ => + { + // TODO: `no_copy = true`. + value = _read_variable_op(); + }); + return array_ops.identity(value); + } + + //public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; + //public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); + //public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; + //public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); + + //public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); + //public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; + //public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; + + //public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; + + //public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; + } +} diff --git a/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs b/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs new file mode 100644 index 000000000..77bf471b0 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + public class EagerResourceDeleter : DisposableObject + { + Tensor _tensor; + string _handle_device; + public EagerResourceDeleter(Tensor handle, string handle_device) + { + _tensor = handle; + _handle = handle.EagerTensorHandle.DangerousGetHandle(); + _handle_device = handle_device; + } + + protected override void DisposeUnmanagedResources(IntPtr handle) + { + // gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true); + + // tf.device(_handle_device); + tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp", + new[] { _tensor }, + new object[] { "ignore_lookup_error", true }, 0); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs new file mode 100644 index 000000000..3eb78153a --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -0,0 +1,58 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; + +namespace Tensorflow +{ + /// + /// A variable maintains state in the graph across calls to `run()`. You add a + /// variable to the graph by constructing an instance of the class `Variable`. + /// + /// The `Variable()` constructor requires an initial value for the variable, + /// which can be a `Tensor` of any type and shape. The initial value defines the + /// type and shape of the variable. After construction, the type and shape of + /// the variable are fixed. The value can be changed using one of the assign methods. + /// https://tensorflow.org/guide/variables + /// + public interface IVariableV1 + { + string UniqueId { get; } + string Name { get; } + /// + /// Handle is ref type + /// + Tensor Handle { get; } + string Device { get; } + Operation Initializer { get; } + Operation Op { get; } + /// + /// GraphElement is a copy of Handle + /// + Tensor GraphElement { get; } + Graph Graph { get; } + TF_DataType dtype { get; } + Shape shape { get; } + bool Trainable { get; } + Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); + Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true); + IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); + Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); + IVariableV1 assign_lazy_load(Tensor value, string name = null); + Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); + NDArray numpy(); + } +} diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs new file mode 100644 index 000000000..32c016b44 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs @@ -0,0 +1,119 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow +{ + public class PureVariableScope : ITensorFlowObject + { + private string _name; + private VariableScope _scope; + private string _new_name; + private string _old_name_scope; + private bool _reuse; + private _VariableStore _var_store; + private VariableScope _old; + private _VariableScopeStore _var_scope_store; + private VariableScope variable_scope_object; + private VariableScope _cached_variable_scope_object; + VariableScope _last_variable_scope_object; + Dictionary _old_subscopes; + public PureVariableScope(string name, + string old_name_scope = null, + TF_DataType dtype = TF_DataType.DtInvalid) + { + _name = name; + _old_name_scope = old_name_scope; + _var_store = variable_scope._get_default_variable_store(); + _var_scope_store = variable_scope.get_variable_scope_store(); + } + + public PureVariableScope(VariableScope scope, + string old_name_scope = null, + TF_DataType dtype = TF_DataType.DtInvalid) + { + _scope = scope; + _old_name_scope = old_name_scope; + _var_store = variable_scope._get_default_variable_store(); + _var_scope_store = variable_scope.get_variable_scope_store(); + _new_name = _scope.name; + + string name_scope = _scope._name_scope; + variable_scope_object = new VariableScope(_reuse, + name: _new_name, + name_scope: name_scope); + + _cached_variable_scope_object = variable_scope_object; + } + + public void __enter__() + { + _old = _var_scope_store.current_scope; + if (_scope != null) + { + _var_scope_store.open_variable_scope(_new_name); + _old_subscopes = _var_scope_store.variable_scopes_count.ToDictionary(kv => kv.Key, kv => kv.Value); + variable_scope_object = _cached_variable_scope_object; + } + else + { + _new_name = string.IsNullOrEmpty(_old.name) ? _name : _old.name + "/" + _name; + _reuse = _reuse || _old.resue; + string name_scope = _old_name_scope == null ? _name : _old_name_scope; + + variable_scope_object = new VariableScope(_reuse, + name: _new_name, + name_scope: name_scope); + + _var_scope_store.open_variable_scope(_new_name); + } + _var_scope_store.current_scope = variable_scope_object; + _last_variable_scope_object = variable_scope_object; + } + + public void Dispose() + { + + } + + public void __exit__() + { + // If jumping out from a non-prolonged scope, restore counts. + if (_scope != null) + _var_scope_store.variable_scopes_count = _old_subscopes; + else + _var_scope_store.close_variable_subscopes(_new_name); + _var_scope_store.current_scope = _old; + } + + public void __init__() + { + + } + + public void __del__() + { + + } + + public static implicit operator VariableScope(PureVariableScope scope) + { + return scope.variable_scope_object; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs new file mode 100644 index 000000000..6bc90ae9c --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -0,0 +1,25 @@ +namespace Tensorflow +{ + public partial class RefVariable + { + public static implicit operator _VariableScopeStore(RefVariable variable) + { + return null; + } + + public static implicit operator RefVariable(_VariableScopeStore store) + { + return null; + } + + public static implicit operator Tensor(RefVariable var) + { + return var.AsTensor(); + } + + public static implicit operator RefVariable(Tensor var) + { + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs new file mode 100644 index 000000000..92fbddb6d --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -0,0 +1,60 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class RefVariable + { + public static Tensor operator +(RefVariable x, int y) => op_helper("add", x, y); + public static Tensor operator +(RefVariable x, float y) => op_helper("add", x, y); + public static Tensor operator +(RefVariable x, double y) => op_helper("add", x, y); + + public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y); + public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y); + public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); + public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y); + + public static Tensor operator <(RefVariable x, Tensor y) => gen_math_ops.less(x.value(), y); + + public static Tensor operator >(RefVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); + + private static Tensor op_helper(string default_name, RefVariable x, T y) + { + var xVal = x.value(); + return tf_with(ops.name_scope(null, default_name, new { xVal, y }), scope => + { + string name = scope; + var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); + Tensor result = null; + switch (default_name) + { + case "add": + result = gen_math_ops.add(xVal, yTensor, name); + break; + case "sub": + result = gen_math_ops.sub(xVal, yTensor, name); + break; + default: + throw new NotImplementedException(""); + } + return result; + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs new file mode 100644 index 000000000..7b08f3ea4 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -0,0 +1,450 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; +using Tensorflow.Train; + +namespace Tensorflow +{ + [Obsolete] + public partial class RefVariable: Trackable, IVariableV1, IProtoBuf + { + protected string _name; + public string UniqueId => _name; + public Tensor GraphElement { get; } + public Tensor _variable; + public Tensor Handle => _variable; + protected string _graph_key; + public Graph Graph => _variable.graph; + + public Tensor _is_initialized_op { get; set; } + + protected TF_DataType _dtype; + + public bool _in_graph_mode = true; + public Tensor _initial_value; + public bool _trainable; + + public Tensor _snapshot; + public bool _save_slice_info; + + private Operation _initializer_op; + public Operation Initializer => _initializer_op; + public Operation Op => _variable.op; + + public TF_DataType dtype => _variable.dtype; + public Shape shape => _variable.shape; + public string Device => ""; + + public string Name => _variable.name; + + public Tensor eval() => _variable; + public bool Trainable => _trainable; + + public RefVariable(object initial_value = null, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", + string name = null, + VariableDef variable_def = null, + TF_DataType dtype = TF_DataType.DtInvalid, + string import_scope = "") : base() + { + _in_graph_mode = true; + + if (initial_value is Operation op) + { + _init_from_op(op); + } + else if (variable_def != null) + { + if (initial_value != null) + throw new ValueError("variable_def and initial_value are mutually exclusive."); + _init_from_proto(variable_def, import_scope: import_scope); + } + else + { + _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); + } + } + + private void _init_from_op(Operation op) + { + var g = ops.get_default_graph(); + _initializer_op = op; + _variable = op.output; + } + + private void _init_from_proto(VariableDef variable_def, string import_scope = "") + { + var g = ops.get_default_graph(); + + _variable = g.as_graph_element( + ops.prepend_name_scope(variable_def.VariableName, + import_scope: import_scope)) as Tensor; + + _initializer_op = g.as_graph_element( + ops.prepend_name_scope(variable_def.InitializerName, + import_scope: import_scope)) as Operation; + + // Tests whether initial_value_name exists first for backwards compatibility. + if (!string.IsNullOrEmpty(variable_def.InitialValueName)) + _initial_value = g.as_graph_element( + ops.prepend_name_scope(variable_def.InitialValueName, + import_scope: import_scope)) as Tensor; + else + _initial_value = null; + + _trainable = variable_def.Trainable; + _snapshot = g.as_graph_element( + ops.prepend_name_scope(variable_def.SnapshotName, + import_scope: import_scope)) as Tensor; + + if (variable_def.SaveSliceInfoDef != null) + throw new NotImplementedException("save_slice_info_def"); + else +#pragma warning disable CS0642 // Possible mistaken empty statement + ;// _save_slice_info = null; +#pragma warning restore CS0642 // Possible mistaken empty statement + + //_caching_device = null; + //_constraint = null; + } + + private void _init_from_args(object initial_value, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) + { + if (initial_value is null) + throw new ValueError("initial_value must be specified."); + + var init_from_fn = initial_value.GetType().Name == "Func`1"; + + if (collections == null) + { + collections = new List { tf.GraphKeys.GLOBAL_VARIABLES }; + } + + // Store the graph key so optimizers know how to only retrieve variables from + // this graph. + _graph_key = ops.get_default_graph().graph_key; + + _trainable = trainable; + if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) + collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); + + tf_with(ops.init_scope(), init_scope => + { + var values = init_from_fn ? new object[0] : new object[] { initial_value }; + tf_with(ops.name_scope(name, "Variable", values), scope => + { + name = scope; + + if (init_from_fn) + { + // Use attr_scope and device(None) to simulate the behavior of + // colocate_with when the variable we want to colocate with doesn't + // yet exist. + string true_name = ops.name_from_scope_name(name); + var attr = new AttrValue + { + List = new AttrValue.Types.ListValue() + }; + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); + tf_with(ops.name_scope("Initializer"), scope2 => + { + _initial_value = (initial_value as Func)(); + _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); + }); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + } + // Or get the initial value from a Tensor or Python object. + else + { + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); + + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); + } + + // Manually overrides the variable's shape with the initial value's. + if (validate_shape) + { + var initial_value_shape = _initial_value.shape; + if (!initial_value_shape.IsFullyDefined) + throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); + } + + // If 'initial_value' makes use of other variables, make sure we don't + // have an issue if these other variables aren't initialized first by + // using their initialized_value() method. + var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); + + _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; + + if (!String.IsNullOrEmpty(caching_device)) + { + + } + else + { + ops.colocate_with(_initializer_op); + + _snapshot = gen_array_ops.identity(_variable, name = "read"); + } + + ops.add_to_collections(collections, this as IVariableV1); + }); + }); + } + + public Tensor _ref() => _variable; + + public Tensor value() => _snapshot; + + public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) => _snapshot; + + public Tensor _as_graph_element() => _variable; + + public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + if (as_ref) + return _ref(); + else + return value(); + } + + /// + /// Attempt to guard against dependencies on uninitialized variables. + /// + /// + private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) + { + return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary()); + } + + /// + /// Replace dependencies on variables with their initialized values. + /// + /// A `Tensor`. The tensor to replace. + /// A dict mapping operation names to `Operation`s. + /// A `Tensor` compatible with `tensor`. + private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) + { + var op = tensor.op; + var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; + if (new_op == null) + { + new_op = _safe_initial_value_from_op(name, op, op_cache); + op_cache[op.name] = new_op; + } + return new_op.outputs[tensor.value_index]; + } + + private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) + { + var op_type = op.node_def.Op; + switch (op_type) + { + case "IsVariableInitialized": + case "VarIsInitializedOp": + case "ReadVariableOp": + return op; + case "Variable": + case "VariableV2": + case "VarHandleOp": + var initialized_value = _find_initialized_value_for_variable(op); + return initialized_value == null ? op : initialized_value.op; + } + + // Recursively build initializer expressions for inputs. + var modified = false; + var new_op_inputs = new List(); + foreach (var op_input in op.inputs) + { + var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache); + new_op_inputs.Add(new_op_input); + modified = modified || new_op_input != op_input; + } + + // If at least one input was modified, replace the op. + if (modified) + { + var new_op_type = op_type; + if (new_op_type == "RefSwitch") + new_op_type = "Switch"; + var new_op_name = op.node_def.Name + "_" + name; + new_op_name = new_op_name.Replace(":", "_"); + + // Convert attr values to AttrValue protos. + var attr_protos = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attr_protos[attr_def.Key] = attr_def.Value; + + return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, + name: new_op_name, attrs: attr_protos); + } + return op; + } + + private Operation _find_initialized_value_for_variable(Operation variable_op) + { + var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" }; + foreach (var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES, + tf.GraphKeys.LOCAL_VARIABLES }) + { + foreach (var var in variable_op.graph.get_collection(collection_name)) + if (var_names.Contains(var.Name)) + return var.initialized_value(); + } + + return null; + } + + /// + /// Assigns a new value to the variable. + /// + /// The new value for this variable. + /// If `True`, use locking during the assignment. + /// The name of the operation to be created + /// + /// if True, will return something which evaluates to the + /// new value of the variable; if False will return the assign op. + /// + /// + /// A `Tensor` that will hold the new value of this variable after + /// the assignment has completed. + /// + public Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true) + { + var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); + if (read_value) + return assign; + return assign.op; + } + + public override string ToString() + { + return $"tf.RefVariable '{Name}' shape={shape} dtype={dtype}"; + } + + public VariableDef to_proto(string export_scope) + { + if (string.IsNullOrEmpty(export_scope) || _variable.name.StartsWith(export_scope)) + { + var var_def = new VariableDef(); + var_def.VariableName = ops.strip_name_scope(_variable.name, export_scope); + if (_initial_value != null) + var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); + var_def.Trainable = _trainable; + var_def.InitializerName = ops.strip_name_scope(Initializer.name, export_scope); + var_def.SnapshotName = ops.strip_name_scope(_snapshot.name, export_scope); + if (_save_slice_info) + throw new NotImplementedException("to_proto _save_slice_info"); + + return var_def; + } + + throw new NotImplementedException("to_proto RefVariable"); + } + + public RefVariable from_proto(VariableDef proto, string import_scope) + { + throw new NotImplementedException(); + } + + /// + /// Returns the value of this variable, read in the current context. + /// + /// + private ITensorOrOperation read_value() + { + return array_ops.identity(_variable, name: "read"); + } + + /// + /// Returns the Tensor used as the initial value for the variable. + /// + /// + private ITensorOrOperation initial_value() + { + return _initial_value; + } + + public Tensor is_variable_initialized(RefVariable variable) + { + return state_ops.is_variable_initialized(variable); + } + + public Tensor initialized_value() + { + ops.init_scope(); + return control_flow_ops.cond(is_variable_initialized(this), + read_value, + initial_value); + } + + // Update 'ref' by adding 'value' to it. + // This operation outputs "ref" after the update is done. + // This makes it easier to chain operations that need to use the reset value. + // Args: + // ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + // Should be from a `Variable` node. + // value: A `Tensor`. Must have the same type as `ref`. + // The value to be added to the variable. + // use_locking: An optional `bool`. Defaults to `False`. + // If True, the addition will be protected by a lock; + // otherwise the behavior is undefined, but may exhibit less contention. + // name: A name for the operation(optional). + // Returns: + // A mutable `Tensor`. Has the same type as `ref`. + public Tensor assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) + { + var variable = this; + var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); + return _op; + } + + public NDArray numpy() + => throw new RuntimeError("Graph mode can't use numpy()."); + + public Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + throw new NotImplementedException(); + } + + public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null) + { + throw new NotImplementedException(); + } + + public IVariableV1 assign_lazy_load(Tensor value, string name = null) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Functions.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Functions.cs new file mode 100644 index 000000000..d3e77c76a --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Functions.cs @@ -0,0 +1,45 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class ResourceVariable + { + /// + /// Subtracts a value from this variable. + /// + /// + /// + /// + /// + public void assign_sub(Tensor delta, bool use_locking = false, string name = null, bool read_value = true) + { + gen_resource_variable_ops.assign_sub_variable_op(handle, delta, name: name); + } + + /// + /// Adds a value to this variable. + /// + /// + /// + /// + /// + public void assign_add(Tensor delta, bool use_locking = false, string name = null, bool read_value = true) + { + gen_resource_variable_ops.assign_add_variable_op(handle, delta, name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs new file mode 100644 index 000000000..29771c06b --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs @@ -0,0 +1,42 @@ +using System; +using Tensorflow.Eager; + +namespace Tensorflow +{ + public partial class ResourceVariable + { + public static implicit operator _VariableScopeStore(ResourceVariable variable) + { + return null; + } + + public static implicit operator ResourceVariable(_VariableScopeStore store) + { + return null; + } + + public static implicit operator Tensor(ResourceVariable var) + => var._dense_var_to_tensor(); + + public static implicit operator EagerTensor(ResourceVariable var) + => var._dense_var_to_tensor() as EagerTensor; + + public static implicit operator IntPtr(ResourceVariable var) + => var._handle; + + Tensor _dense_var_to_tensor(TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool as_ref = false) + { + return value(); + } + + public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + if (as_ref) + return handle; + else + return GraphElement ?? read_value(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs new file mode 100644 index 000000000..7876a9904 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs @@ -0,0 +1,70 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using System.Linq; + +namespace Tensorflow +{ + public partial class ResourceVariable + { + public Tensor this[params Slice[] slices] + { + get + { + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "strided_slice", args), scope => + { + string name = scope; + if (args.Begin != null) + { + (args.PackedBegin, args.PackedEnd, args.PackedStrides) = + (array_ops.stack(args.Begin), + array_ops.stack(args.End), + array_ops.stack(args.Strides)); + + var tensor = gen_array_ops.strided_slice( + this, + args.PackedBegin, + args.PackedEnd, + args.PackedStrides, + begin_mask: args.BeginMask, + end_mask: args.EndMask, + shrink_axis_mask: args.ShrinkAxisMask, + new_axis_mask: args.NewAxisMask, + ellipsis_mask: args.EllipsisMask, + name: name); + + tensor.OriginalVar = this; + tensor.OriginalVarSlice = args; + + return tensor; + } + + throw new NotImplementedException(""); + }); + } + } + + public Tensor this[params string[] slices] + => this[slices.Select(x => new Slice(x)).ToArray()]; + } +} diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs new file mode 100644 index 000000000..2737a2191 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.NumPy; + +namespace Tensorflow +{ + public partial class ResourceVariable + { + public static Tensor operator +(ResourceVariable x, int y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, float y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, double y) => x.value() + y; + public static Tensor operator +(ResourceVariable x, ResourceVariable y) => x.value() + y.value(); + public static Tensor operator -(ResourceVariable x, int y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, float y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, double y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, Tensor y) => x.value() - y; + public static Tensor operator -(ResourceVariable x, ResourceVariable y) => x.value() - y.value(); + + public static Tensor operator *(ResourceVariable x, ResourceVariable y) => x.value() * y.value(); + public static Tensor operator *(ResourceVariable x, Tensor y) => x.value() * y; + public static Tensor operator *(ResourceVariable x, NDArray y) => x.value() * y; + + public static Tensor operator <(ResourceVariable x, Tensor y) => x.value() < y; + + public static Tensor operator >(ResourceVariable x, Tensor y) => x.value() > y; + } +} diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs new file mode 100644 index 000000000..bc23df3ed --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -0,0 +1,289 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using System; +using System.Collections.Generic; +using Tensorflow.Checkpoint; +using Tensorflow.NumPy; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Variable based on resource handles. + /// + public partial class ResourceVariable : BaseResourceVariable, IVariableV1 + { + public ResourceVariable(object initial_value = null, + bool trainable = true, + List collections = null, + bool validate_shape = true, + string caching_device = "", + string name = null, + VariableDef variable_def = null, + TF_DataType dtype = TF_DataType.DtInvalid, + string import_scope = "", + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null) + { + Aggregation = aggregation; + if (variable_def != null) + { + if (initial_value != null) + throw new ValueError("variable_def and initial_value are mutually exclusive."); + _init_from_proto(variable_def, import_scope: import_scope); + } + else + { + _init_from_args(initial_value: initial_value, + trainable: trainable, + collections: collections, + caching_device: caching_device, + name: name, + dtype: dtype, + aggregation: aggregation, + shape: shape); + } + } + + private void _init_from_args(object initial_value = null, + bool trainable = true, + List collections = null, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null) + { + var init_from_fn = initial_value.GetType().Name == "Func`1" || + initial_value.GetType().GetInterface("IInitializer") != null; + if (collections == null) + collections = new List() { tf.GraphKeys.GLOBAL_VARIABLES }; + _trainable = trainable; + + if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) + collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); + + tf_with(ops.init_scope(), init_scope => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", initial_value, skip_on_eager: false), scope => + { + name = scope; + var handle_name = ops.name_from_scope_name(name); + string unique_id = ""; + string shared_name = ""; + + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}_{ops.uid()}"; + shared_name = null; + } + + var attr = new AttrValue(); + attr.List = new AttrValue.Types.ListValue(); + attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); + tf_with(ops.name_scope("Initializer"), delegate + { + if (initial_value.GetType().GetInterface("IInitializer") != null) + _initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); + else + { + var value = init_from_fn ? (initial_value as Func)() : initial_value; + _initial_value = ops.convert_to_tensor(value, + name: "initial_value", + dtype: dtype); + } + }); + + if(shape is null) + { + shape = _initial_value.shape; + } + dtype = _initial_value.dtype; + + if (_in_graph_mode) + { + // TODO(Rinne): deal with initializer_op. + //if(initial_value is not null) + //{ + // tf_with(ops.name_scope("Assign"), n => + // { + // tf_with(ops.device(handle.Device), _ => + // { + + // }); + // }); + //} + handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); + initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; + + ops.colocate_with(initializer_op); + tf_with(ops.device(handle.Device), _ => + { + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + _graph_element = gen_array_ops.identity(handle, name = "read"); + ops.add_to_collections(collections, this); + _dtype = handle.dtype; + }); + } + else + { + handle = resource_variable_ops.eager_safe_variable_handle( + initial_value: _initial_value, + shape: shape, + shared_name: shared_name, + name: name, + graph_mode: _in_graph_mode); + + gen_resource_variable_ops.assign_variable_op(handle, _initial_value); + initializer_op = null; + _graph_element = null; + if (!string.IsNullOrEmpty(caching_device)) + { + tf_with(ops.device(caching_device), _ => + { + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + }); + } + _dtype = _initial_value.dtype.as_base_dtype(); + // initial_value = _in_graph_mode ? initial_value : null; + } + + base.__init__(trainable: trainable, + shape: shape, + dtype: _dtype, + handle: handle, + name: name, + unique_id: unique_id, + handle_name: handle_name); + }); + }); + } + + private void _init_from_proto(VariableDef variable_def, string import_scope = null) + { + _in_graph_mode = true; + if (!variable_def.IsResource) + throw new ValueError("Trying to restore Variable as ResourceVariable."); + + // Create from variable_def. + var g = ops.get_default_graph(); + var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope); + handle = g.as_graph_element(prepend_name_scope) as Tensor; + _handle_name = handle.name; + _name = handle.name; + _shape = new Shape(handle.op.get_attr("shape") as TensorShapeProto); + + prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope); + initializer_op = g.as_graph_element(prepend_name_scope) as Operation; + if (!string.IsNullOrEmpty(variable_def.InitialValueName)) + { + prepend_name_scope = ops.prepend_name_scope(variable_def.InitialValueName, import_scope: import_scope); + _initial_value = g.as_graph_element(prepend_name_scope) as Tensor; + } + + _trainable = variable_def.Trainable; + /*var (synchronization, aggregation, trainable) = + variables.validate_synchronization_aggregation_trainable( + variable_def.Synchronization, + variable_def.Aggregation, + variable_def.Trainable, + variable_def.VariableName);*/ + if (!string.IsNullOrEmpty(variable_def.SnapshotName)) + { + prepend_name_scope = ops.prepend_name_scope(variable_def.SnapshotName, import_scope: import_scope); + var snapshot = g.as_graph_element(prepend_name_scope) as Tensor; + while (snapshot.op.type != "ReadVariableOp") + snapshot = snapshot.op.inputs[0]; + _graph_element = snapshot; + } + else + { + throw new NotImplementedException("SnapshotName _init_from_proto"); + } + + if (variable_def.SaveSliceInfoDef != null) + { + throw new NotImplementedException("SaveSliceInfoDef _init_from_proto"); + } + + _dtype = dtypes.as_tf_dtype((DataType)handle.op.get_attr("dtype")); + } + + public Tensor sparse_read(Tensor indices, string name = "Gather") + { + return tf_with(ops.name_scope(name), scope => + { + name = scope; + var value = gen_resource_variable_ops.resource_gather( + handle, indices, dtype: _dtype, name: name); + + return array_ops.identity(value); + }); + } + + public VariableDef to_proto(string export_scope) + { + if (string.IsNullOrEmpty(export_scope) || Handle.name.StartsWith(export_scope)) + { + var var_def = new VariableDef(); + var_def.VariableName = ops.strip_name_scope(Handle.name, export_scope); + if (_initial_value != null) + var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); + var_def.Trainable = _trainable; + var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); + var_def.SnapshotName = ops.strip_name_scope(_graph_element.name, export_scope); + + return var_def; + } + + throw new NotImplementedException("to_proto RefVariable"); + } + + public NDArray eval(Session session = null) + { + return _graph_element.eval(session); + } + + public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( + VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) + { + if(aggregation is null) + { + aggregation = VariableAggregation.None; + } + if(synchronization is null) + { + synchronization = VariableSynchronization.Auto; + } + if (trainable is null) + { + trainable = synchronization != VariableSynchronization.OnRead; + } + return (synchronization.Value, aggregation.Value, trainable.Value); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/SafeResourceVariableHandle.cs b/src/TensorFlowNET.Core/Variables/SafeResourceVariableHandle.cs new file mode 100644 index 000000000..dc3f09df6 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/SafeResourceVariableHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow.Variables +{ + public sealed class SafeResourceVariableHandle : SafeTensorflowHandle + { + private SafeResourceVariableHandle() + { + } + + public SafeResourceVariableHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TFE_DeleteResourceVariable(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs new file mode 100644 index 000000000..e26312447 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + /// + /// A variable with no initializer. + /// + public sealed class UninitializedVariable : BaseResourceVariable, IVariableV1 + { + // TODO: complete the arg list. + public UninitializedVariable( + bool trainable = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null, + Tensor extra_handle_data = null) + { + string unique_id = ""; + string handle_name = ""; + Tensor created_handle = null; + tf_with(ops.init_scope(), (x) => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => + { + handle_name = ops.name_from_scope_name(name); + string? shared_name; + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}-{ops.uid()}"; + shared_name = null; + } + created_handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); + // skip the assignment of `handle._parent_trackable` because of lack of API. + // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. + + if (_in_graph_mode) + { + tf_with(ops.name_scope("Read"), _ => + { + var value = tf_with(ops.device(created_handle.Device), _ => + { + var result = gen_resource_variable_ops.read_variable_op(created_handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, created_handle, result); + return result; + }); + _graph_element = value; + }); + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); + } + else + { + _graph_element = null; + } + }); + }); + base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/VariableArgs.cs b/src/TensorFlowNET.Core/Variables/VariableArgs.cs new file mode 100644 index 000000000..ed1e3b98d --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableArgs.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + public class VariableArgs + { + public object InitialValue { get; set; } + public Func Getter { get; set; } + public string Name { get; set; } + public Shape Shape { get; set; } + public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; + public IInitializer Initializer { get; set; } + public bool Trainable { get; set; } + public bool ValidateShape { get; set; } = true; + public bool UseResource { get; set; } = true; + public bool Overwrite { get; set; } + public List Collections { get; set; } + public string CachingDevice { get; set; } = ""; + public VariableDef VariableDef { get; set; } + public string ImportScope { get; set; } = ""; + public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; + public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; + } +} diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs new file mode 100644 index 000000000..c9a6fffbe --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -0,0 +1,85 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Variable scope object to carry defaults to provide to `get_variable` + /// + public class VariableScope + { + public bool use_resource { get; set; } +#pragma warning disable CS0414 // The field 'VariableScope._reuse' is assigned but its value is never used + private _ReuseMode _reuse; +#pragma warning restore CS0414 // The field 'VariableScope._reuse' is assigned but its value is never used + public bool resue; + + private TF_DataType _dtype; + string _name; + public string name => _name; + public string _name_scope { get; set; } + public string original_name_scope => _name_scope; + + public VariableScope(bool reuse, + string name = "", + string name_scope = "", + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + _name = name; + _name_scope = name_scope; + _reuse = _ReuseMode.AUTO_REUSE; + _dtype = dtype; + } + + public IVariableV1 get_variable(_VariableStore var_store, + string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + object initializer = null, // IInitializer or Tensor + bool? trainable = null, + List collections = null, + bool? use_resource = null, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; + return tf_with(ops.name_scope(null), scope => + { + if (dtype == TF_DataType.DtInvalid) + dtype = _dtype; + + return var_store.get_variable(full_name, + shape: shape, + dtype: dtype, + initializer: initializer, + reuse: resue, + trainable: trainable, + collections: collections, + synchronization: synchronization, + aggregation: aggregation); + }); + } + + public void reuse_variables() + { + _reuse = _ReuseMode.AUTO_REUSE; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs new file mode 100644 index 000000000..9344e8248 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs @@ -0,0 +1,13 @@ +namespace Tensorflow +{ + /// + /// Mode for variable access within a variable scope. + /// + public enum _ReuseMode + { + NOT_REUSE = 0, + // Indicates that variables are to be fetched if they already exist or + // otherwise created. + AUTO_REUSE = 1 + } +} diff --git a/src/TensorFlowNET.Core/Variables/_UnreadVariable.cs b/src/TensorFlowNET.Core/Variables/_UnreadVariable.cs new file mode 100644 index 000000000..f5d0504ec --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_UnreadVariable.cs @@ -0,0 +1,28 @@ +using Tensorflow.Eager; + +namespace Tensorflow +{ + /// + /// Represents a future for a read of a variable. + /// Pretends to be the tensor if anyone looks. + /// + public class _UnreadVariable : BaseResourceVariable, IVariableV1 + { + public override string Name => _in_graph_mode ? _parent_op.name : "UnreadVariable"; + + public _UnreadVariable(Tensor handle, TF_DataType dtype, Shape shape, + bool in_graph_mode, string unique_id) + { + _dtype = dtype; + _shape = shape; + base.handle = handle; + _unique_id = unique_id; + _in_graph_mode = in_graph_mode; + + if (handle is EagerTensor) + _handle_name = ""; + else + _handle_name = handle.name; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs new file mode 100644 index 000000000..083afe4fa --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs @@ -0,0 +1,59 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; + +namespace Tensorflow +{ + public class _VariableScopeStore + { + public VariableScope current_scope { get; set; } + public Dictionary variable_scopes_count; + + public _VariableScopeStore() + { + current_scope = new VariableScope(false); + variable_scopes_count = new Dictionary(); + } + + public void open_variable_scope(string scope_name) + { + if (variable_scopes_count.ContainsKey(scope_name)) + variable_scopes_count[scope_name] += 1; + else + variable_scopes_count[scope_name] = 1; + } + + public void close_variable_subscopes(string scope_name) + { + var variable_scopes_count_tmp = new Dictionary(); + foreach (var k in variable_scopes_count.Keys) + variable_scopes_count_tmp.Add(k, variable_scopes_count[k]); + + foreach (var k in variable_scopes_count_tmp.Keys) + if (scope_name == null || k.StartsWith(scope_name + "/")) + variable_scopes_count[k] = 0; + } + + public int variable_scope_count(string scope_name) + { + if (variable_scopes_count.ContainsKey(scope_name)) + return variable_scopes_count[scope_name]; + else + return 0; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs new file mode 100644 index 000000000..0570fd067 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -0,0 +1,185 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Variable store that carries a number of named Variables. + /// + public class _VariableStore + { + private Dictionary _vars; + private Dictionary _partitioned_vars; +#pragma warning disable CS0414 // The field '_VariableStore._store_eager_variables' is assigned but its value is never used + private bool _store_eager_variables; +#pragma warning restore CS0414 // The field '_VariableStore._store_eager_variables' is assigned but its value is never used + + public _VariableStore() + { + _vars = new Dictionary(); + _partitioned_vars = new Dictionary(); + _store_eager_variables = false; + } + + public IVariableV1 get_variable(string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + object initializer = null, // IInitializer or Tensor + bool? reuse = null, + bool? trainable = null, + List collections = null, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + dtype = dtype.as_base_dtype(); + trainable = variable_scope._get_trainable_value(synchronization, trainable); + + return _true_getter(name, + shape: shape, + dtype: dtype, + initializer: initializer, + trainable: trainable, + collections: collections, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + + private IVariableV1 _true_getter(string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + object initializer = null, + bool? trainable = null, + List collections = null, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + bool is_scalar = !(shape is null) && shape.ndim == 0; + + if (initializer is IInitializer init) + { + return _get_single_variable(name: name, + shape: shape, + dtype: dtype, + initializer: init, + trainable: trainable, + collections: collections, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + else if (initializer is Tensor tensor) + { + return _get_single_variable(name: name, + shape: shape, + dtype: dtype, + init_value: tensor, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + else + { + IInitializer init1 = null; + return _get_single_variable(name: name, + shape: shape, + dtype: dtype, + initializer: init1, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + } + + private IVariableV1 _get_single_variable(string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + Tensor init_value = null, + bool reuse = false, + bool? trainable = null, + List collections = null, + bool validate_shape = false, + bool? use_resource = null, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + bool initializing_from_value = init_value != null; + if (use_resource == null) + use_resource = variable_scope._DEFAULT_USE_RESOURCE; + + if (_vars.ContainsKey(name)) + { + if (!reuse) + { + var var = _vars[name]; + + } + throw new NotImplementedException("_get_single_variable"); + } + + IVariableV1 v = null; + // Create the tensor to initialize the variable with default value. + if (initializer == null && init_value == null) + { + if (dtype.is_floating()) + { + initializer = tf.glorot_uniform_initializer; + initializing_from_value = false; + } + } + + // Create the variable. + ops.init_scope(); + { + if (initializing_from_value) + { + v = new ResourceVariable(init_value, + name: name, + validate_shape: validate_shape, + trainable: trainable.Value); + } + else + { + Func init_val = () => initializer.Apply(new InitializerArgs(shape, dtype: dtype)); + var variable_dtype = dtype.as_base_dtype(); + + v = variable_scope.default_variable_creator(init_val, + name: name, + trainable: trainable, + collections: collections, + dtype: variable_dtype, + use_resource: use_resource, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + } + + _vars[name] = v; + + return v; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/c_api.variable.cs b/src/TensorFlowNET.Core/Variables/c_api.variable.cs new file mode 100644 index 000000000..78075f615 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/c_api.variable.cs @@ -0,0 +1,21 @@ +using System; +using System.Runtime.InteropServices; +using Tensorflow.Variables; + +namespace Tensorflow +{ + public partial class c_api + { + [DllImport(TensorFlowLibName)] + public static extern SafeResourceVariableHandle TFE_NewResourceVariable(); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteResourceVariable(IntPtr variable); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_SetResourceVariableHandle(SafeResourceVariableHandle variable, IntPtr tensor); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_SetResourceVariableName(SafeResourceVariableHandle variable, string name); + } +} diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs new file mode 100644 index 000000000..8d8c06999 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -0,0 +1,102 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class gen_state_ops + { + /// + /// Holds state in the form of a tensor that persists across steps. + /// Outputs a ref to the tensor state so it may be read or modified. + /// + /// The shape of the variable tensor. + /// The type of elements in the variable tensor. + /// + /// + /// + /// + public static Tensor variable_v2(int[] shape, TF_DataType dtype, string name = null, string container = "", string shared_name = "") + { + var _op = tf.OpDefLib._apply_op_helper("VariableV2", name: name, args: new { dtype, shape, container, shared_name }); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + var _attrs = new Dictionary(); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); + _attrs["container"] = _op.get_attr("container"); + _attrs["shared_name"] = _op.get_attr("shared_name"); + + return _result[0]; + } + + /// + /// Update 'ref' by assigning 'value' to it + /// + /// + /// + /// + /// + /// + public static Tensor assign(T @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + => tf.Context.ExecuteOp("Assign", name, new ExecuteOpArgs(@ref, value) + .SetAttributes(new { validate_shape, use_locking })); + + public static Tensor assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); + return _op.outputs[0]; + } + + public static Tensor assign_sub(IVariableV1 @ref, + Tensor value, + bool use_locking = false, + string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("AssignSub", name: name, args: new { @ref, value, use_locking }); + + return _op.outputs[0]; + } + + /// + /// Adds sparse updates to a variable reference. + /// + /// + /// + /// + /// + /// + /// + public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); + return _op.outputs[0]; + } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + var _op = tf.OpDefLib._apply_op_helper("IsVariableInitialized", name: name, args: new { @ref }); + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs new file mode 100644 index 000000000..6d79f9065 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -0,0 +1,127 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class state_ops + { + /// + /// Create a variable Operation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor variable_op_v2(int[] shape, + TF_DataType dtype, + string name = "Variable", + string container = "", + string shared_name = "") => gen_state_ops.variable_v2(shape, + dtype, + name: name, + container: container, + shared_name: shared_name); + + public static Tensor assign(T @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + return gen_state_ops.assign(@ref, + value, + validate_shape: validate_shape, + use_locking: use_locking, + name: name); + } + + public static Tensor assign(IVariableV1 @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.assign(@ref, + value, + validate_shape: validate_shape, + use_locking: use_locking, + name: name); + else + return @ref.assign(value, name: name); + } + + public static Tensor assign_sub(IVariableV1 @ref, + Tensor value, + bool use_locking = false, + string name = null) => @ref.dtype.is_ref_dtype() ? + gen_state_ops.assign_sub(@ref, + value, + use_locking: use_locking, + name: name) : + @ref.assign_sub(value, name: name); + + //"""Update 'ref' by adding 'value' to it. + // + // This operation outputs "ref" after the update is done. + // This makes it easier to chain operations that need to use the reset value. + // + // Args: + // ref: A mutable `Tensor`. Must be one of the following types: + // `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, + // `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. + // Should be from a `Variable` node. + // value: A `Tensor`. Must have the same type as `ref`. + // The value to be added to the variable. + // use_locking: An optional `bool`. Defaults to `False`. + // If True, the addition will be protected by a lock; + // otherwise the behavior is undefined, but may exhibit less contention. + // name: A name for the operation (optional). + // + // Returns: + // Same as "ref". Returned as a convenience for operations that want + // to use the new value after the variable has been updated. + public static Tensor assign_add(IVariableV1 @ref, + T value, + bool use_locking = false, + string name = null) + { + if (tf.executing_eagerly()) + return @ref.assign_add(value, use_locking: use_locking, name: name); + else + return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + } + + public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); + + throw new NotImplementedException("scatter_add"); + } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.is_variable_initialized(@ref: @ref, name: name); + throw new NotImplementedException(""); + //return @ref.is_initialized(name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs new file mode 100644 index 000000000..31f3285e7 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -0,0 +1,319 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Tensorflow +{ + /// + /// A context manager for defining ops that creates variables (layers). + /// + public class variable_scope : ITensorFlowObject + { + public static string _VARSTORE_KEY = "__variable_store"; + public static string _VARSCOPESTORE_KEY = "__varscope"; + public static bool _DEFAULT_USE_RESOURCE = true; + + private bool _use_resource; + public bool UseResource => _use_resource; + private string _name; + private VariableScope _scope; + private string _default_name; + private Tensor[] _values; + private ops.NameScope _current_name_scope; + private bool _auxiliary_name_scope; + private PureVariableScope _cached_pure_variable_scope; + private bool? _reuse; + bool _in_graph_mode; + protected Graph _graph; + bool _building_function; + + public variable_scope(string name, + string default_name = "", + Tensor[] values = null, + bool? reuse = null, + bool auxiliary_name_scope = true) + { + _name = name; + _default_name = default_name; + _values = values; + _current_name_scope = null; + _reuse = reuse; + _use_resource = false; + if (_default_name == null && _name == null) + throw new TypeError("If default_name is None then name is required"); + + _auxiliary_name_scope = auxiliary_name_scope; + } + + public variable_scope(VariableScope scope, + string default_name = "", + Tensor[] values = null, + bool? reuse = null, + bool auxiliary_name_scope = true) + { + _scope = scope; + _default_name = default_name; + _values = values; + _current_name_scope = null; + _reuse = reuse; + _use_resource = false; + if (_default_name == null && _scope == null) + throw new TypeError("If default_name is None then scope is required"); + + if (_values == null) + _values = new Tensor[0]; + _in_graph_mode = true; + if (_in_graph_mode) + _graph = ops._get_graph_from_inputs(_values); + _auxiliary_name_scope = auxiliary_name_scope; + } + + public void __enter__() + { + // If the default graph is building a function, then we should not replace it + // with the cached graph. + if (ops.get_default_graph().building_function) + _building_function = true; + else + _building_function = false; + if (_in_graph_mode && !_building_function) + { + _graph.as_default(); + } + + _scope = _enter_scope_uncached(); + } + + private VariableScope _enter_scope_uncached() + { + ops.NameScope current_name_scope; + PureVariableScope pure_variable_scope = null; + VariableScope entered_pure_variable_scope; + + if (_auxiliary_name_scope) + // Create a new name scope later + current_name_scope = null; + else + { + // Reenter the current name scope + string name_scope = ops.get_name_scope(); + if (!string.IsNullOrEmpty(name_scope)) + // Hack to reenter + name_scope += "/"; + current_name_scope = ops.name_scope(name_scope); + } + + if (!string.IsNullOrEmpty(_name) || _scope != null) + { + var name_scope = _scope == null ? _name : _scope.name.Split('/').Last(); + if (current_name_scope == null) + current_name_scope = ops.name_scope(name_scope); + current_name_scope.__enter__(); + string current_name_scope_name = current_name_scope; + _current_name_scope = current_name_scope; + string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope; + + if (_scope == null) + pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); + else + pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); + pure_variable_scope.__enter__(); + entered_pure_variable_scope = pure_variable_scope; + _cached_pure_variable_scope = pure_variable_scope; + return entered_pure_variable_scope; + } + else + { + current_name_scope = ops.name_scope(_default_name); + current_name_scope.__enter__(); + string current_name_scope_name = current_name_scope; + _current_name_scope = current_name_scope; + string unique_default_name = _get_unique_variable_scope(_default_name); + pure_variable_scope = new PureVariableScope(unique_default_name, + old_name_scope: current_name_scope_name); + pure_variable_scope.__enter__(); + entered_pure_variable_scope = pure_variable_scope; + _cached_pure_variable_scope = pure_variable_scope; + return entered_pure_variable_scope; + } + } + + /// + /// Get a name with the given prefix unique in the current variable scope. + /// + /// + /// + public static string _get_unique_variable_scope(string prefix) + { + var var_scope_store = get_variable_scope_store(); + var current_scope = get_variable_scope(); + string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix; + if (var_scope_store.variable_scope_count(name) == 0) + return prefix; + var idx = 1; + while (var_scope_store.variable_scope_count($"{name}_{idx}") > 0) + idx += 1; + return $"{prefix}_{idx}"; + } + + public static IVariableV1 default_variable_creator(object initial_value, + string name = null, + bool? trainable = null, + List collections = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, + bool validate_shape = false, + bool? use_resource = null, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + trainable = _get_trainable_value(synchronization, trainable); + if (!use_resource.HasValue) + { + use_resource = get_variable_scope().use_resource; + } + + if (!use_resource.HasValue) + use_resource = _DEFAULT_USE_RESOURCE; + + if (use_resource.Value) + { + return new ResourceVariable(initial_value, + trainable: trainable.Value, + validate_shape: validate_shape, + collections: collections, + name: name, + dtype: dtype, + shape: shape); + } + else + { + return new RefVariable(initial_value, + trainable: trainable.Value, + validate_shape: validate_shape, + collections: collections, + name: name, + dtype: dtype); + } + } + + public static _VariableStore _get_default_variable_store() + { + var store = ops.get_collection(_VARSTORE_KEY); + if (store != null) + return (store as List<_VariableStore>)[0]; + + var store1 = new _VariableStore(); + ops.add_to_collection(_VARSTORE_KEY, store1); + return store1; + } + + public static VariableScope get_variable_scope() + { + return get_variable_scope_store().current_scope; + } + + public static _VariableScopeStore get_variable_scope_store() + { + _VariableScopeStore ret = null; + var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); + if (scope_store == null) + { + ret = new _VariableScopeStore(); + ops.add_to_collection(_VARSCOPESTORE_KEY, ret); + } + else + { + switch (scope_store) + { + case List values: + ret = values[0]; + break; + case List<_VariableScopeStore> values: + ret = values[0]; + break; + default: + throw new InvalidOperationException("get_variable_scope_store"); + } + + } + + return ret; + } + + public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) + { + if (synchronization == VariableSynchronization.OnRead) + { + if (trainable.Value) + throw new ValueError("Synchronization value can be set to " + + "VariableSynchronization.ON_READ only for non-trainable variables. " + + "You have specified trainable=True and " + + "synchronization=VariableSynchronization.ON_READ."); + } + else if (!trainable.HasValue) + { + trainable = true; + } + + return trainable.Value; + } + + public static implicit operator VariableScope(variable_scope scope) + { + return scope._scope; + } + + [DebuggerHidden] + public void __exit__() + { + _cached_pure_variable_scope.__exit__(); + if (_current_name_scope != null) + _current_name_scope.__exit__(); + } + + [DebuggerHidden] + public void Dispose() + { + if (_current_name_scope != null) + _current_name_scope.Dispose(); + } + + // TODO for Switch/Case + public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool trainable = false, + bool validate_shape = true) + { + throw new NotImplementedException(); + } + + public void __init__() + { + + } + + public void __del__() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs new file mode 100644 index 000000000..91f57e292 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -0,0 +1,158 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class variables + { + /// + /// Returns all variables created with `trainable=True` + /// + /// + public static object trainable_variables() + { + return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); + } + + /// + /// Returns all variables and `SaveableObject`s that must be checkpointed. + /// + /// + /// + public static IVariableV1[] _all_saveable_objects(string scope = "") + { + var all = new List(); + + all.AddRange(ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)); + all.AddRange(ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); + + return all.ToArray(); + } + + /// + /// Returns global variables. + /// + /// + /// (Optional.) A string. If supplied, the resulting list is filtered + /// to include only items whose `name` attribute matches `scope` using + /// `re.match`. Items without a `name` attribute are never returned if a + /// scope is supplied. The choice of `re.match` means that a `scope` without + /// special tokens filters by prefix. + /// + /// A list of `Variable` objects. + public static List global_variables(string scope = null) + { + return ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); + } + + /// + /// Returns an Op that initializes a list of variables. + /// + /// List of `Variable` objects to initialize. + /// Optional name for the returned operation. + /// An Op that run the initializers of all the specified variables. + public static Operation variables_initializer(IVariableV1[] var_list, string name = "init") + { + if (var_list.Length > 0) + { + return control_flow_ops.group(var_list.Select(x => x.Initializer).ToArray(), name); + } + else + return gen_control_flow_ops.no_op(name: name); + } + + public static Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) + { + return _safe_initial_value_from_tensor(name, initial_value, new Dictionary()); + } + + public static Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) + { + var op = tensor.op; + Operation new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; + if (new_op == null) + { + new_op = _safe_initial_value_from_op(name, op, op_cache); + op_cache[op.name] = new_op; + } + + return new_op.outputs[tensor.value_index]; + } + + /// + /// Replace dependencies on variables with their initialized values. + /// + /// + /// + /// + /// + public static Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) + { + var op_type = op.node_def.Op; + if (op_type == "IsVariableInitialized" || + op_type == "VarIsInitializedOp" || + op_type == "ReadVariableOp") + return op; + + if (op_type == "Variable" || + op_type == "VariableV2" || + op_type == "VarHandleOp") + { + throw new NotImplementedException(""); + } + + // Recursively build initializer expressions for inputs. + bool modified = false; + var new_op_inputs = new List(); + foreach (Tensor op_input in op.inputs) + { + var new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache); + new_op_inputs.Add(new_op_input); + modified = modified || new_op_input != op_input; + } + + // If at least one input was modified, replace the op. + if (modified) + { + var new_op_type = op_type; + if (new_op_type == "RefSwitch") + new_op_type = "Switch"; + var new_op_name = op.node_def.Name + "_" + name; + new_op_name = new_op_name.Replace(":", "_"); + var _output_types = op._output_types; + + // Convert attr values to AttrValue protos. + var attr_protos = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attr_protos[attr_def.Key] = attr_def.Value; + + return op.graph.create_op( + new_op_type, + new_op_inputs.ToArray(), + _output_types, + name: new_op_name, + attrs: attr_protos); + } + + return op; + } + } +} diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs deleted file mode 100644 index bb49a0497..000000000 --- a/src/TensorFlowNET.Core/c_api.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; - -namespace Tensorflow -{ - /// - /// C API for TensorFlow. - /// - /// The API leans towards simplicity and uniformity instead of convenience - /// since most usage will be by language specific wrappers. - /// - public static partial class c_api - { - public const string TensorFlowLibName = "tensorflow"; - - public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData); - - [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_Version(); - } -} diff --git a/src/TensorFlowNET.Core/c_api_util.cs b/src/TensorFlowNET.Core/c_api_util.cs deleted file mode 100644 index f6d54062e..000000000 --- a/src/TensorFlowNET.Core/c_api_util.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class c_api_util - { - public static TF_Output tf_output(IntPtr c_op, int index) - { - var ret = new TF_Output(); - ret.oper = c_op; - ret.index = index; - - return ret; - } - } -} diff --git a/src/TensorFlowNET.Core/globals.regen b/src/TensorFlowNET.Core/globals.regen new file mode 100644 index 000000000..86cbee675 --- /dev/null +++ b/src/TensorFlowNET.Core/globals.regen @@ -0,0 +1,40 @@ +%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] +%supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +//this is the type we use in summerizing/reducting: +%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] + +%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] +%supported_numericals_signed_lowercase = ["short","int","long","double","float"] +%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] +%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] + +%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] +%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] +%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] +%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] + +%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] +%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] + +//this is the type we use in summerizing/reducting: +%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] + diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs new file mode 100644 index 000000000..adf2bb109 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -0,0 +1,173 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class ops + { + /// + /// Standard names to use for graph collections. + /// The standard library uses various well-known names to collect and + /// retrieve values associated with a graph. For example, the + /// `tf.Optimizer` subclasses default to optimizing the variables + /// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is + /// specified, but it is also possible to pass an explicit list of + /// variables. + /// + public class GraphKeys + { + #region const + /// + /// Key to collect concatenated sharded variables. + /// + public const string CONCATENATED_VARIABLES_ = "concatenated_variables"; + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public const string TRAINABLE_VARIABLES_ = "trainable_variables"; + + /// + /// Trainable resource-style variables. + /// + public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; + + /// + /// Key for streaming model ports. + /// + public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; + + /// + /// Key to collect losses + /// + public const string LOSSES_ = "losses"; + + public const string LOCAL_VARIABLES_ = "local_variables"; + + public const string METRIC_VARIABLES_ = "metric_variables"; + public const string MODEL_VARIABLES_ = "model_variables"; + + public const string MOVING_AVERAGE_VARIABLES_ = "moving_average_variables"; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public const string GLOBAL_VARIABLES_ = "variables"; + + public const string TRAIN_OP_ = "train_op"; + + public const string GLOBAL_STEP_ = "global_step"; + + /// + /// List of all collections that keep track of variables. + /// + public string[] _VARIABLE_COLLECTIONS_ = new string[] + { + GLOBAL_VARIABLES_, + LOCAL_VARIABLES_, + METRIC_VARIABLES_, + MODEL_VARIABLES_, + TRAINABLE_VARIABLES_, + MOVING_AVERAGE_VARIABLES_, + CONCATENATED_VARIABLES_, + TRAINABLE_RESOURCE_VARIABLES_ + }; + + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public const string SAVEABLE_OBJECTS_ = "saveable_objects"; + /// + /// Key to collect update_ops + /// + public const string UPDATE_OPS_ = "update_ops"; + + // Key to collect summaries. + public const string SUMMARIES_ = "summaries"; + + // Used to store v2 summary names. + public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; + + // Key for control flow context. + public const string COND_CONTEXT_ = "cond_context"; + public const string WHILE_CONTEXT_ = "while_context"; + + #endregion + + + public string CONCATENATED_VARIABLES => CONCATENATED_VARIABLES_; + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; + + /// + /// Trainable resource-style variables. + /// + public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; + + /// + /// Key for streaming model ports. + /// + public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; + + /// + /// Key to collect local variables that are local to the machine and are not + /// saved/restored. + /// + public string LOCAL_VARIABLES = LOCAL_VARIABLES_; + + /// + /// Key to collect losses + /// + public string LOSSES => LOSSES_; + + public string METRIC_VARIABLES => METRIC_VARIABLES_; + public string MOVING_AVERAGE_VARIABLES = MOVING_AVERAGE_VARIABLES_; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; + + public string TRAIN_OP => TRAIN_OP_; + + public string GLOBAL_STEP => GLOBAL_STEP_; + public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache"; + + public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; + /// + /// Key to collect update_ops + /// + public string UPDATE_OPS => UPDATE_OPS_; + + // Key to collect summaries. + public string SUMMARIES => SUMMARIES_; + + // Used to store v2 summary names. + public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; + + // Key for control flow context. + public string COND_CONTEXT => COND_CONTEXT_; + public string WHILE_CONTEXT => WHILE_CONTEXT_; + } + } +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs new file mode 100644 index 000000000..6f51150a2 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.cs @@ -0,0 +1,634 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Google.Protobuf; +using Google.Protobuf.Collections; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Graphs; +using Tensorflow.Util; +using static Tensorflow.Binding; +using static Tensorflow.CppShapeInferenceResult.Types; + +namespace Tensorflow +{ + public partial class ops + { + public static long tensor_id(Tensor tensor) + { + return tensor.Id; + } + + public static void add_to_collection(string name, T value) + { + var graph = tf.get_default_graph(); + graph.add_to_collection(name, value); + } + + public static void add_to_collections(List names, T value) + { + var graph = tf.get_default_graph(); + graph.add_to_collections(names, value); + } + + /// + /// Wrapper for `Graph.get_collection()` using the default graph. + /// contains many standard names for collections. + /// + /// + /// The key for the collection. For example, the `GraphKeys` class + /// + /// + /// + /// The list of values in the collection with the given `name`, or + /// an empty list if no value has been added to that collection. The + /// list contains the values in the order under which they were + /// collected. + /// + public static object get_collection(string key, string scope = null) + { + return get_default_graph().get_collection(key, scope); + } + + public static List get_collection(string key, string scope = null) + { + return get_default_graph().get_collection(key, scope); + } + + public static List get_collection_ref(string key) + { + return get_default_graph().get_collection_ref(key); + } + + public static Graph _get_graph_from_inputs(params object[] op_input_list) + { + var current_default_graph = get_default_graph(); + if (current_default_graph.building_function) + return current_default_graph; + + Graph graph = null; + foreach (var op_input in op_input_list) + { + if (op_input is Tensor op_input_tensor) + graph = graph ?? op_input_tensor.graph; + } + return graph ?? current_default_graph; + } + + public static Graph _get_graph_from_inputs(Tensors op_input_list) + => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); + + public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null) + { + foreach (var op_input in op_input_list) + { + // Determine if this is a valid graph_element. + // var graph_element = op_input; + } + + return get_default_graph(); + } + + /// + /// Converts the given `value` to a `Tensor`. + /// + /// + /// + /// + /// + public static Tensor convert_to_tensor(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool as_ref = false, + TF_DataType preferred_dtype = TF_DataType.DtInvalid, + Context ctx = null) + { + if (dtype == TF_DataType.DtInvalid) + dtype = preferred_dtype; + + if (dtype == TF_DataType.DtInvalid) + dtype = value.GetDataType(); + + if (value is EagerTensor eager_tensor) + { + if (tf.executing_eagerly()) + { + if (dtype != TF_DataType.DtInvalid && dtype != eager_tensor.dtype) + return gen_math_ops.cast(eager_tensor, dtype.as_base_dtype(), name: name); + return eager_tensor; + } + else + { + var graph = get_default_graph(); + if (graph is FuncGraph funcGraph) + { + return funcGraph.capture(eager_tensor, name: name); + } + if (!graph.building_function) + { + // throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); + return eager_tensor.AsPlaceholder(name: name); + } + } + } + else if (value is KerasTensor kt) + { + if (kt.inferred_value != null) + { + return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name); + } + } + + // graph mode + Tensor ret = value switch + { + NDArray nd => constant_op.constant(nd, dtype: dtype, name: name), + EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE + ? tensor.AsPlaceholder(name: name) + : tensor.AsConstant(name: name), + Tensor tensor => tensor, + IEnumerable tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), + RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), + ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), + Axis ts => constant_op.constant(ts, dtype: dtype, name: name), + Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), + string str => constant_op.constant(str, dtype: tf.@string, name: name), + string[] str => constant_op.constant(str, dtype: tf.@string, name: name), + IEnumerable objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), + _ => constant_op.constant(value, dtype: dtype, name: name) + }; + + if (dtype == TF_DataType.TF_STRING) + return ret; + + if (dtype != TF_DataType.DtInvalid && dtype.as_base_dtype() != ret.dtype.as_base_dtype()) + ret = gen_math_ops.cast(ret, dtype, name: name); + + return ret; + } + + + public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + { + return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false); + } + + public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + => convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); + + /// + /// Wrapper for `Graph.control_dependencies()` using the default graph. + /// + /// See `tf.Graph.control_dependencies` for more details. + /// + /// When eager execution is enabled, any callable object in the `control_inputs` + /// list will be called. + /// + /// + /// A list of `Operation` or `Tensor` objects which + /// must be executed or computed before running the operations + /// defined in the context.Can also be `None` to clear the control + /// dependencies.If eager execution is enabled, any callable object in the + /// `control_inputs` list will be called. + /// + /// + /// A context manager that specifies control dependencies for all + /// operations constructed within the context. + /// + public static _ControlDependenciesController control_dependencies(object[] control_inputs) + => get_default_graph().control_dependencies(control_inputs); + + /// + /// Creates a TF_Operation. + /// + /// a `Graph`. + /// `node_def_pb2.NodeDef` for the operation to create. + /// + /// A list of `Tensor`s (corresponding to scalar inputs) and lists of + /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", + /// "list(int64)"). The length of the list should be equal to the number of + /// inputs specified by this operation's op def. + /// + /// A list of `Operation`s to set as control dependencies. + /// A wrapped TF_Operation*. + public static (IntPtr, OperationDescription) _create_c_op(Graph graph, NodeDef node_def, Tensor[] inputs, Operation[] control_inputs, + OpDef op_def = null) + { + if (op_def == null) + op_def = graph.GetOpDef(node_def.Op); + + var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); + + var op_desc = graph.NewOperation(node_def.Op, node_def.Name); + + if (!string.IsNullOrEmpty(node_def.Device)) + c_api.TF_SetDevice(op_desc, node_def.Device); + + // Add inputs + foreach (var op_input in input_tensors) + { + if (op_input.IsList) + c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count()); + else if (op_input.Count() == 1) + c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output()); + } + + var status = tf.Status; + + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); + + // Add attrs + foreach (var attr in node_def.Attr) + { + var bytes = attr.Value.ToByteArray(); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status); + status.Check(true); + } + + var c_op = op_desc.FinishOperation(status); + + status.Check(true); + + return (c_op, op_desc); + } + + public static Tensors[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs) + { + var grouped_inputs = new List(); + int i = 0; + + foreach (var input_arg in op_def.InputArg) + { + int input_len = 1; + bool is_sequence = false; + + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) + { + input_len = (int)attrs[input_arg.NumberAttr].I; + is_sequence = true; + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + input_len = attrs[input_arg.TypeListAttr].List.Type.Count; + is_sequence = true; + } + + if (is_sequence) + { + var input_tensors = new Tensors(inputs.Skip(i).Take(input_len).ToArray()); + input_tensors.IsList = true; + grouped_inputs.Add(input_tensors); + } + else + grouped_inputs.Add(inputs[i]); + + i += input_len; + } + + return grouped_inputs.ToArray(); + } + + public static OpDef _get_op_def(Graph graph, string type) + { + return graph.GetOpDef(type); + } + + public static NodeDef _NodeDef(string op_type, string name, Dictionary attrs = null) + { + var node_def = new NodeDef(); + node_def.Op = op_type; + node_def.Name = name; + + if (attrs != null) + { + foreach (var attr in attrs) + node_def.Attr.Add(attr.Key, attr.Value); + } + + return node_def; + } + + public static string name_from_scope_name(string name) + { + if (name == null) + return null; + else if (name.EndsWith("/")) + return name.Substring(0, name.Length - 1); + else + return name; + } + + /// + /// A context manager that lifts ops out of control-flow scopes and function-building graphs. + /// + /// + public static NameScope init_scope() + { + // Retrieve the active name scope: entering an `init_scope` preserves + // the name scope of the current context. + var default_graph = get_default_graph(); + var scope = default_graph.get_name_scope(); + if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) + // Names that end with trailing slashes are treated by `name_scope` as + // absolute. + scope += "/"; + // inner_device_stack = default_graph._device_function_stack + // var outer_context = default_graph.as_default; + + tf_with(ops.control_dependencies(null), delegate + { + // var outer_graph = get_default_graph(); + // outer_device_stack = None + }); + + tf.Context.ScopeName = scope; + return ops.name_scope(scope); + } + + private static int uid_number = -1; + + /// + /// A unique (within this program execution) integer. + /// Not thread safe + /// + /// + public static int uid() + { + return Interlocked.Increment(ref uid_number); + } + + static int graph_uid_number = -1; + public static int GraphUniqueId() + { + return Interlocked.Increment(ref graph_uid_number); + } + + static int uid_number_for_function = 0; + public static int uid_function() + => Interlocked.Increment(ref uid_number_for_function); + + static int uid_number_for_layer = 0; + public static int uid_layer() + => Interlocked.Increment(ref uid_number_for_layer); + + public static void reset_uid() + { + uid_number = -1; + graph_uid_number = -1; + uid_number_for_function = 0; + uid_number_for_layer = 0; + } + + public static void colocate_with(bool ignore_existing = false) + { + _colocate_with_for_gradient(null, null, ignore_existing); + } + + public static void colocate_with(Operation op, bool ignore_existing = false) + { + _colocate_with_for_gradient(op, null, ignore_existing); + } + + public static void colocate_with(Tensor tensor, bool ignore_existing = false) + { + _colocate_with_for_gradient(tensor.op, null, ignore_existing); + } + + public static void colocate_with(IVariableV1 variable, bool ignore_existing = false) + { + _colocate_with_for_gradient(variable.AsTensor(), null, ignore_existing); + } + + public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false) + { + var default_graph = get_default_graph(); + default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing); + } + + /// + /// Uses the default session to evaluate one or more tensors. + /// + /// A single Tensor, or a list of Tensor objects. + /// + /// A dictionary that maps Tensor objects (or tensor names) to lists, + /// numpy ndarrays, TensorProtos, or strings. + /// + /// The graph in which the tensors are defined. + /// A different session to use to evaluate "tensors". + /// + /// Either a single numpy ndarray if "tensors" is a single tensor; or a list + /// of numpy ndarrays that each correspond to the respective element in + /// "tensors". + /// + public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null) + { + if (session == null) + { + session = get_default_session(); + + if (session == null) + throw new ValueError("Cannot evaluate tensor using `eval()`: No default " + + "session is registered. Use `with " + + "sess.as_default()` or pass an explicit session to " + + "`eval(session=sess)`"); + + if (session.graph != graph) + throw new ValueError("Cannot use the default session to evaluate tensor: " + + "the tensor's graph is different from the session's " + + "graph. Pass an explicit session to " + + "`eval(session=sess)`."); + } + else + { + if (session.graph != graph) + throw new ValueError("Cannot use the default session to evaluate tensor: " + + "the tensor's graph is different from the session's " + + "graph. Pass an explicit session to " + + "`eval(session=sess)`."); + } + + return session.run(tensor, feed_dict); + } + + /// + /// Prepends name scope to a name. + /// + /// + /// + /// + public static string prepend_name_scope(string name, string import_scope) + { + if (!string.IsNullOrEmpty(import_scope)) + { + if (import_scope.EndsWith("/")) + import_scope = import_scope.Substring(0, import_scope.Length - 1); + + return $"{import_scope}/{name}"; + } + else + return name; + } + + public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) + { + if (session == null) + { + session = get_default_session(); + if (session == null) + throw new ValueError("Cannot execute operation using `run()`: No default " + + "session is registered. Use `with " + + "sess.as_default():` or pass an explicit session to " + + "`run(session=sess)`"); + } + + if (session.graph != graph) + throw new ValueError("Cannot use the default session to execute operation: " + + "the operation's graph is different from the " + + "session's graph. Pass an explicit session to " + + "run(session=sess)."); + + session.run(operation, feed_dict); + } + + public static Tensor[] convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + => internal_convert_n_to_tensor(values, dtype: dtype, name: name, as_ref: false); + + public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + => internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name); + + public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + => internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false); + + public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + => value; + + public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + var ret = new List(); + + foreach (var (i, value) in enumerate(values)) + { + if (value == null) + { + ret.Add(value); + } + else + { + var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; + ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref)); + } + } + + return ret.ToArray(); + } + + public static Tensor[] internal_convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, + bool as_ref = false) + { + var ret = new List(); + foreach ((int i, object value) in enumerate(values)) + { + string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; + ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); + } + return ret.ToArray(); + } + + public static string strip_name_scope(string name, string export_scope = "") + { + if (!string.IsNullOrEmpty(export_scope)) + { + throw new NotImplementedException("ops.strip_name_scope"); + } + else + { + return name; + } + } + + public static string get_name_scope() + { + var g = get_default_graph(); + return g.get_name_scope(); + } + + public static bool executing_eagerly_outside_functions() + { + if (tf.Context.executing_eagerly()) + return true; + else + // TODO(Wanglongzhi2001), implement the false case + return true; + //throw new NotImplementedException(""); + } + + public static bool inside_function() + { + return get_default_graph().building_function; + } + + public static HandleData get_resource_handle_data(Tensor graph_op) + { + var handle_data = c_api.TF_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); + try{ + var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); + return HandleData.Parser.ParseFrom(handle_str); + } + catch(Exception){ + var handle_str = c_api.ByteStringPieceFromNativeString(handle_data.DangerousGetHandle()); + return HandleData.Parser.ParseFrom(handle_str); + } + } + + public static void dismantle_graph(Graph graph) + { + + } + + public static ITensorFlowObject device(string device_name) + { + if (tf.Context.executing_eagerly()) + { + return tf.Context.device(device_name); + } + //else if (ops.executing_eagerly_outside_functions()) + //{ + // throw new NotImplementedException(); + //} + else + { + return get_default_graph().device(device_name); + } + // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. + } + + public class NullContextManager: IDisposable + { + public void Dispose() + { + + } + } + } +} diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs new file mode 100644 index 000000000..3872d5b1a --- /dev/null +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -0,0 +1,124 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Diagnostics; +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class ops + { + public static NameScope name_scope(string name, + string default_name = "", + object values = null, + bool skip_on_eager = true) => new NameScope(name, default_name, values: values, skip_on_eager: skip_on_eager); + + /// + /// Returns a context manager that creates hierarchical names for operations. + /// + public class NameScope : ITensorFlowObject + { + public string _name; + public string _default_name; + public object _values; + public string scope_name; + public string old_scope_name = ""; + bool _skip_on_eager = false; + + public NameScope(string name, string default_name = "", object values = null, bool skip_on_eager = true) + { + _name = name; + _default_name = default_name; + _values = values; + _skip_on_eager = skip_on_eager; + } + + [DebuggerStepThrough] + public void __enter__() + { + if (tf.Context.executing_eagerly()) + { + (scope_name, old_scope_name) = enter_eager_name_scope(tf.Context, _name); + } + else + { + _name = _name ?? _default_name; + Graph g = null; + + if (_values is List vList) + g = _get_graph_from_inputs(vList.ToArray()); + else if (_values is Tensor[] vArray) + g = _get_graph_from_inputs(vArray); + + if (g == null) + g = get_default_graph(); + + old_scope_name = g._name_stack; + scope_name = g.name_scope(_name); + } + } + + private (string, string) enter_eager_name_scope(Context ctx, string name) + { + if (_skip_on_eager) + return (null, null); + + if (name == null) + name = _default_name; + + var scope_name = name; + var old_name = ctx.ScopeName; + // A trailing slash breaks out of nested name scopes, indicating a + // fully specified scope name, for compatibility with Graph.name_scope. + if (!name.EndsWith("/")) + { + scope_name = name + "/"; + if (!string.IsNullOrEmpty(old_name)) + scope_name = old_name + scope_name; + } + + ctx.ScopeName = scope_name; + return (scope_name, old_name); + } + + [DebuggerStepThrough] + public void Dispose() + { + + } + + [DebuggerStepThrough] + public void __exit__() + { + if (tf.Context.executing_eagerly()) + tf.Context.ScopeName = old_scope_name; + else + get_default_graph()._name_stack = old_scope_name; + } + + /// + /// __enter__() + /// + /// + public static implicit operator string(NameScope ns) + { + return ns.scope_name; + } + } + } +} diff --git a/src/TensorFlowNET.Core/ops.threading.cs b/src/TensorFlowNET.Core/ops.threading.cs new file mode 100644 index 000000000..6c6476a51 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.threading.cs @@ -0,0 +1,94 @@ +using System; +using System.Threading; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class ops + { + [ThreadStatic] + static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); + [ThreadStatic] + static Session defaultSession; + + /// + /// Returns the default session for the current thread. + /// + /// The default `Session` being used in the current thread. + public static Session get_default_session() + { + if (defaultSession == null) + defaultSession = new Session(tf.get_default_graph()); + + return defaultSession; + } + + /// + /// Returns the default session for the current thread. + /// + /// The default `Session` being used in the current thread. + public static Session set_default_session(Session sess) + { + defaultSession = sess; + return sess; + } + + /// + /// Returns the default graph for the current thread. + /// + /// The returned graph will be the innermost graph on which a + /// `Graph.as_default()` context has been entered, or a global default + /// graph if none has been explicitly created. + /// + /// NOTE: The default graph is a property of the current thread.If you + /// create a new thread, and wish to use the default graph in that + /// thread, you must explicitly add a `with g.as_default():` in that + /// thread's function. + /// + /// + public static Graph get_default_graph() + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.get_default(); + } + + public static Graph set_default_graph(Graph g) + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.get_controller(g); + } + + /// + /// Clears the default graph stack and resets the global default graph. + /// + /// NOTE: The default graph is a property of the current thread.This + /// function applies only to the current thread.Calling this function while + /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined + /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects + /// after calling this function will result in undefined behavior. + /// + /// + public static void reset_default_graph() + { + if (default_graph_stack == null) + return; + //if (!_default_graph_stack.is_cleared()) + // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + + // "nested graphs. If you need a cleared graph, " + + // "exit the nesting and create a new graph."); + default_graph_stack.reset(); + } + + public static Graph peak_default_graph() + { + if (default_graph_stack == null) + default_graph_stack = new DefaultGraphStack(); + return default_graph_stack.peak_controller(); + } + + public static void pop_graph() + => default_graph_stack.pop(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll new file mode 100644 index 000000000..82e86b4e5 Binary files /dev/null and b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll differ diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs new file mode 100644 index 000000000..e368b37cd --- /dev/null +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -0,0 +1,155 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Razorvine.Pickle; +using Serilog; +using Serilog.Core; +using System.Reflection; +using System.Threading; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Gradients; +using Tensorflow.Keras; +using Tensorflow.NumPy.Pickle; + +namespace Tensorflow +{ + public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); + + public partial class tensorflow + { + public TF_DataType byte8 = TF_DataType.TF_UINT8; + public TF_DataType int8 = TF_DataType.TF_INT8; + public TF_DataType int16 = TF_DataType.TF_INT16; + public TF_DataType int32 = TF_DataType.TF_INT32; + public TF_DataType int64 = TF_DataType.TF_INT64; + public TF_DataType float16 = TF_DataType.TF_HALF; + public TF_DataType float32 = TF_DataType.TF_FLOAT; + public TF_DataType float64 = TF_DataType.TF_DOUBLE; + public TF_DataType @bool = TF_DataType.TF_BOOL; + public TF_DataType chars = TF_DataType.TF_STRING; + public TF_DataType @string = TF_DataType.TF_STRING; + + public OpDefLibrary OpDefLib; + public Logger Logger; + + ThreadLocal _status = new ThreadLocal(() => new Status()); + public Status Status => _status.Value; + + ThreadLocal _context = new ThreadLocal(() => new Context()); + public Context Context => _context.Value; + + ThreadLocal _runner = new ThreadLocal(() => new EagerRunner()); + public IEagerRunner Runner => _runner.Value; + + private IKerasApi _keras; + public IKerasApi keras + { + get + { + if (_keras != null) + { + return _keras; + } + + var k = Assembly.Load("Tensorflow.Keras"); + var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi))); + if (cls != null) + { + _keras = Activator.CreateInstance(cls) as IKerasApi; + return _keras; + } + else + { + throw new Exception("Can't find keras library."); + } + } + } + + public tensorflow() + { + Logger = new LoggerConfiguration() + .MinimumLevel.Error() + .WriteTo.Console() + .CreateLogger(); + + OpDefLib = new OpDefLibrary(); + InitGradientEnvironment(); + + try + { + var handle = c_api.TF_Version(); + } + catch (DllNotFoundException) + { + throw new RuntimeError("Tensorflow.NET cannot find a backend. Please install one of the following packages for your program: " + + "SciSharp.TensorFlow.Redist, SciSharp.TensorFlow.Redist-Linux-GPU, SciSharp.TensorFlow.Redist-Windows-GPU. For more details, " + + "please visit https://github.com/SciSharp/TensorFlow.NET. If it still not work after installing the backend, please submit an " + + "issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + + // register numpy reconstructor for pickle + Unpickler.registerConstructor("numpy.core.multiarray", "_reconstruct", new MultiArrayConstructor()); + Unpickler.registerConstructor("numpy", "dtype", new DtypeConstructor()); + } + + public string VERSION => c_api.StringPiece(c_api.TF_Version()); + + private void InitGradientEnvironment() + { + _tapeSet = new GradientTape(); + ops.RegisterFromAssembly(); + } + + public ResourceVariable Variable(T data, + bool trainable = true, + bool validate_shape = true, + bool use_resource = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null) + => new ResourceVariable(data, + trainable: trainable, + validate_shape: validate_shape, + name: name, + dtype: dtype, + aggregation: aggregation, + shape: shape); + + public Tensor placeholder(TF_DataType dtype, Shape shape = null, string name = null) + => array_ops.placeholder(dtype, shape, name); + + public void enable_eager_execution() + => Context.eager_mode(); + + public Session get_default_session() + => ops.get_default_session(); + + public Session Session() + => compat.v1.Session(); + + public Session Session(Graph graph, ConfigProto config = null) + { + return new Session(graph, config: config).as_default(); + } + + public Session Session(ConfigProto config) + { + return new Session(null, config).as_default(); + } + } +} diff --git a/src/TensorFlowNET.Core/tensorflow.memory.cs b/src/TensorFlowNET.Core/tensorflow.memory.cs new file mode 100644 index 000000000..ae8590fe8 --- /dev/null +++ b/src/TensorFlowNET.Core/tensorflow.memory.cs @@ -0,0 +1,56 @@ +using System; + +namespace Tensorflow +{ + public partial class tensorflow + { + public unsafe void memcpy(T* dst, void* src, ulong size) + where T : unmanaged + { + System.Buffer.MemoryCopy(src, dst, size, size); + } + + public unsafe void memcpy(void* dst, T* src, ulong size) + where T : unmanaged + { + System.Buffer.MemoryCopy(src, dst, size, size); + } + + public unsafe void memcpy(void* dst, IntPtr src, ulong size) + { + System.Buffer.MemoryCopy(src.ToPointer(), dst, size, size); + } + + public unsafe void memcpy(T[] dst, IntPtr src, ulong size) + where T : unmanaged + { + fixed (void* p = &dst[0]) + System.Buffer.MemoryCopy(src.ToPointer(), p, size, size); + } + + public unsafe void memcpy(T[] dst, IntPtr src, long size) + where T : unmanaged + { + fixed (void* p = &dst[0]) + System.Buffer.MemoryCopy(src.ToPointer(), p, size, size); + } + + public unsafe void memcpy(IntPtr dst, T[] src, ulong size) + where T : unmanaged + { + if (src.Length == 0) return; + + fixed (void* p = &src[0]) + System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); + } + + public unsafe void memcpy(IntPtr dst, T[] src, long size) + where T : unmanaged + { + if (src.Length == 0) return; + + fixed (void* p = &src[0]) + System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); + } + } +} diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs deleted file mode 100644 index 09783aa13..000000000 --- a/src/TensorFlowNET.Core/tf.cs +++ /dev/null @@ -1,76 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; -using TF_DataType = Tensorflow.DataType; -using attr_value_pb2 = Tensorflow; -using Tensorflow.Eager; - -namespace Tensorflow -{ - public static class tf - { - public static TF_DataType float32 = TF_DataType.TF_FLOAT; - public static TF_DataType chars = TF_DataType.TF_STRING; - - public static Context context = new Context(); - - public static Graph g = new Graph(c_api.TF_NewGraph()); - - public static object Variable(T data, TF_DataType dtype) - { - return new Variable(null, TF_DataType.DtInvalid); - } - - public static unsafe Tensor add(Tensor a, Tensor b) - { - return gen_math_ops.add(a, b); - } - - public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) - { - return gen_array_ops.placeholder(dtype, shape); - } - - public static unsafe Tensor constant(object value) - { - var g = ops.get_default_graph(); - var tensor_value = new attr_value_pb2.AttrValue(); - var tensor_pb = tensor_util.make_tensor_proto(value); - tensor_value.Tensor = tensor_pb; - var dtype_value = new attr_value_pb2.AttrValue - { - Type = tensor_value.Tensor.Dtype, - }; - - var attrs = new Dictionary(); - attrs["dtype"] = dtype_value; - attrs["value"] = tensor_value; - var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; - - return const_tensor; - } - - public static void enable_eager_execution() - { - context.default_execution_mode = Context.EAGER_MODE; - } - - public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version()); - - public static Graph get_default_graph() - { - return ops.get_default_graph(); - } - - public static Graph Graph() - { - return g; - } - - public static Session Session() - { - return new Session(); - } - } -} diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs new file mode 100644 index 000000000..d3801902f --- /dev/null +++ b/src/TensorFlowNET.Keras/Activations.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public class Activations: IActivationsApi + { + private static Dictionary _nameActivationMap; + + private static Activation _linear = new Activation() + { + Name = "linear", + ActivationFunction = (features, name) => features + }; + private static Activation _relu = new Activation() + { + Name = "relu", + ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)) + }; + private static Activation _relu6 = new Activation() + { + Name = "relu6", + ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu6", name, new ExecuteOpArgs(features)) + }; + private static Activation _sigmoid = new Activation() + { + Name = "sigmoid", + ActivationFunction = (features, name) => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)) + }; + private static Activation _softmax = new Activation() + { + Name = "softmax", + ActivationFunction = (features, name) => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features)) + }; + private static Activation _tanh = new Activation() + { + Name = "tanh", + ActivationFunction = (features, name) => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)) + }; + private static Activation _mish = new Activation() + { + Name = "mish", + ActivationFunction = (features, name) => features * tf.math.tanh(tf.math.softplus(features)) + }; + + /// + /// Register the name-activation mapping in this static class. + /// + /// + private static void RegisterActivation(Activation activation) + { + _nameActivationMap[activation.Name] = activation; + } + + static Activations() + { + _nameActivationMap = new Dictionary(); + + RegisterActivation(_relu); + RegisterActivation(_relu6); + RegisterActivation(_linear); + RegisterActivation(_sigmoid); + RegisterActivation(_softmax); + RegisterActivation(_tanh); + RegisterActivation(_mish); + } + + public Activation Linear => _linear; + + public Activation Relu => _relu; + public Activation Relu6 => _relu6; + + public Activation Sigmoid => _sigmoid; + + public Activation Softmax => _softmax; + + public Activation Tanh => _tanh; + + public Activation Mish => _mish; + + public Activation GetActivationFromName(string name) + { + if (name == null) + { + return _linear; + } + if (!_nameActivationMap.TryGetValue(name, out var res)) + { + throw new Exception($"Activation {name} not found"); + } + else + { + return res; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/BackendBase.cs b/src/TensorFlowNET.Keras/BackendBase.cs new file mode 100644 index 000000000..c29fa273b --- /dev/null +++ b/src/TensorFlowNET.Keras/BackendBase.cs @@ -0,0 +1,86 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public abstract class BackendBase + { + TF_DataType _FLOATX = dtypes.float32; + float _EPSILON = 1e-7f; + ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; + + + public float epsilon() => _EPSILON; + + public void set_epsilon(float e) => _EPSILON = e; + + public TF_DataType floatx() => _FLOATX; + + public void set_floatx(TF_DataType floatx) => _FLOATX = floatx; + + //public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype()); + + public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT; + + public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format; + + public ImageDataFormat normalize_data_format(object value = null) + { + if (value == null) + value = _IMAGE_DATA_FORMAT; + if (isinstance(value, typeof(ImageDataFormat))) + return (ImageDataFormat)value; + else if (isinstance(value, typeof(string))) + { + ImageDataFormat dataFormat; + if (Enum.TryParse((string)value, true, out dataFormat)) + { + if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(",")) + return dataFormat; + } + } + throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString()); + } + + //Legacy Methods + + public void set_image_dim_ordering(ImageDimOrder dim_ordering) + { + if (dim_ordering == ImageDimOrder.th) + _IMAGE_DATA_FORMAT = ImageDataFormat.channels_first; + else if (dim_ordering == ImageDimOrder.tf) + _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; + else + throw new Exception("Unknown dim_ordering:" + dim_ordering); + } + + public ImageDimOrder image_dim_ordering() + { + if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first) + return ImageDimOrder.th; + else + return ImageDimOrder.tf; + } + } + public enum ImageDimOrder + { + tf, + th + } +} diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs new file mode 100644 index 000000000..574cf5990 --- /dev/null +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -0,0 +1,1005 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Linq; +using System.Collections.Generic; +using Tensorflow.Functions; +using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; +using static Tensorflow.Binding; +using static Tensorflow.Graphs.SubGraphUtility; +using Tensorflow.Util; +using Tensorflow.Common.Types; +using System.Diagnostics; + +namespace Tensorflow.Keras +{ + public class BackendImpl : BackendBase + { + /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */ + public Func py_sum = sum; + public Func py_all = all; + //Func py_any = any; + //Func> py_slice = slice; + + public Session _SESSION => ops.get_default_session(); + + public Graph _GRAPH; + FuncGraph _CURRENT_SCRATCH_GRAPH; + public Dictionary _GRAPH_LEARNING_PHASES; + //Dictionary> PER_GRAPH_LAYER_NAME_UIDS; + public bool _MANUAL_VAR_INIT = false; + public List _LOCAL_DEVICES = null; + /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */ + + /// + /// A global dictionary mapping graph objects to an index of counters used + /// for various layer names in each graph. + /// Allows to give unique autogenerated names to layers, in a graph-specific way. + /// + public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public Dictionary _GRAPH_VARIABLES = new Dictionary(); + public Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary(); + + public _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); + + public BackendImpl() + { + } + + public void track_variable(IVariableV1 v) + { + if (tf.Context.executing_eagerly()) + { + return; + } + var graph = v.Graph; + if(graph is null) + { + graph = get_graph(); + } + _GRAPH_VARIABLES[graph.graph_key] = v; + } + + public KerasTensor placeholder(Shape shape = null, + int ndim = -1, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + string name = null, + bool ragged = false) + { + if (sparse) + { + throw new NotImplementedException("placeholder sparse is true"); + } + else + { + return array_ops.placeholder(dtype: dtype, shape: shape, name: name); + } + } + + public Graph get_graph() + { + if (tf.Context.executing_eagerly()) + { + if (_GRAPH == null) + _GRAPH = new FuncGraph("keras_graph"); + + return _GRAPH; + } + return ops.get_default_graph(); + } + + FuncGraph _scratch_graph() + { + if (_CURRENT_SCRATCH_GRAPH == null) + _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); + + return _CURRENT_SCRATCH_GRAPH; + } + + public int get_uid(string prefix) + { + var graph = tf.get_default_graph(); + if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) + PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict()); + if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix)) + PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0; + PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1; + + return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix]; + } + + public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public void clear_session() + { + tf.Context.reset_context(); + reset_uids(); + // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); + if (_GRAPH_LEARNING_PHASES != null) + _GRAPH_LEARNING_PHASES.Clear(); + if (_GRAPH_LEARNING_PHASES != null) + _GRAPH_LEARNING_PHASES.Clear(); + PER_GRAPH_LAYER_NAME_UIDS.Clear(); + _CURRENT_SCRATCH_GRAPH = null; + _GRAPH = null; + + ops.set_default_session(tf.Session(ops.get_default_graph())); + tf.enable_eager_execution(); + tf.Runner.ClearEagerOperationMap(); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + public void manual_variable_initialization(bool value) + { + _MANUAL_VAR_INIT = value; + } + + public Tensor mean(Tensor x, int axis = -1, bool keepdims = false) + { + if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL) + x = math_ops.cast(x, TF_DataType.TF_FLOAT); + return math_ops.reduce_mean(x, axis: axis, keepdims: false); + } + + public GraphLearningPhase learning_phase() + { + var graph = tf.get_default_graph(); + if (_GRAPH_LEARNING_PHASES.ContainsKey(graph)) + { + var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase"); + _GRAPH_LEARNING_PHASES[graph] = 0; + } + return _GRAPH_LEARNING_PHASES[graph]; + } + public void set_learning_phase(bool value) + { + _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); + } + + public void set_value(IVariableV1 x, object value) + { + // TODO(Rinne): check the implementation. + x.assign(value); + } + + public void batch_set_value(List<(IVariableV1, NDArray)> tuples) + { + if (ops.executing_eagerly_outside_functions()) + { + foreach (var (x, value) in tuples) + x.assign(value, read_value: false); + } + else + { + throw new NotImplementedException(""); + } + } + + /// + /// Pads the 2nd and 3rd dimensions of a 4D tensor. + /// + /// + /// + /// + /// + public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null) + { + if (padding == null) + padding = new[,] { { 1, 1 }, { 1, 1 } }; + + NDArray pattern; + + if (data_format == "channels_first") + pattern = new int[,] + { + { 0, 0 }, + { 0, 0 }, + { padding[0][0], padding[0][1] }, + { padding[1][0], padding[1][1] } + }; + else + pattern = new int[,] + { + { 0, 0 }, + { padding[0][0], padding[0][1] }, + { padding[1][0], padding[1][1] }, + { 0, 0 } + }; + return array_ops.pad(x, pattern); + } + + /// + /// Method to evaluate a tensor in eager or in a tf.function. + /// + /// + /// + public NDArray eval_in_eager_or_function(Tensors outputs) + { + if (outputs[0].op.type == "Const") + return tensor_util.constant_value(outputs); + + var source_graph = outputs.graph; + var exec_graph = _scratch_graph(); + var global_graph = get_graph(); + if (source_graph == global_graph && exec_graph != global_graph) + { + var lifted_map = lift_to_graph(outputs, exec_graph, + new List(), + add_sources: true, + handle_captures: true, + base_graph: source_graph); + } + if (outputs[0].op.type == "Placeholder" + || outputs[0].op.type == "StridedSlice") + return exec_graph.external_captures.Last().numpy(); + + // Consolidate updates + exec_graph.as_default(); + exec_graph.Inputs = exec_graph.internal_captures; + exec_graph.Outputs = outputs; + + var graph_fn = new ConcreteFunction(exec_graph); + + _CURRENT_SCRATCH_GRAPH = null; + tf.Context.restore_mode(); + // return outputs.eval(); + throw new NotImplementedException(""); + } + + public class _DummyEagerGraph + { } + + /// + /// Categorical crossentropy between an output tensor and a target tensor. + /// + /// + /// + /// + /// + /// + public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1) + { + if (from_logits) + return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis); + + if (output.op != null && output.op.type == "Softmax") + { + if (output.op.inputs.Length != 1) throw new ApplicationException(); + var o = output.op.inputs[0]; + return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: o, axis: axis); + } + + // scale preds so that the class probas of each sample sum to 1 + output = output / math_ops.reduce_sum(output, new Axis(axis), true); + // Compute cross entropy from probabilities. + var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype()); + output = clip_ops.clip_by_value(output, epsilon_, 1.0f - epsilon_); + return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis)); + } + + public Tensor sparse_categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1, int? ignore_class = null) + { + target = tf.cast(target, tf.int64); + if (!from_logits) + { + var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype()); + output = tf.clip_by_value(output, epsilon_, 1 - epsilon_); + output = tf.math.log(output); + } + var output_rank = output.shape.ndim; + if (output_rank > -1) + { + axis = Math.Abs(axis) % output_rank; + if (axis != output_rank - 1) + { + /*var permutation = list( + itertools.chain( + range(axis), range(axis + 1, output_rank), [axis] + ) + ); + output = tf.transpose(output, perm: permutation);*/ + throw new NotImplementedException(""); + } + + } + + var output_shape = tf.shape(output); + var target_rank = target.shape.ndim; + var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1; + if (update_shape) + { + target = tf.reshape(target, -1); + output = tf.reshape(output, (-1, output.shape[-1])); + } + + if (ignore_class.HasValue) + { + throw new NotImplementedException(""); + } + + var res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels: target, logits: output); + + if (ignore_class.HasValue) + { + throw new NotImplementedException(""); + } + + if (update_shape && output_rank >= 3) + { + // If our output includes timesteps or + // spatial dimensions we need to reshape + res = tf.reshape(res, output_shape[":-1"]); + } + + return res; + } + + public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false) + { + if (from_logits) + return tf.nn.sigmoid_cross_entropy_with_logits(labels: target, logits: output); + + var epsilon_ = constant_op.constant(epsilon(), dtype: output.dtype.as_base_dtype()); + output = tf.clip_by_value(output, epsilon_, 1.0f - epsilon_); + + // Compute cross entropy from probabilities. + var bce = target * tf.math.log(output + epsilon()); + bce += (1 - target) * tf.math.log(1 - output + epsilon()); + return -bce; + } + + /// + /// Resizes the images contained in a 4D tensor. + /// + /// + /// + /// + /// + /// + /// + public Tensor resize_images(Tensor x, int height_factor, int width_factor, + string data_format, string interpolation = "nearest") + { + var (rows, cols) = (0, 0); + if (data_format == "channels_first") + (rows, cols) = (2, 3); + else if (data_format == "channels_last") + (rows, cols) = (1, 2); + else + throw new ValueError($"Invalid `data_format` argument: {data_format}"); + + var original_shape = x.shape; + var new_shape = array_ops.shape(x)[new Slice(rows, cols + 1)]; + new_shape *= constant_op.constant(np.array(height_factor, width_factor)); + + if (data_format == "channels_first") + // x = permute_dimensions(x, [0, 2, 3, 1]); + throw new NotImplementedException(""); + if (interpolation == "nearest") + x = tf.image.resize_images_v2(x, new_shape, method: ResizeMethod.NEAREST_NEIGHBOR); + + if (data_format == "channels_first") + // x = permute_dimensions(x, [0, 3, 1, 2]); + throw new NotImplementedException(""); + + int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor; + int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor; + + Shape output_shape = data_format == "channels_first" ? + (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1); + x.shape = output_shape; + return x; + } + + /// + /// Concatenates a list of tensors alongside the specified axis. + /// + /// list of tensors to concatenate. + /// concatenation axis. + /// + public Tensor concatenate(Tensors tensors, int axis = -1) + { + if(axis < 0) + { + var rank = tensors[0].ndim; + if (rank > -1) + axis += rank; + else + axis = 0; + } + + return array_ops.concat(tensors, axis); + } + + public Tensor conv2d_transpose(Tensor x, + IVariableV1 kernel, + Tensor output_shape, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null) + { + /* + var force_transpose = false; + if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 })) + force_transpose = true; + x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose) + */ + var tf_data_format = "NHWC"; + padding = padding.ToUpper(); + strides = new Shape(1, strides[0], strides[1], 1); + if (dilation_rate.Equals(new long[] { 1, 1 })) + x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides, + padding: padding, + data_format: tf_data_format); + else + throw new NotImplementedException("dilation_rate other than [1,1] is not yet supported"); + + return x; + } + + public (Tensors, Tensors, Tensors) rnn( + Func step_function, // args:inputs, states, return:output, new_states + Tensors inputs, // inputs is a tuple of tensors (one per input sequence) + Tensors initial_states, + bool go_backwards = false, + Tensor? mask = null, + Tensors? constants = null, + bool unroll = false, + Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not + bool time_major = false, + bool zero_output_for_mask = false, + bool return_all_outputs = true) + { + + Tensor swap_batch_timestep(Tensor input_t) + { + var axes = Enumerable.Range(0, input_t.rank).ToArray(); + axes[0] = 1; + axes[1] = 0; + return tf.transpose(input_t, axes); + } + + if (!time_major) + { + inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); + } + + var flatted_inptus = Nest.Flatten(inputs).ToList(); + var first_flatted_input = flatted_inptus[0]; + var time_steps = first_flatted_input.shape[0]; + var batch = first_flatted_input.shape[1]; + var time_steps_t = tf.shape(first_flatted_input)[0]; + + foreach (var input_ in flatted_inptus) + { + input_.shape.with_rank_at_least(3); + } + + if (mask != null) + { + if (mask.dtype != TF_DataType.TF_BOOL) + { + mask = tf.cast(mask, TF_DataType.TF_BOOL); + } + + if (mask.rank == 2) + { + mask = tf.expand_dims(mask, -1); + } + + if (!time_major) + { + mask = swap_batch_timestep(mask); + } + + } + + // tf.where needs its condition tensor to be the same shape as its two + // result tensors, but in our case the condition (mask) tensor is + // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + // So we need to broadcast the mask to match the shape of inputs. + // That's what the tile call does, it just repeats the mask along its + // second dimension n times. + + Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) + { + if (!mask_t.IsSingle()) + { + throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); + } + + if (!input_t.IsSingle()) + { + throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); + } + + var rank_diff = input_t.rank - mask_t.rank; + for (int i = 0; i < rank_diff; i++) + { + mask_t = tf.expand_dims(mask_t, -1); + } + var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray()); + return tf.tile(mask_t, multiples); + } + + Tensors outputs = new Tensors(); + Tensors output_time_zero = new Tensors(); + Tensors last_output = new Tensors(); + Tensors new_states = new Tensors(); + if (unroll) + { + if (time_steps == 0) + { + throw new ValueError("Unrolling requires a fixed number of timesteps."); + } + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple + //var states = Tuple.Create(initial_states); + var states = initial_states; + + var successive_states = new Tensors(); + var successive_outputs = new Tensors(); + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + Tensors _process_single_input_t(Tensor input_t) + { + var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim + if (go_backwards) + { + unstaked_input_t = unstaked_input_t.Reverse().ToArray(); + } + return unstaked_input_t; + } + + // TODO(Wanglongzhi2001) + Tensors processed_input; + if (!inputs.IsSingle()) + { + processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo().ToTensors(); + } + else + { + processed_input = _process_single_input_t(inputs); + } + + object _get_input_tensor(int time) + { + List inp = new List(); + foreach (var t_ in processed_input) + { + inp.Add(t_[time]); + } + return Nest.PackSequenceAs(inputs, inp); + } + + if (mask != null) + { + var mask_list = tf.unstack(mask); + if (go_backwards) + { + mask_list.Reverse().ToArray(); + } + + for (int i = 0; i < time_steps; i++) + { + // TODO(Wanglongzhi2001),deal with _get_input_tensor + var inp = _get_input_tensor(i); + var mask_t = mask_list[i]; + // TODO + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + + var tiled_mask_t = _expand_mask(mask_t, output); + + Tensors prev_output; + if (successive_outputs == null) + { + prev_output = tf.zeros_like(output); + } + else + { + prev_output = successive_outputs.Last(); + } + + // output could be a tensor + output = tf.where(tiled_mask_t, output, prev_output); + + var flat_states = Nest.Flatten(states).ToList(); + var flat_new_states = Nest.Flatten(newStates).ToList(); + + var tiledMaskT = flat_states + .Select(s => _expand_mask(mask_t, s)) + .ToArray(); + var tuple = Tuple.Create(tiledMaskT); + + List flat_final_states = new List(); + foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) + { + flat_final_states.Add(tf.where(m, s, ps)); + } + + states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); + if (return_all_outputs) + { + successive_outputs = successive_outputs.MergeWith(output); + successive_outputs = successive_states.MergeWith(states); + } + else + { + successive_outputs = new Tensors(output); + successive_states = new Tensors(states); + } + + } + last_output = successive_outputs.Last(); + new_states = successive_states.Last(); + outputs = tf.stack(successive_outputs); + + if (zero_output_for_mask) + { + last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output)); + outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + } + else // mask is null + { + for (int i = 0; i < time_steps; i++) + { + var inp = _get_input_tensor(i); + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + states = newStates; + + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(newStates); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { newStates }; + } + } + last_output = successive_outputs.Last(); + new_states = successive_states.Last(); + outputs = tf.stack(successive_outputs); + } + } + } + else // unroll == false + { + var states = initial_states; + // Create input tensor array, if the inputs is nested tensors, then it + // will be flattened first, and tensor array will be created one per + // flattened tensor. + + + var input_ta = new List(); + for (int i = 0; i < flatted_inptus.Count; i++) + { + input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t)); + } + + foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) + { + if (!go_backwards) + { + ta.unstack(input_); + } + else + { + ta.unstack(reverse(input_, 0)); + } + } + + + // Get the time(0) input and compute the output for that, the output will + // be used to determine the dtype of output tensor array. Don't read from + // input_ta due to TensorArray clear_after_read default to True. + var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors(); + + // output_time_zero is used to determine the cell output shape and its + // dtype. the value is discarded. + (output_time_zero, _) = step_function(input_time_zero, + constants is null ? initial_states : initial_states.MergeWith(constants)); + + Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1); + var output_ta = new List(); + foreach(var output in output_time_zero.Flatten()) + { + output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape)); + } + + var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + Func? masking_fn; + Func? compute_masked_output = null; + if (mask != null) + { + if (go_backwards) + { + mask = tf.reverse(mask, axis: new[] { 0 }); + } + var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t); + mask_ta = mask_ta.unstack(mask); + + masking_fn = (time) => + { + return mask_ta.read(time); + }; + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var tiled_mask_t = new Tensors(); + foreach (var o in flat_out) + { + tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + } + + Tensors res = new Tensors(); + foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) + { + res.Add(tf.where(m, o, fm)); + } + return res; + }; + } + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor + else if (input_length is Tensor) + { + if (go_backwards) + { + var max_len = tf.reduce_max(input_length, axis: 0); + var rev_input_length = tf.subtract(max_len - 1, input_length); + + masking_fn = (time) => + { + return tf.less(rev_input_length, time); + }; + } + else + { + masking_fn = (time) => + { + return tf.greater(input_length, time); + }; + } + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var res = new List(); + foreach (var (o, zo) in zip(flat_out, flat_mask)) + { + res.Add(tf.where(mask_t, o, zo)); + } + return res; + }; + } + else + { + masking_fn = null; + } + + Func cond = (time) => (time[0] < time_steps_t); + int parallel_iterations = 32; + Tensors final_outputs; + if (masking_fn != null) + { + // Mask for the T output will be base on the output of T - 1. In the + // case T = 0, a zero filled tensor will be used. + var flat_zero_output = new Tensors(); + foreach (var o in Nest.Flatten(output_time_zero)) + { + flat_zero_output.Add(tf.zeros_like(o)); + } + + var prev_output = flat_zero_output; + var output_ta_t = output_ta; + Tensors _step(Tensors tensors) + { + /* + RNN step function. + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + Returns: + Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + */ + + Tensor time = tensors[0]; + TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; + Tensors prev_output = tensors.GetShallow(2); + Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray()); + + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var mask_t = masking_fn(time); + var (output, new_states) = step_function(current_input, states.MergeWith(constants)); + // mask output + var flat_output = Nest.Flatten(output).ToList(); + + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList(); + + // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // mask states + var flat_state = states.Flatten().ToList(); + var flat_new_state = new_states.Flatten().ToList(); + + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + + var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors(); + + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + Debug.Assert(flat_output.Count() == 1); + output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First()); + + return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states) + .ToArray().ToTensors(); + + } + var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) } + .Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors(); + final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); + new_states = final_outputs.Skip(3).ToList(); + } + else + { + var output_ta_t = output_ta; + new_states = states; + Tensors _step(Tensors tensors) + { + Tensor time = tensors[0]; + TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; + Tensors states = new Tensors(tensors.Skip(2).ToArray()); + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var (output, new_states) = step_function(current_input, states.MergeWith(constants)); + var flat_state = new_states.Flatten().ToList(); + var flat_new_state = new_states.Flatten().ToList(); + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + var flat_output = Nest.Flatten(output); + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + Debug.Assert(flat_output.Count() == 1); + output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First()); + + new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors(); + } + Debug.Assert(output_ta.Count == 1); + var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors(); + final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); + new_states = final_outputs.Skip(2).ToList(); + } + + output_ta = new List { (final_outputs[1] as FakeTensorByTensorArray).TensorArray }; + outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors()); + last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors()); + outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors(); + last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors(); + } + + Func set_shape; + set_shape = (output_) => + { + if (output_ is Tensor) + { + var shape = output_.shape.as_int_list(); + if (return_all_outputs) + { + shape[0] = (int)time_steps; + } + else + { + shape[0] = 1; + } + shape[1] = (int)batch; + output_.shape = shape; + } + return output_; + }; + + outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); + if (!time_major) + { + outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); + } + return (last_output, outputs, new_states); + + } + + /// + /// Repeats the elements of a tensor along an axis, like `np.repeat`. + /// + /// + /// + /// + /// + public Tensor repeat_elements(Tensor x, int rep, int axis) + { + var x_shape = x.shape.as_int_list(); + if (x_shape[axis] != -1) + { + var splits = tf.split(x, x_shape[axis], axis:axis); + var x_rep = splits.SelectMany(s => Enumerable.Repeat(s, rep)).ToArray(); + return concatenate(x_rep, axis); + } + //var auxiliary_axis = axis + 1; + //x_shape = x.shape; + //var x_rep = tf.expand_dims(x, auxiliary_axis); + //var reps = np.ones(x_shape.Length + 1); + //reps[auxiliary_axis] = rep; + //x_rep = tf.tile(x_rep, reps); + + throw new NotImplementedException(); + + } + public Tensor reverse(Tensor input, int axis) + { + return reverse(input, new int[] { axis }); + } + + public Tensor reverse(Tensor input, int[] axes) + { + return tf.reverse(input, axes); + } + + public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) + { + if (!is_ragged_output) + { + return output; + } + + throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs new file mode 100644 index 000000000..cb16aafa3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Callbacks; + +public class CallbackList +{ + // 改成public使得新定义的callback可以加入到callbacks里 + public List callbacks = new List(); + public History History => callbacks[0] as History; + + public CallbackList(CallbackParams parameters) + { + callbacks.Add(new History(parameters)); + callbacks.Add(new ProgbarLogger(parameters)); + } + + public void on_train_begin() + { + callbacks.ForEach(x => x.on_train_begin()); + } + public void on_test_begin() + { + callbacks.ForEach(x => x.on_test_begin()); + } + public void on_epoch_begin(int epoch) + { + callbacks.ForEach(x => x.on_epoch_begin(epoch)); + } + + public void on_train_batch_begin(long step) + { + callbacks.ForEach(x => x.on_train_batch_begin(step)); + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + callbacks.ForEach(x => x.on_train_batch_end(end_step, logs)); + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); + } + + public void on_predict_begin() + { + callbacks.ForEach(x => x.on_predict_begin()); + } + + public void on_predict_batch_begin(long step) + { + callbacks.ForEach(x => x.on_predict_batch_begin(step)); + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); + } + + public void on_predict_end() + { + callbacks.ForEach(x => x.on_predict_end()); + } + + public void on_test_batch_begin(long step) + { + callbacks.ForEach(x => x.on_test_batch_begin(step)); + } + public void on_test_batch_end(long end_step, Dictionary logs) + { + callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); + } + + public void on_test_end(Dictionary logs) + { + callbacks.ForEach(x => x.on_test_end(logs)); + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs new file mode 100644 index 000000000..fe859c8a2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Callbacks +{ + public class CallbackParams + { + public IModel Model { get; set; } + public int Verbose { get; set; } + public int Epochs { get; set; } + public long Steps { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs new file mode 100644 index 000000000..a2a2ecfe2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -0,0 +1,174 @@ +using Tensorflow.Keras.Engine; +namespace Tensorflow.Keras.Callbacks; + + +/// +/// Stop training when a monitored metric has stopped improving. +/// +public class EarlyStopping: ICallback +{ + int _paitence; + float _min_delta; + int _verbose; + int _stopped_epoch; + int _wait; + int _best_epoch; + int _start_from_epoch; + float _best; + float _baseline; + string _monitor; + string _mode; + bool _restore_best_weights; + List? _best_weights; + CallbackParams _parameters; + Func _monitor_op; + + public Dictionary>? history { get; set; } + // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model + public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, + int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false, + int start_from_epoch = 0) + { + _parameters = parameters; + _stopped_epoch = 0; + _wait = 0; + _monitor = monitor; + _paitence = patience; + _verbose = verbose; + _baseline = baseline; + _start_from_epoch = start_from_epoch; + _min_delta = Math.Abs(min_delta); + _restore_best_weights = restore_best_weights; + _mode = mode; + + if (_mode != "auto" && _mode != "min" && _mode != "max") + { + Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode."); + _mode = "auto"; + } + + if (_mode == "min") + { + _monitor_op = np.less; + } + else if (_mode == "max") + { + _monitor_op = np.greater; + } + else + { + if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) + { + _monitor_op = np.greater; + } + else + { + _monitor_op = np.less; + } + } + + if (_monitor_op == np.greater) + { + _min_delta *= 1; + } + else + { + _min_delta *= -1; + } + } + public void on_train_begin() + { + _wait = 0; + _stopped_epoch = 0; + _best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf; + _best_weights = null; + _best_epoch = 0; + } + + public void on_epoch_begin(int epoch) + { + + } + + public void on_train_batch_begin(long step) + { + + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + var current = get_monitor_value(epoch_logs); + // If no monitor value exists or still in initial warm-up stage. + if (current == 0f || epoch < _start_from_epoch) + return; + // Restore the weights after first epoch if no progress is ever made. + if (_restore_best_weights && _best_weights == null) + { + _best_weights = _parameters.Model.get_weights(); + } + _wait += 1; + + if (_is_improvement(current, _best)) + { + _best = current; + _best_epoch = epoch; + if (_restore_best_weights) + _best_weights = _parameters.Model.get_weights(); + // Only restart wait if we beat both the baseline and our previous best. + if (_baseline == 0f || _is_improvement(current, _baseline)) + _wait = 0; + } + // Only check after the first epoch. + if (_wait >= _paitence && epoch > 0) + { + _stopped_epoch = epoch; + _parameters.Model.Stop_training = true; + if (_restore_best_weights && _best_weights != null) + { + if (_verbose > 0) + { + Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); + } + _parameters.Model.set_weights(_best_weights); + } + } + } + public void on_train_end() + { + if (_stopped_epoch > 0 && _verbose > 0) + { + Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping"); + } + } + public void on_predict_begin() { } + public void on_predict_batch_begin(long step) { } + public void on_predict_batch_end(long end_step, Dictionary logs) { } + public void on_predict_end() { } + public void on_test_begin() { } + public void on_test_batch_begin(long step) { } + public void on_test_batch_end(long end_step, Dictionary logs) { } + + float get_monitor_value(Dictionary logs) + { + logs = logs ?? new Dictionary(); + float monitor_value = logs[_monitor]; + if (monitor_value == 0f) + { + Console.WriteLine($"Early stopping conditioned on metric {_monitor} " + + $"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}"); + } + return monitor_value; + } + public bool _is_improvement(float monitor_value, float reference_value) + { + return _monitor_op(monitor_value - _min_delta, reference_value); + } + + public void on_test_end(Dictionary logs) + { + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs new file mode 100644 index 000000000..6d3ff6c38 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/History.cs @@ -0,0 +1,88 @@ +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Callbacks; + +public class History : ICallback +{ + List epochs; + CallbackParams _parameters; + public Dictionary> history { get; set; } + + public History(CallbackParams parameters) + { + _parameters = parameters; + } + + public void on_train_begin() + { + epochs = new List(); + history = new Dictionary>(); + } + public void on_test_begin() + { + epochs = new List(); + history = new Dictionary>(); + } + public void on_train_end() { } + public void on_epoch_begin(int epoch) + { + + } + + public void on_train_batch_begin(long step) + { + + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + epochs.Add(epoch); + + foreach (var log in epoch_logs) + { + if (!history.ContainsKey(log.Key)) + { + history[log.Key] = new List(); + } + history[log.Key].Add(log.Value); + } + } + + public void on_predict_begin() + { + epochs = new List(); + history = new Dictionary>(); + } + + public void on_predict_batch_begin(long step) + { + + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + + } + + public void on_predict_end() + { + + } + + public void on_test_batch_begin(long step) + { + + } + + public void on_test_batch_end(long end_step, Dictionary logs) + { + } + + public void on_test_end(Dictionary logs) + { + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs new file mode 100644 index 000000000..23b18cd47 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs @@ -0,0 +1,125 @@ +using System.Diagnostics; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Callbacks +{ + public class ProgbarLogger : ICallback + { + bool _called_in_fit = false; + int seen = 0; + CallbackParams _parameters; + Stopwatch _sw; + + public Dictionary> history { get; set; } + + public ProgbarLogger(CallbackParams parameters) + { + _parameters = parameters; + } + + public void on_train_begin() + { + _called_in_fit = true; + _sw = new Stopwatch(); + } + public void on_train_end() { } + public void on_test_begin() + { + _sw = new Stopwatch(); + } + public void on_epoch_begin(int epoch) + { + _reset_progbar(); + _maybe_init_progbar(); + Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{_parameters.Epochs:D3}"); + } + + public void on_train_batch_begin(long step) + { + _sw.Restart(); + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + _sw.Stop(); + var elapse = _sw.ElapsedMilliseconds; + var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {(float)x.Value:F6}")); + + var progress = ""; + var length = 30.0 / _parameters.Steps; + for (int i = 0; i < Math.Floor(end_step * length - 1); i++) + progress += "="; + if (progress.Length < 28) + progress += ">"; + else + progress += "="; + + var remaining = ""; + for (int i = 1; i < 30 - progress.Length; i++) + remaining += "."; + + Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} [{progress}{remaining}] - {elapse}ms/step - {results}"); + if (!Console.IsOutputRedirected) + { + Console.CursorLeft = 0; + } + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + Console.WriteLine(); + } + + void _reset_progbar() + { + seen = 0; + } + + void _maybe_init_progbar() + { + + } + + public void on_predict_begin() + { + _reset_progbar(); + _maybe_init_progbar(); + } + + public void on_predict_batch_begin(long step) + { + + } + + public void on_predict_batch_end(long end_step, Dictionary logs) + { + + } + + public void on_predict_end() + { + + } + + public void on_test_batch_begin(long step) + { + _sw.Restart(); + } + public void on_test_batch_end(long end_step, Dictionary logs) + { + _sw.Stop(); + var elapse = _sw.ElapsedMilliseconds; + var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}")); + + Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); + if (!Console.IsOutputRedirected) + { + Console.CursorLeft = 0; + } + } + + public void on_test_end(Dictionary logs) + { + } + } +} diff --git a/src/TensorFlowNET.Keras/Datasets/Cifar10.cs b/src/TensorFlowNET.Keras/Datasets/Cifar10.cs new file mode 100644 index 000000000..dc1fb76d5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Datasets/Cifar10.cs @@ -0,0 +1,133 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Datasets +{ + public class Cifar10 + { + string origin_folder = "https://www.cs.toronto.edu/~kriz/"; + string file_name = "cifar-10-python.tar.gz"; + string dest_folder = "cifar-10-batches"; + + /// + /// Loads [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). + /// + /// + public DatasetPass load_data() + { + var dst = Download(); + + var data_list = new List(); + var label_list = new List(); + + foreach (var i in range(1, 6)) + { + var fpath = Path.Combine(dst, $"data_batch_{i}"); + var (data, labels) = load_batch(fpath); + data_list.Add(data); + label_list.Add(labels); + } + + var x_train_tensor = tf.concat(data_list, 0); + var y_train_tensor = tf.concat(label_list, 0); + var y_train = np.array(y_train_tensor.BufferToArray()).reshape(y_train_tensor.shape); + + // test data + var fpath_test = Path.Combine(dst, "test_batch"); + var (x_test, y_test) = load_batch(fpath_test); + + // channels_last + x_train_tensor = tf.transpose(x_train_tensor, new[] { 0, 2, 3, 1 }); + var x_train = np.array(x_train_tensor.BufferToArray()).reshape(x_train_tensor.shape); + + var x_test_tensor = tf.transpose(x_test, new[] { 0, 2, 3, 1 }); + x_test = np.array(x_test_tensor.BufferToArray()).reshape(x_test_tensor.shape); + + return new DatasetPass + { + Train = (x_train, y_train), + Test = (x_test, y_test) + }; + } + + (NDArray, NDArray) load_batch(string fpath, string label_key = "labels") + { + var pickle = File.ReadAllBytes(fpath); + // read description + var start_pos = 7; + var desc = read_description(ref start_pos, pickle); + var labels = read_labels(ref start_pos, pickle); + var data = read_data(ref start_pos, pickle); + + return (data.Item2, labels.Item2); + } + + (string, string) read_description(ref int start_pos, byte[] pickle) + { + var key_length = pickle[start_pos]; + start_pos++; + var span = new Span(pickle, start_pos, key_length); + var key = Encoding.ASCII.GetString(span.ToArray()); + start_pos += key_length + 3; + + var value_length = pickle[start_pos]; + start_pos++; + var value = Encoding.ASCII.GetString(new Span(pickle, start_pos, value_length).ToArray()); + start_pos += value_length; + start_pos += 3; + + return (key, value); + } + + (string, NDArray) read_labels(ref int start_pos, byte[] pickle) + { + byte[] value = new byte[10000]; + + var key_length = pickle[start_pos]; + start_pos++; + var span = new Span(pickle, start_pos, key_length); + var key = Encoding.ASCII.GetString(span.ToArray()); + start_pos += key_length + 6; + + var value_length = 10000; + for (int i = 0; i < value_length; i++) + { + if (i > 0 && i % 1000 == 0) + start_pos += 2; + value[i] = pickle[start_pos + 1]; + start_pos += 2; + } + start_pos += 2; + + return (key, np.array(value)); + } + + (string, NDArray) read_data(ref int start_pos, byte[] pickle) + { + var key_length = pickle[start_pos]; + start_pos++; + var span = new Span(pickle, start_pos, key_length); + var key = Encoding.ASCII.GetString(span.ToArray()); + start_pos += key_length + 133; + var value_length = 3072 * 10000; + var value = new Span(pickle, start_pos, value_length).ToArray(); + start_pos += value_length; + + return (key, np.array(value).reshape((10000, 3, 32, 32))); + } + + string Download() + { + var dst = Path.Combine(Path.GetTempPath(), dest_folder); + Web.Download(origin_folder + file_name, dst, file_name); + Compress.ExtractTGZ(Path.Combine(dst, file_name), dst); + + return Path.Combine(dst, "cifar-10-batches-py"); + } + } +} diff --git a/src/TensorFlowNET.Keras/Datasets/DatasetPass.cs b/src/TensorFlowNET.Keras/Datasets/DatasetPass.cs new file mode 100644 index 000000000..80bafaa36 --- /dev/null +++ b/src/TensorFlowNET.Keras/Datasets/DatasetPass.cs @@ -0,0 +1,24 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Datasets +{ + public class DatasetPass + { + public (NDArray, NDArray) Train { get; set; } + public (NDArray, NDArray) Test { get; set; } + + public void Deconstruct(out NDArray x_train, out NDArray y_train, out NDArray x_test, out NDArray y_test) + { + x_train = Train.Item1; + y_train = Train.Item2; + x_test = Test.Item1; + y_test = Test.Item2; + } + + public void Deconstruct(out (NDArray, NDArray) train, out (NDArray, NDArray) test) + { + train = Train; + test = Test; + } + } +} diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs new file mode 100644 index 000000000..4d6df913b --- /dev/null +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -0,0 +1,243 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Datasets +{ + /// + /// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment + /// (positive/negative). Reviews have been preprocessed, and each review is + /// encoded as a list of word indexes(integers). + /// For convenience, words are indexed by overall frequency in the dataset, + /// so that for instance the integer "3" encodes the 3rd most frequent word in + /// the data.This allows for quick filtering operations such as: + /// "only consider the top 10,000 most + /// common words, but eliminate the top 20 most common words". + /// As a convention, "0" does not stand for a specific word, but instead is used + /// to encode the pad token. + /// Args: + /// path: where to cache the data (relative to %TEMP%/imdb/imdb.npz). + /// num_words: integer or None.Words are + /// ranked by how often they occur(in the training set) and only + /// the `num_words` most frequent words are kept.Any less frequent word + /// will appear as `oov_char` value in the sequence data.If None, + /// all words are kept.Defaults to `None`. + /// skip_top: skip the top N most frequently occurring words + /// (which may not be informative). These words will appear as + /// `oov_char` value in the dataset.When 0, no words are + /// skipped. Defaults to `0`. + /// maxlen: int or None.Maximum sequence length. + /// Any longer sequence will be truncated. None, means no truncation. + /// Defaults to `None`. + /// seed: int. Seed for reproducible data shuffling. + /// start_char: int. The start of a sequence will be marked with this + /// character. 0 is usually the padding character. Defaults to `1`. + /// oov_char: int. The out-of-vocabulary character. + /// Words that were cut out because of the `num_words` or + /// `skip_top` limits will be replaced with this character. + /// index_from: int. Index actual words with this index and higher. + /// Returns: + /// Tuple of Numpy arrays: `(x_train, labels_train), (x_test, labels_test)`. + /// + /// ** x_train, x_test**: lists of sequences, which are lists of indexes + /// (integers). If the num_words argument was specific, the maximum + /// possible index value is `num_words - 1`. If the `maxlen` argument was + /// specified, the largest possible sequence length is `maxlen`. + /// + /// ** labels_train, labels_test**: lists of integer labels(1 or 0). + /// + /// Raises: + /// ValueError: in case `maxlen` is so low + /// that no input sequence could be kept. + /// Note that the 'out of vocabulary' character is only used for + /// words that were present in the training set but are not included + /// because they're not making the `num_words` cut here. + /// Words that were not seen in the training set but are in the test set + /// have simply been skipped. + /// + /// """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). + public class Imdb + { + string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; + string dest_folder = "imdb"; + + /// + /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public DatasetPass load_data( + string path = "imdb.npz", + int? num_words = null, + int skip_top = 0, + int? maxlen = null, + int seed = 113, + int? start_char = 1, + int? oov_char = 2, + int index_from = 3) + { + path = data_utils.get_file( + path, + origin: Path.Combine(origin_folder, "imdb.npz"), + file_hash: "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f" + ); + path = Path.Combine(path, "imdb.npz"); + var fileBytes = File.ReadAllBytes(path); + var (x_train, x_test) = LoadX(fileBytes); + var (labels_train, labels_test) = LoadY(fileBytes); + + var indices = np.arange(len(x_train)); + np.random.shuffle(indices, seed); + x_train = x_train[indices]; + labels_train = labels_train[indices]; + + indices = np.arange(len(x_test)); + np.random.shuffle(indices, seed); + x_test = x_test[indices]; + labels_test = labels_test[indices]; + + var x_train_array = (int[,])x_train.ToMultiDimArray(); + var x_test_array = (int[,])x_test.ToMultiDimArray(); + var labels_train_array = (long[])labels_train.ToArray(); + var labels_test_array = (long[])labels_test.ToArray(); + + if (start_char != null) + { + var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); + int[,] new_x_train_array = new int[d1, d2 + 1]; + for (var i = 0; i < d1; i++) + { + new_x_train_array[i, 0] = (int)start_char; + Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2); + } + (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); + int[,] new_x_test_array = new int[d1, d2 + 1]; + for (var i = 0; i < d1; i++) + { + new_x_test_array[i, 0] = (int)start_char; + Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2); + } + x_train_array = new_x_train_array; + x_test_array = new_x_test_array; + } + else if (index_from != 0) + { + var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); + for (var i = 0; i < d1; i++) + { + for (var j = 0; j < d2; j++) + { + if (x_train_array[i, j] == 0) + break; + x_train_array[i, j] += index_from; + } + } + (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); + for (var i = 0; i < d1; i++) + { + for (var j = 0; j < d2; j++) + { + if (x_test_array[i, j] == 0) + break; + x_test[i, j] += index_from; + } + } + } + + if (maxlen == null) + { + maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1)); + } + (x_train_array, labels_train_array) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array); + (x_test_array, labels_test_array) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array); + if (x_train_array.Length == 0 || x_test_array.Length == 0) + throw new ValueError("After filtering for sequences shorter than maxlen=" + + $"{maxlen}, no sequence was kept. Increase maxlen."); + + int[,] xs_array = new int[x_train_array.GetLength(0) + x_test_array.GetLength(0), (int)maxlen]; + Array.Copy(x_train_array, xs_array, x_train_array.Length); + Array.Copy(x_test_array, 0, xs_array, x_train_array.Length, x_train_array.Length); + + long[] labels_array = new long[labels_train_array.Length + labels_test_array.Length]; + Array.Copy(labels_train_array, labels_array, labels_train_array.Length); + Array.Copy(labels_test_array, 0, labels_array, labels_train_array.Length, labels_test_array.Length); + + if (num_words == null) + { + var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1)); + num_words = 0; + for (var i = 0; i < d1; i++) + for (var j = 0; j < d2; j++) + num_words = max((int)num_words, (int)xs_array[i, j]); + } + + // by convention, use 2 as OOV word + // reserve 'index_from' (=3 by default) characters: + // 0 (padding), 1 (start), 2 (OOV) + if (oov_char != null) + { + var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1)); + int[,] new_xs_array = new int[d1, d2]; + for (var i = 0; i < d1; i++) + { + for (var j = 0; j < d2; j++) + { + if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words) + new_xs_array[i, j] = xs_array[i, j]; + else + new_xs_array[i, j] = (int)oov_char; + } + } + xs_array = new_xs_array; + } + else + { + var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1)); + int[,] new_xs_array = new int[d1, d2]; + for (var i = 0; i < d1; i++) + { + int k = 0; + for (var j = 0; j < d2; j++) + { + if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words) + new_xs_array[i, k++] = xs_array[i, j]; + } + } + xs_array = new_xs_array; + } + + Array.Copy(xs_array, x_train_array, x_train_array.Length); + Array.Copy(xs_array, x_train_array.Length, x_test_array, 0, x_train_array.Length); + + Array.Copy(labels_array, labels_train_array, labels_train_array.Length); + Array.Copy(labels_array, labels_train_array.Length, labels_test_array, 0, labels_test_array.Length); + + return new DatasetPass + { + Train = (x_train_array, labels_train_array), + Test = (x_test_array, labels_test_array) + }; + } + + (NDArray, NDArray) LoadX(byte[] bytes) + { + var x = np.Load_Npz(bytes); + return (x["x_train.npy"], x["x_test.npy"]); + } + + (NDArray, NDArray) LoadY(byte[] bytes) + { + var y = np.Load_Npz(bytes); + return (y["y_train.npy"], y["y_test.npy"]); + } + } +} diff --git a/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs b/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs new file mode 100644 index 000000000..0a328702f --- /dev/null +++ b/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Keras.Datasets +{ + public class KerasDataset + { + public Mnist mnist { get; } = new Mnist(); + public Cifar10 cifar10 { get; } = new Cifar10(); + public Imdb imdb { get; } = new Imdb(); + } +} diff --git a/src/TensorFlowNET.Keras/Datasets/MNIST.cs b/src/TensorFlowNET.Keras/Datasets/MNIST.cs new file mode 100644 index 000000000..0e2dd2186 --- /dev/null +++ b/src/TensorFlowNET.Keras/Datasets/MNIST.cs @@ -0,0 +1,73 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.IO; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Datasets +{ + public class Mnist + { + string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; + string file_name = "mnist.npz"; + + /// + /// Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). + /// + /// + public DatasetPass load_data() + { + var file = Download(); + var bytes = File.ReadAllBytes(file); + var datax = LoadX(bytes); + var datay = LoadY(bytes); + return new DatasetPass + { + Train = (datax.Item1, datay.Item1), + Test = (datax.Item2, datay.Item2) + }; + } + + (NDArray, NDArray) LoadX(byte[] bytes) + { + var x = np.Load_Npz(bytes); + return (x["x_train.npy"], x["x_test.npy"]); + } + + (NDArray, NDArray) LoadY(byte[] bytes) + { + var y = np.Load_Npz(bytes); + return (y["y_train.npy"], y["y_test.npy"]); + } + + string Download() + { + var fileSaveTo = Path.Combine(Path.GetTempPath(), file_name); + + if (File.Exists(fileSaveTo)) + { + Binding.tf_output_redirect.WriteLine($"The file {fileSaveTo} already exists"); + return fileSaveTo; + } + + Web.Download(origin_folder + file_name, Path.GetTempPath(), file_name); + + return fileSaveTo; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/CallContext.cs b/src/TensorFlowNET.Keras/Engine/CallContext.cs new file mode 100644 index 000000000..99dd7901f --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/CallContext.cs @@ -0,0 +1,12 @@ +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public class CallContext + { + public CallContextManager enter(bool build_graph) + { + return new CallContextManager(build_graph); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/CallContextManager.cs b/src/TensorFlowNET.Keras/Engine/CallContextManager.cs new file mode 100644 index 000000000..79cb4b30c --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/CallContextManager.cs @@ -0,0 +1,20 @@ +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public class CallContextManager : IDisposable + { + bool _build_graph; + + public CallContextManager(bool build_graph) + { + _build_graph = build_graph; + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs new file mode 100644 index 000000000..2e5644807 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine +{ + public class CombinerPreprocessingLayer : Layer + { + PreprocessingLayerArgs args; + protected ICombiner combiner; + protected bool _previously_updated; + + public CombinerPreprocessingLayer(PreprocessingLayerArgs args) + : base(args) + { + _previously_updated = false; + } + + public virtual void adapt(IDatasetV2 data, bool reset_state = true) + { + IAccumulator accumulator; + if (!reset_state) + accumulator = combiner.Restore(); + + var next_data = data.make_one_shot_iterator(); + var data_element = next_data.next(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Container.cs b/src/TensorFlowNET.Keras/Engine/Container.cs new file mode 100644 index 000000000..baf5e662b --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Container.cs @@ -0,0 +1,13 @@ +namespace Tensorflow.Keras.Engine +{ + public class Container + { + protected string[] _output_names; + protected bool _built; + + public Container(string[] output_names) + { + _output_names = output_names; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs new file mode 100644 index 000000000..590f30a78 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Util; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + public abstract class DataAdapter + { + protected DataAdapterArgs args; + protected IDatasetV2 dataset; + + public virtual bool CanHandle(Tensors x, Tensors y = null) + => throw new NotImplementedException(); + + public virtual IDatasetV2 GetDataset() + => dataset; + + public virtual int GetSize() + => throw new NotImplementedException(""); + + public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y) + { + for(int i = 0; i < x.Length; i++) + { + if (x[i].shape.ndim == 1) + x[i] = array_ops.expand_dims(x[i], axis: -1); + } + for (int i = 0; i < y.Length; i++) + { + if (y[i].shape.ndim == 1) + y[i] = array_ops.expand_dims(y[i], axis: -1); + } + return (x, y); + } + + public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) + { + for (int i = 0; i < x.Length; i++) + { + if (x[i].shape.ndim == 1) + x[i] = array_ops.expand_dims(x[i], axis: -1); + } + for (int i = 0; i < y.Length; i++) + { + if (y[i].shape.ndim == 1) + y[i] = array_ops.expand_dims(y[i], axis: -1); + } + for (int i = 0; i < sample_weight.Length; i++) + { + if (sample_weight[i].shape.ndim == 1) + sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); + } + return (x, y, sample_weight); + } + + public virtual bool ShouldRecreateIterator() + { + return true; + } + + public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) + { + var x = x_y_sample_weight.Item1; + var y = x_y_sample_weight.Item2; + var sample_weight = x_y_sample_weight.Item3; + int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); + var train_x = x[new Slice(0, train_count)]; + var train_y = y[new Slice(0, train_count)]; + ValidationDataPack validation_data; + if (sample_weight != null) + { + validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); + sample_weight = sample_weight[new Slice(0, train_count)]; + } + else + { + validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); + } + + return ((train_x, train_y, sample_weight), validation_data); + } + + public static ((IEnumerable, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable, NDArray, NDArray) x_y_sample_weight, float validation_split) + { + var x = x_y_sample_weight.Item1; + var y = x_y_sample_weight.Item2; + var sample_weight = x_y_sample_weight.Item3; + int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); + var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); + var train_y = y[new Slice(0, train_count)]; + var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); + var val_y = y[new Slice(train_count)]; + + ValidationDataPack validation_data; + if (sample_weight != null) + { + validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); + sample_weight = sample_weight[new Slice(0, train_count)]; + } + else + { + validation_data = (val_x, val_y); + } + return ((train_x, train_y, sample_weight), validation_data); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs new file mode 100644 index 000000000..a305e5033 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -0,0 +1,211 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; +using Tensorflow.Keras.Utils; +using Tensorflow.Util; +using Tensorflow.Framework; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + /// + /// Handles iterating over epoch-level `tf.data.Iterator` objects. + /// + public class DataHandler + { + DataHandlerArgs args; + IDataAdapter _adapter; + public IDataAdapter DataAdapter => _adapter; + IDatasetV2 _dataset; + long _inferred_steps; + public long Inferredsteps => _inferred_steps; + long _current_step; + long _step_increment; + public long StepIncrement => _step_increment; + bool _insufficient_data; + long _steps_per_execution_value; + int _initial_epoch => args.InitialEpoch; + int _epochs => args.Epochs; + NDArray _sample_weight => args.SampleWeight; + IVariableV1 _steps_per_execution; + + public DataHandler(DataHandlerArgs args) + { + this.args = args; + + if (args.StepsPerExecution == null) + { + _steps_per_execution = tf.Variable(1L); + _steps_per_execution_value = 1L; + } + else + { + _steps_per_execution = args.StepsPerExecution; + _steps_per_execution_value = args.StepsPerExecution.numpy(); + } + + if(args.Dataset == null) + { + _adapter = new TensorLikeDataAdapter(new DataAdapterArgs + { + X = args.X, + Y = args.Y, + BatchSize = args.BatchSize, + Steps = args.StepsPerEpoch, + Epochs = args.Epochs - args.InitialEpoch, + SampleWeight = args.SampleWeight, + Shuffle = args.Shuffle, + MaxQueueSize = args.MaxQueueSize, + Worker = args.Workers, + UseMultiprocessing = args.UseMultiprocessing, + Model = args.Model + }); + } + else + { + _adapter = new DatasetAdapter(new DataAdapterArgs + { + Dataset = args.Dataset, + BatchSize = args.BatchSize, + Steps = args.StepsPerEpoch, + Epochs = args.Epochs - args.InitialEpoch, + Shuffle = args.Shuffle, + MaxQueueSize = args.MaxQueueSize, + Worker = args.Workers, + UseMultiprocessing = args.UseMultiprocessing, + Model = args.Model + }); + } + + _dataset = _adapter.GetDataset(); + _current_step = 0; + _step_increment = _steps_per_execution_value - 1; + _insufficient_data = false; + _configure_dataset_and_inferred_steps(args.X, args.ClassWeight); + } + + void _configure_dataset_and_inferred_steps(Tensors x, Dictionary class_weight) + { + if (_dataset == null) + { + _dataset = _adapter.GetDataset(); + _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); + } + + if (class_weight != null) + { + _dataset = _dataset.map(_make_class_weight_map_fn(class_weight)); + } + _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); + } + + + Func _make_class_weight_map_fn(Dictionary class_weight) + { + var class_ids = class_weight.Keys.OrderBy(key => key).ToList(); + var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1); + if (!class_ids.SequenceEqual(expected_class_ids)) + { + throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+ + $"than the number of classes, found {class_weight}"); + } + + var class_weight_list = new List(); + foreach (var class_id in class_ids) + { + class_weight_list.Add(class_weight[class_id]); + } + var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray()); + + Func _class_weight_map_fn = (Tensors data) => + { + var x = data[0]; + var y = data[1]; + var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight); + + if (y.shape.rank > 2) + { + throw new ValueError("`class_weight` not supported for 3+ dimensional targets."); + } + + var y_classes = smart_module.smart_cond( + y.shape.rank == 2 && y.shape[1] > 1, + () => math_ops.argmax(y, dimension: 1), + () => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64)); + + var cw = array_ops.gather(class_weight_tensor, y_classes); + if (sw != null) + { + cw = tf.cast(cw, sw.dtype); + cw *= sw; + } + else + { + sw = cw; + } + return new Tensors { x, y, sw }; + }; + + return _class_weight_map_fn; + } + + long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) + { + if (steps_per_epoch > -1) + return steps_per_epoch; + + var adapter_steps = _adapter.GetSize(); + if (adapter_steps > -1) + return adapter_steps; + + var size = dataset.cardinality(); + return size.numpy(); + } + + public IEnumerable<(int, OwnedIterator)> enumerate_epochs() + { + var data_iterator = new OwnedIterator(_dataset); + foreach (var epoch in range(_initial_epoch, _epochs)) + { + if (_insufficient_data) + break; + if (_adapter.ShouldRecreateIterator()) + { + data_iterator = new OwnedIterator(_dataset); + } + yield return (epoch, data_iterator); + } + // _adapter.on_epoch_end() + } + + public IEnumerable steps() + { + _current_step = 0; + while (_current_step < _inferred_steps) + { + if (_insufficient_data) + break; + + bool can_run_full_execution = _steps_per_execution_value == 1 + || _inferred_steps < 0 + || _inferred_steps - _current_step >= _steps_per_execution_value; + + if (can_run_full_execution) + { + _step_increment = _steps_per_execution_value - 1; + yield return _current_step; + _current_step += _steps_per_execution_value; + } + else + { + var steps_remaining = _inferred_steps - _current_step; + _steps_per_execution.assign(steps_remaining); + _step_increment = steps_remaining - 1; + yield return _current_step; + _current_step += steps_remaining; + _steps_per_execution.assign(_steps_per_execution_value); + } + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs new file mode 100644 index 000000000..29b0e58bd --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + public class DatasetAdapter : DataAdapter, IDataAdapter + { + public DatasetAdapter(DataAdapterArgs args) + { + this.args = args; + dataset = args.Dataset; + } + + public override int GetSize() + => -1; + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs new file mode 100644 index 000000000..bb71b0a2d --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs @@ -0,0 +1,24 @@ +namespace Tensorflow.Keras.Engine.DataAdapters +{ + /// + /// In TF 2.0, tf.data is the preferred API for user to feed in data. In order + /// to simplify the training code path, all the input data object will be + /// converted to `tf.data.Dataset` if possible. + /// + public interface IDataAdapter + { + /// + /// Whether the current DataAdapter could handle the input x and y. + /// + /// input features + /// target labels + /// + bool CanHandle(Tensors x, Tensors y = null); + IDatasetV2 GetDataset(); + int GetSize(); + (Tensors, Tensors) Expand1d(Tensors x, Tensors y); + (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); + + bool ShouldRecreateIterator(); + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs new file mode 100644 index 000000000..978a3f51c --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -0,0 +1,104 @@ +using System; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + /// + /// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. + /// + public class TensorLikeDataAdapter : DataAdapter, IDataAdapter + { + int _size; + int _batch_size; + int num_samples; + int num_full_batches; + int _partial_batch_size; + + public TensorLikeDataAdapter(DataAdapterArgs args) + { + this.args = args; + Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; + num_samples = (int)args.X.shape[0]; + var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; + _batch_size = batch_size; + _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); + num_full_batches = num_samples / batch_size; + _partial_batch_size = num_samples % batch_size; + + var indices_dataset = tf.data.Dataset.range(1); + indices_dataset = indices_dataset.repeat(args.Epochs); + indices_dataset = indices_dataset.map(permutation).prefetch(1); + indices_dataset = indices_dataset.flat_map(slice_batch_indices); + var inputs = new Tensors(); + if (args.X != null) + inputs.AddRange(args.X); + if (args.Y != null) + inputs.AddRange(args.Y); + if (sample_weight_tensor != null) + inputs.Add(sample_weight_tensor); + dataset = slice_inputs(indices_dataset, inputs); + dataset.FirstInputTensorCount = args.X.Length; + } + + Tensors permutation(Tensors tensor) + { + var indices = math_ops.range(num_samples, dtype: dtypes.int64); + if (args.Shuffle) + indices = random_ops.random_shuffle(indices); + return indices; + } + + /// + /// Convert a Tensor of indices into a dataset of batched indices. + /// + /// + /// + IDatasetV2 slice_batch_indices(Tensor indices) + { + var num_in_full_batch = num_full_batches * _batch_size; + var first_k_indices = array_ops.slice(indices, new Tensor[] { ops.convert_to_tensor(0) }, + new Tensor[] { ops.convert_to_tensor(num_in_full_batch) }); + first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); + var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); + if (_partial_batch_size > 0) + { + var array = array_ops.slice(indices, + new[] { constant_op.constant(num_in_full_batch)}, + new[] { constant_op.constant(_partial_batch_size)}); + var index_remainder = tf.data.Dataset.from_tensors(array); + flat_dataset = flat_dataset.concatenate(index_remainder); + } + + return flat_dataset; + } + + IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) + { + var dataset = tf.data.Dataset.from_tensors(elements).repeat(); + dataset = tf.data.Dataset.zip(indices_dataset, dataset); + + dataset = dataset.map(inputs => + { + var indices = inputs[0]; + var results = inputs.Skip(1) + .Select(x => array_ops.gather(x, indices, axis: 0)) + .ToArray(); + return new Tensors(results); + }, -1); + + return dataset.with_options(new DatasetOptions { }); + } + + public override int GetSize() => _size; + + public override bool ShouldRecreateIterator() => false; + + Tensor _process_tensorlike(NDArray sample_weights) + { + return tf.convert_to_tensor(sample_weights); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs b/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs new file mode 100644 index 000000000..0002aed1d --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + /// + /// Adds layers that are not connected to the outputs to the model. + /// + /// + public void connect_ancillary_layers(Dictionary created_layers) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs new file mode 100644 index 000000000..375fc9106 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -0,0 +1,141 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + public static Functional from_config(FunctionalConfig config) + { + var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); + var model = new Functional(input_tensors, output_tensors, name: config.Name); + model.connect_ancillary_layers(created_layers); + return model; + } + + /// + /// Reconstructs graph from config object. + /// + /// + /// + public static (Tensors, Tensors, Dictionary) reconstruct_from_config(FunctionalConfig config, Dictionary? created_layers = null) + { + // Layer instances created during the graph reconstruction process. + created_layers = created_layers ?? new Dictionary(); + var node_index_map = new Dictionary<(string, int), int>(); + var node_count_by_layer = new Dictionary(); + var unprocessed_nodes = new Dictionary>(); + // First, we create all layers and enqueue nodes to be processed + foreach (var layer_data in config.Layers) + process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); + + // Then we process nodes in order of layer depth. + // Nodes that cannot yet be processed (if the inbound node + // does not yet exist) are re-enqueued, and the process + // is repeated until all nodes are processed. + while (unprocessed_nodes.Count > 0) + { + foreach(var layer_data in config.Layers) + { + var layer = created_layers[layer_data.Name]; + if (unprocessed_nodes.ContainsKey(layer)) + { + var node_data = unprocessed_nodes[layer]; + // foreach (var node_data in unprocessed_nodes[layer]) + { + process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map); + unprocessed_nodes.Remove(layer); + } + } + } + } + + var input_tensors = new List(); + foreach (var layer_data in config.InputLayers) + { + var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); + var layer = created_layers[layer_name]; + var layer_output_tensors = layer.InboundNodes[node_index].Outputs; + input_tensors.append(layer_output_tensors[tensor_index]); + } + + var output_tensors = new List(); + foreach (var layer_data in config.OutputLayers) + { + var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); + var layer = created_layers[layer_name]; + var layer_output_tensors = layer.InboundNodes[node_index].Outputs; + output_tensors.append(layer_output_tensors[tensor_index]); + } + + return (input_tensors, output_tensors, created_layers); + } + + static void process_layer(Dictionary created_layers, + LayerConfig layer_data, + Dictionary> unprocessed_nodes, + Dictionary node_count_by_layer) + { + ILayer layer = null; + var layer_name = layer_data.Name; + if (created_layers.ContainsKey(layer_name)) + layer = created_layers[layer_name]; + else + { + layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); + + created_layers[layer_name] = layer; + } + node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0); + + var inbound_nodes_data = layer_data.InboundNodes; + foreach (var node_data in inbound_nodes_data) + { + if (!unprocessed_nodes.ContainsKey(layer)) + unprocessed_nodes[layer] = new List() { node_data }; + else + unprocessed_nodes[layer].Add(node_data); + } + } + + static void process_node(ILayer layer, + List nodes_data, + Dictionary created_layers, + Dictionary node_count_by_layer, + Dictionary<(string, int), int> node_index_map) + { + + var input_tensors = new List(); + + for (int i = 0; i < nodes_data.Count; i++) + { + var node_data = nodes_data[i]; + var inbound_layer_name = node_data.Name; + var inbound_node_index = node_data.NodeIndex; + var inbound_tensor_index = node_data.TensorIndex; + + var inbound_layer = created_layers[inbound_layer_name]; + var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; + input_tensors.Add(inbound_node.Outputs[inbound_node_index]); + } + + var output_tensors = layer.Apply(input_tensors); + + // Update node index map. + var output_index = output_tensors[0].KerasHistory.NodeIndex; + node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index; + node_count_by_layer[layer] += 1; + } + + static bool _should_skip_first_node(ILayer layer) + { + return layer is Functional && layer.Layers[0] is InputLayer; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs new file mode 100644 index 000000000..df77e5969 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + public override IKerasConfig get_config() + { + return get_network_config(); + } + + /// + /// Builds the config, which consists of the node graph and serialized layers. + /// + FunctionalConfig get_network_config() + { + var config = new FunctionalConfig + { + Name = name + }; + + var node_conversion_map = new Dictionary(); + foreach (var layer in _self_tracked_trackables) + { + var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + { + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key)) + { + node_conversion_map[node_key] = kept_nodes; + kept_nodes += 1; + } + } + } + + var layer_configs = new List(); + using (SharedObjectSavingScope.Enter()) + { + foreach (var layer in _self_tracked_trackables) + { + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + { + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + filtered_inbound_nodes.append(node_data); + } + } + + var layer_config = generic_utils.serialize_layer_to_config(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } + } + config.Layers = layer_configs; + + // Gather info about inputs and outputs. + var model_inputs = new List(); + foreach (var i in range(_input_layers.Count)) + { + var (layer, node_index, tensor_index) = _input_coordinates[i]; + var node_key = _make_node_key(layer.Name, node_index); + if (!NetworkNodes.Contains(node_key)) + continue; + var new_node_index = node_conversion_map[node_key]; + model_inputs.append(new NodeConfig + { + Name = layer.Name, + NodeIndex = new_node_index, + TensorIndex = tensor_index + }); + } + config.InputLayers = model_inputs; + + var model_outputs = new List(); + foreach (var i in range(_output_layers.Count)) + { + var (layer, node_index, tensor_index) = _output_coordinates[i]; + var node_key = _make_node_key(layer.Name, node_index); + if (!NetworkNodes.Contains(node_key)) + continue; + var new_node_index = node_conversion_map[node_key]; + model_outputs.append(new NodeConfig + { + Name = layer.Name, + NodeIndex = new_node_index, + TensorIndex = tensor_index + }); + } + config.OutputLayers = model_outputs; + + return config; + } + + string _make_node_key(string layer_name, int node_index) + => $"{layer_name}_ib-{node_index}"; + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs new file mode 100644 index 000000000..75854d82c --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -0,0 +1,392 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + /// + /// A `Functional` model is a `Model` defined as a directed graph of layers. + /// + public partial class Functional : Model + { + List _output_layers; + List _input_layers; + List _input_coordinates; + List _output_coordinates; + public string[] NetworkNodes { get; set; } + + Dictionary tensor_usage_count; + + /// + /// Dictionary of layer dependencies to be included in the checkpoint. + /// + public IDictionary LayerCheckpointDependencies + { + get + { + int weight_layer_index = 0; + Dictionary dependencies = new(); + for(int i = 0; i < Layers.Count; i++) + { + var layer = Layers[i]; + var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); + if(weights.Count > 0) + { + dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; + weight_layer_index++; + } + dependencies[$"layer-{i}"] = layer; + } + return dependencies; + } + } + + public Functional(Tensors inputs, Tensors outputs, string name = null) + : base(new ModelArgs + { + Name = name, + Inputs = inputs, + Outputs = outputs + }) + { + Initialize(inputs, outputs, name); + } + + internal void Initialize(Tensors inputs, Tensors outputs, string name = null) + { + _input_layers = new List(); + _output_layers = new List(); + _input_coordinates = new List(); + _output_coordinates = new List(); + tensor_usage_count = new Dictionary(); + if (this is Sequential) + return; + _init_graph_network(inputs, outputs); + } + + protected void _init_graph_network(Tensors inputs, Tensors outputs) + { + _is_graph_network = true; + this.inputs = inputs; + this.outputs = outputs; + built = true; + if(inputs.Length > 0) + { + _buildInputShape = inputs.shape; + } + else + { + _buildInputShape = new TensorShapeConfig(); + } + + if (outputs.Any(x => x.KerasHistory == null)) + base_layer_utils.create_keras_history(outputs); + + // Build self._output_layers: + foreach (var x in outputs) + { + var (layer, node_index, tensor_index) = x.KerasHistory; + _output_layers.append(layer); + _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + } + + // Build self._input_layers: + foreach (var x in inputs) + { + var (layer, node_index, tensor_index) = x.KerasHistory; + _input_layers.append(layer); + _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + } + + // Keep track of the network's nodes and layers. + (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); + + // Build self.input_names and self.output_names. + _set_output_names(); + + ComputeTensorUsageCount(); + } + + /// + /// Assigns unique names to the Network's outputs. + /// + void _set_output_names() + { + var uniquified = new List(); + var output_names = new List(); + var prefix_count = new Dictionary(); + + foreach (var layer in _output_layers) + { + var proposal = layer.Name; + while (output_names.Contains(proposal)) + { + var existing_count = prefix_count.Get(layer.Name, 1); + proposal = $"{layer.Name}_{existing_count}"; + prefix_count[layer.Name] = existing_count + 1; + } + output_names.add(proposal); + uniquified.append(proposal); + } + + this.output_names = uniquified.ToArray(); + } + + void ComputeTensorUsageCount() + { + var available_tensors = inputs.Select(x => x.Id).ToList(); + var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray(); + foreach (var depth in depth_keys) + { + foreach (var node in NodesByDepth[depth]) + { + var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray(); + if (input_tensors.issubset(available_tensors)) + { + foreach (var tensor in node.KerasInputs) + { + if (!tensor_usage_count.ContainsKey(tensor.Id)) + tensor_usage_count[tensor.Id] = 0; + tensor_usage_count[tensor.Id] += 1; + } + + foreach (var output_tensor in node.Outputs) + available_tensors.Add(output_tensor.Id); + } + } + } + + foreach (var tensor in outputs) + { + if (!tensor_usage_count.ContainsKey(tensor.Id)) + tensor_usage_count[tensor.Id] = 0; + tensor_usage_count[tensor.Id] += 1; + } + } + + /// + /// Validates a network's topology and gather its layers and nodes. + /// + /// + /// + (string[], Dictionary>, List, Dictionary>) MapGraphNetwork(Tensors inputs, Tensors outputs) + { + var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); + var network_nodes = nodes_in_decreasing_depth + .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) + .ToList(); + + var nodes_depths = new Dictionary(); + var layers_depths = new Dictionary(); + + nodes_in_decreasing_depth.Reverse(); + foreach (var node in nodes_in_decreasing_depth) + { + // If the depth is not set, the node has no outbound nodes (depth 0). + int depth = nodes_depths.SetDefault(node, 0); + // Update the depth of the corresponding layer + int previous_depth = layers_depths.Get(node.Layer, 0); + // If we've seen this layer before at a higher depth, + // we should use that depth instead of the node depth. + // This is necessary for shared layers that have inputs at different + // depth levels in the graph. + depth = Math.Max(depth, previous_depth); + layers_depths[node.Layer] = depth; + nodes_depths[node] = depth; + + // Update the depth of inbound nodes. + // The "depth" of a node is the max of the depths + // of all nodes it is connected to + 1. + foreach (var node_dep in node.ParentNodes) + { + previous_depth = nodes_depths.Get(node_dep, 0); + nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); + } + } + + // Handle inputs that are not connected to outputs. + // We do not error out here because the inputs may be used to compute losses + // and metrics. + foreach (var input_t in inputs) + { + var (input_layer, _, _) = input_t.KerasHistory; + if (!layers_depths.ContainsKey(input_layer)) + { + layers_depths[input_layer] = 0; + layer_indices[input_layer] = -1; + nodes_depths[input_layer.InboundNodes[0]] = 0; + network_nodes.Add(MakeNodeKey(input_layer.Name, 0)); + } + } + + // Build a dict {depth: list of nodes with this depth} + var nodes_by_depth = new Dictionary>(); + foreach (var (node, depth) in enumerate(nodes_depths)) + { + if (!nodes_by_depth.ContainsKey(depth)) + nodes_by_depth[depth] = new List(); + nodes_by_depth[depth].Add(node); + } + + var layers_by_depth = new Dictionary>(); + foreach (var (layer, depth) in enumerate(layers_depths)) + { + if (!layers_by_depth.ContainsKey(depth)) + layers_by_depth[depth] = new List(); + layers_by_depth[depth].Add(layer); + } + + // Get sorted list of layer depths. + var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse(); + + // Set self.layers ordered by depth. + var layers = new List(); + foreach (var depth in depth_keys) + { + var layers_for_depth = layers_by_depth[depth]; + + // Network.layers needs to have a deterministic order: + // here we order them by traversal order. + layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList(); + layers.AddRange(layers_for_depth); + } + + // Get sorted list of node depths. + depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse(); + + return (network_nodes.ToArray(), nodes_by_depth, layers, layers_by_depth); + } + + string MakeNodeKey(string layer_name, int node_index) + => $"{layer_name}_ib-{node_index}"; + + /// + /// This method topologically sorts nodes in order from inputs to outputs. + /// + /// + (List, Dictionary) BuildMap(Tensors outputs) + { + var finished_nodes = new List(); + var nodes_in_progress = new List(); + var nodes_in_decreasing_depth = new List(); + var layer_indices = new Dictionary(); + foreach (var output in outputs) + BuildMapHelper(output, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + layer_indices); + + return (nodes_in_decreasing_depth, layer_indices); + } + + void BuildMapHelper(Tensor tensor, + List finished_nodes, + List nodes_in_progress, + List nodes_in_decreasing_depth, + Dictionary layer_indices) + { + var (layer, node_index, _) = tensor.KerasHistory; + var node = layer.InboundNodes[node_index] as Node; + + // Don't repeat work for shared subgraphs + if (finished_nodes.Contains(node)) + return; + + // Prevent cycles. + if (nodes_in_progress.Contains(node)) + throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); + + // Store the traversal order for layer sorting. + if (!layer_indices.ContainsKey(layer)) + layer_indices[layer] = layer_indices.Count; + + // Propagate to all previous tensors connected to this node. + nodes_in_progress.Add(node); + if (!node.is_input) + { + foreach (var k_tensor in node.KerasInputs) + { + BuildMapHelper(k_tensor, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + layer_indices); + } + } + + finished_nodes.Add(node); + nodes_in_progress.Remove(node); + nodes_in_decreasing_depth.append(node); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var tensor_dict = new Dictionary>(); + // map input values + foreach (var (x, y) in zip(this.inputs, inputs)) + { + tensor_dict[x.Id] = new Queue(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y)); + } + + var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); + + foreach (var depth in depth_keys) + { + var nodes = NodesByDepth[depth]; + foreach (Node node in nodes) + { + // Input tensors already exist. + if (node.is_input) + continue; + + var layer_inputs = node.MapArguments(tensor_dict); + + tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); + var outputs = node.Layer.Apply(layer_inputs, training: training ?? false); + foreach (var output in outputs.Where(x => x != null)) + tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); + // Update tensor_dict for next or later input + foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs)) + tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); + } + } + + var output_tensors = new Tensors(); + + foreach (var x in outputs) + output_tensors.Add(tensor_dict[x.Id].Dequeue()); + + return output_tensors; + } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) + .ToDictionary(x => x.Key, x => x.Value); + } + + protected override void _init_set_name(string name, bool zero_based = true) + { + if (string.IsNullOrEmpty(name)) + { + string class_name = GetType().Name; + if (this.GetType() == typeof(Functional)) + { + class_name = "Model"; + } + this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); + } + else + { + this.name = name; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs new file mode 100644 index 000000000..df8198395 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public interface IAccumulator + { + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs new file mode 100644 index 000000000..8fe1764d6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + /// + /// Functional object that defines a shardable computation. + /// + public interface ICombiner + { + void Compute(Tensor values, IAccumulator accumulator = null); + void Merge(); + void Extract(); + IAccumulator Restore(); + void Serialize(); + void Deserialize(); + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs b/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs new file mode 100644 index 000000000..2925739bc --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs @@ -0,0 +1,63 @@ +using System; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + protected virtual IVariableV1 add_weight(string name, + Shape shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + IRegularizer regularizer = null, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None, + bool trainable = true, + Func getter = null) + { + // Initialize variable when no initializer provided + if (initializer == null) + { + // If dtype is DT_FLOAT, provide a uniform unit scaling initializer + if (dtype.is_floating()) + initializer = tf.glorot_uniform_initializer; + else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) + initializer = tf.zeros_initializer; + else if(getter is null) + throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); + } + + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + + var args = new VariableArgs + { + Name = name, + Shape = shape, + DType = dtype, + Getter = getter ?? base_layer_utils.make_variable, + Overwrite = true, + Initializer = initializer, + Synchronization = synchronization, + Aggregation = aggregation, + Trainable = trainable + }; + var variable = _add_variable_with_custom_getter(args); + + if (regularizer != null) + { + var name_in_scope = variable.Name.Split(':')[0]; + _handle_weight_regularization(name_in_scope, variable, regularizer); + } + + //backend.track_variable(variable); + if (trainable == true) + _trainable_weights.Add(variable); + else + _non_trainable_weights.Add(variable); + + return variable; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs new file mode 100644 index 000000000..a3831bffa --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -0,0 +1,62 @@ +using System.Threading; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + /// + /// Wraps `call`, applying pre- and post-processing steps. + /// + /// + /// + /// + /// + public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null) + { + if (callContext.Value == null) + callContext.Value = new CallContext(); + + if (_in_functional_construction_mode(inputs)) + return FunctionalConstructionCall(inputs); + + var eager = tf.executing_eagerly(); + using var ctxManager = CallContext.enter(build_graph: false); + + string nameScope = eager ? name : _name_scope(); + var scope = ops.name_scope(nameScope); + scope.__enter__(); + + if (!built) + MaybeBuild(inputs); + + var outputs = Call(inputs, state: states, training: training); + + // memory leak + // _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); + + // TODO(Rinne): set save spec if null + + scope.__exit__(); + + return outputs; + } + + // TODO(Rinne): remove it and completely fix issue 1084 + [Obsolete] + private bool _enforce_layer_construction = false; + [Obsolete] + internal void enforce_layer_construction() + { + _enforce_layer_construction = true; + } + [Obsolete] + internal void unset_layer_construction() + { + _enforce_layer_construction = false; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs b/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs new file mode 100644 index 000000000..dd037e243 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + public IEnumerable _flatten_layers(bool recursive = true, bool include_self = true) + { + if (include_self) + yield return this; + + var seen_object_ids = new List(); + var deque = new Queue(_self_tracked_trackables); + while (!deque.empty()) + { + var layer_or_container = deque.Dequeue(); + var layer_or_container_id = layer_or_container.GetHashCode(); + if (seen_object_ids.Contains(layer_or_container_id)) + continue; + seen_object_ids.Add(layer_or_container_id); + yield return layer_or_container; + if (recursive) + deque.extendleft(layer_or_container.Layers); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs new file mode 100644 index 000000000..e4023c3fd --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -0,0 +1,45 @@ +using System; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + Tensors FunctionalConstructionCall(Tensors inputs) + { + if (base_layer_utils.needs_keras_history(inputs)) + base_layer_utils.create_keras_history(inputs); + + Tensors outputs = null; + using var ctxManager = CallContext.enter(build_graph: true); + + var graph = keras.backend.get_graph(); + graph.as_default(); + + var scope = ops.name_scope(_name_scope()); + scope.__enter__(); + + MaybeBuild(inputs); + + // Wrapping `call` function in autograph to allow for dynamic control + // flow and control dependencies in call. We are limiting this to + // subclassed layers as autograph is strictly needed only for + // subclassed layers and models. + // tf_convert will respect the value of autograph setting in the + // enclosing tf.function, if any. + if (!dynamic) + throw new NotImplementedException(""); + + outputs = Call(inputs); + + _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); + + scope.__exit__(); + graph.Exit(); + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs new file mode 100644 index 000000000..81fc26355 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + public virtual List Layers => _self_tracked_trackables; + + protected void StackLayers(params ILayer[] layers) + { + _self_tracked_trackables.AddRange(layers); + } + + public virtual Shape ComputeOutputShape(Shape input_shape) + => throw new NotImplementedException(""); + + protected List _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) + { + List res = new(); + var nested_layers = _flatten_layers(false, false); + foreach (var layer in nested_layers) + { + if (layer is Layer l) + { + if (include_trainable == true && include_non_trainable == true) + { + res.AddRange(l.Variables); + } + else if (include_trainable == true && include_non_trainable == false) + { + res.AddRange(l.TrainableVariables); + } + else if(include_trainable == false && include_non_trainable == true) + { + res.AddRange(l.NonTrainableVariables); + } + } + } + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.LoadWeights.cs b/src/TensorFlowNET.Keras/Engine/Layer.LoadWeights.cs new file mode 100644 index 000000000..fa833da35 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.LoadWeights.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + /// + /// Loads all layer weights, either from a TensorFlow or an HDF5 weight file. + /// + /// + public void load_weights(string filepath) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs new file mode 100644 index 000000000..49811417e --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -0,0 +1,32 @@ +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Engine; + +public abstract partial class Layer +{ + public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + + public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + + public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + IDictionary children; + if (save_type == SaveType.SAVEDMODEL) + { + Debug.Assert(cache is not null); + children = TrackableSavedModelSaver.trackable_children(cache); + } + else + { + children = new Dictionary(); + } + + return children.Concat(base._trackable_children(save_type, cache)).GroupBy(x => x.Key).Select(g => g.First()).ToDictionary(x => x.Key, x => x.Value); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.State.cs b/src/TensorFlowNET.Keras/Engine/Layer.State.cs new file mode 100644 index 000000000..35f1a8527 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.State.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + protected Dictionary trainable_state; + protected Dictionary _compiled_trainable_state; + + /// + /// Get the `trainable` state of each sublayer. + /// + /// + protected Dictionary _get_trainable_state() + { + trainable_state = new Dictionary(); + foreach (var layer in _flatten_layers()) + trainable_state[layer] = layer.Trainable; + return trainable_state; + } + + void _set_trainable_state(Dictionary trainable_state) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs new file mode 100644 index 000000000..2f758a850 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -0,0 +1,489 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Tensorflow.Eager; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using Tensorflow.NumPy; +using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Util; +using static Tensorflow.Binding; +using Tensorflow.Framework; +using Tensorflow.Sessions; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Engine +{ + /// + /// Base layer class. + /// A layer is a class implementing common neural networks operations, such + /// as convolution, batch norm, etc. These operations require managing weights, + /// losses, updates, and inter-layer connectivity. + /// + public abstract partial class Layer : AutoTrackable, ILayer + { + /// + /// Arguments initialize layer. + /// + internal LayerArgs args; + + /// + /// Indicates whether `build` needs to be called upon layer call, to create + /// the layer's weights. + /// + protected bool built; + public bool Built + { + get + { + return built; + } + internal set + { + built = value; + } + } + public bool Trainable => args.Trainable; + public TF_DataType DType => args.DType; + public bool AutoCast => args.Autocast; + public IRegularizer ActivityRegularizer => args.ActivityRegularizer; + + /// + /// A stateful layer is a layer whose updates are run during inference too, + /// for instance stateful RNNs. + /// + protected bool stateful; + /// + /// Provides information about which inputs are compatible with the layer. + /// + protected InputSpec inputSpec; + public InputSpec InputSpec => inputSpec; + bool dynamic = true; + public bool SupportsMasking { get; set; } + protected List _trainable_weights; + + public virtual List TrainableVariables => TrainableWeights; + + protected List _non_trainable_weights; + public List NonTrainableVariables => NonTrainableWeights; + public List Variables => Weights; + + public virtual List TrainableWeights + { + get + { + if (!this.Trainable) + { + return new List(); + } + var children_weights = _gather_children_variables(true); + return children_weights.Concat(_trainable_weights).Distinct().ToList(); + } + } + + public virtual List NonTrainableWeights + { + get + { + if (!this.Trainable) + { + var children_weights = _gather_children_variables(true, true); + return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); + } + else + { + var children_weights = _gather_children_variables(include_non_trainable: true); + return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); + } + } + } + + public virtual List Weights + { + get + { + return TrainableWeights.Concat(NonTrainableWeights).ToList(); + } + set + { + if (Weights.Count() != value.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(value)}, but the layer was " + + $"expecting {len(Weights)} weights."); + foreach (var (this_w, v_w) in zip(Weights, value)) + this_w.assign(v_w, read_value: true); + } + } + + public virtual void set_weights(IEnumerable weights) + { + if (Weights.Count() != weights.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(weights)}, but the layer was " + + $"expecting {len(Weights)} weights."); + + + + // check if the shapes are compatible + var weight_index = 0; + foreach(var w in weights) + { + if (!Weights[weight_index].AsTensor().is_compatible_with(w)) + { + throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); + } + weight_index++; + } + + if (tf.executing_eagerly()) + { + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + else + { + // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. + + //Tensors assign_ops = new Tensors(); + //var feed_dict = new FeedDict(); + + //Graph g = tf.Graph().as_default(); + //foreach (var (this_w, v_w) in zip(Weights, weights)) + //{ + // var tf_dtype = this_w.dtype; + // var placeholder_shape = v_w.shape; + // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); + // var assign_op = this_w.assign(assign_placeholder); + // assign_ops.Add(assign_op); + // feed_dict.Add(assign_placeholder, v_w); + //} + //var sess = tf.Session().as_default(); + //sess.run(assign_ops, feed_dict); + + //g.Exit(); + } + } + + public List get_weights() + { + List weights = new List(); + weights.AddRange(Weights.ConvertAll(x => x.numpy())); + return weights; + } + + protected int id; + public int Id => id; + protected string name; + protected string base_name; + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } + + protected bool computePreviousMask; + protected List updates; + public KerasShapesWrapper BatchInputShape => args.BatchInputShape; + protected KerasShapesWrapper _buildInputShape = null; + public KerasShapesWrapper BuildInputShape => _buildInputShape; + + List inboundNodes; + public List InboundNodes => inboundNodes; + List outboundNodes; + public List OutboundNodes => outboundNodes; + + public Dictionary SerializedAttributes { get; set; } + + ThreadLocal callContext = new ThreadLocal(); + public CallContext CallContext => callContext.Value; + public Tensor[] input + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].input_tensors; + } + return null; + } + } + public Dictionary> NodesByDepth { get; set; } + public Shape OutputShape + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].Outputs.shape; + } + return null; + } + } + protected List _self_tracked_trackables; + + /// + /// If this value is set, the behavior of layer call will be changed to directly calling this function. + /// + public Func? ReplacedCall { get; set; } = null; + + public Layer(LayerArgs args) + { + Initialize(args); + } + + internal virtual void Initialize(LayerArgs args) + { + this.args = args; + // A stateful layer is a layer whose updates are run during inference too, + // for instance stateful RNNs. + stateful = false; + // Indicates whether `build` needs to be called upon layer call, to create + // the layer's weights. + built = false; + SupportsMasking = false; + + id = ops.uid_layer(); + _init_set_name(args.Name); + _trainable_weights = new List(); + _non_trainable_weights = new List(); + computePreviousMask = false; + updates = new List(); + _self_tracked_trackables = new List(); + + inboundNodes = new List(); + outboundNodes = new List(); + + // Manage input shape information if passed. + if (args.BatchInputShape == null && args.InputShape != null) + { + args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); + } + } + + bool _in_functional_construction_mode(Tensors inputs) + { + return tf.Context.executing_eagerly() + && inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count() || _enforce_layer_construction; + } + + public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) + => _set_connectivity_metadata_(inputs, outputs); + + private void _set_connectivity_metadata_(Tensors inputs, Tensors outputs) + { + var node = new Node(new NodeArgs + { + InputTensors = inputs, + Outputs = outputs + }); + node.Connect(this); + } + + private void _handle_activity_regularization(Tensors inputs, Tensors outputs) + { + //if(_activity_regularizer != null) + { + + } + } + + private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask) + { + + } + + private Tensor compute_mask(Tensor inputs, Tensor mask = null) + { + return null; + } + + /// + /// Subclass has to override this method. + /// + /// + /// + /// + /// + protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if(ReplacedCall is not null) + { + return ReplacedCall(inputs); + } + return inputs; + } + + protected virtual string _name_scope() + { + return Name; + } + + protected void MaybeBuild(Tensors inputs) + { + // Check input assumptions set before layer building, e.g. input rank. + if (built) + return; + if (DType == TF_DataType.DtInvalid) + args.DType = inputs.dtype; + + tf.init_scope(); + + bool need_restore_mode = false; + if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function()) + { + need_restore_mode = true; + tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); + } + + build(new KerasShapesWrapper(inputs.shape)); + + if (need_restore_mode) + tf.Context.restore_mode(); + + built = true; + } + + public virtual void build(KerasShapesWrapper input_shape) + { + _buildInputShape = input_shape; + built = true; + } + + protected virtual void add_loss(Func losses) + { + + } + + /// + /// Create lambdas which compute regularization losses. + /// + /// + /// + /// + void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) + { + + add_loss(() => tf_with(ops.name_scope(name + "/Regularizer"), scope => + regularizer.Apply(new RegularizerArgs(variable.AsTensor()) + { + + }) + )); + } + + /*protected virtual void add_update(Tensor[] updates, bool inputs = false) + { + var updates_op = updates.Select(x => x.op).ToArray(); + this.updates.AddRange(updates_op); + }*/ + + // Determine layer name (non-unique). + protected virtual void _init_set_name(string name, bool zero_based = true) + { + base_name = name; + this.name = name; + if (name == null) + { + base_name = generic_utils.to_snake_case(this.GetType().Name); + this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based); + } + } + + public int count_params() + { + if (Trainable) + return layer_utils.count_params(this, Weights); + return 0; + } + + public virtual IKerasConfig get_config() + => args; + + public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + + } + + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with "_self_setattr_tracking". + + value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + + foreach(var val in nest.flatten(value)) + { + if(val is Metric) + { + // TODO(Rinne): deal with metrics. + } + } + + // TODO(Rinne): deal with "_auto_track_sub_layers". + + foreach(var val in nest.flatten(value)) + { + if(val is not IVariableV1 variable) + { + continue; + } + if (variable.Trainable) + { + if (_trainable_weights.Contains(variable)) + { + continue; + } + _trainable_weights.Add(variable); + } + else + { + if (_non_trainable_weights.Contains(variable)) + { + continue; + } + _non_trainable_weights.Add(variable); + } + keras.backend.track_variable(variable); + } + + // Directly use the implementation of `Trackable`. + var t = this.GetType(); + var field_info = t.GetField(name); + if (field_info is not null) + { + field_info.SetValue(this, value); + } + else + { + CustomizedFields[name] = value; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs new file mode 100644 index 000000000..c06fca593 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs @@ -0,0 +1,83 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; + +namespace Tensorflow.Keras.Engine +{ + public class LossesContainer : Container + { + ILossFunc _user_losses; + ILossFunc _losses; + Mean _loss_metric; + bool _built; + Tensor[] _per_output_metrics; + + public LossesContainer(ILossFunc losses, string[] output_names = null) + : base(output_names) + { + _user_losses = losses; + _losses = losses; + _loss_metric = new Mean(name: "loss"); + _built = false; + } + + /// + /// Computes the overall loss. + /// + /// + /// + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + if (!_built) + Build(y_pred); + var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); + var loss_metric_value = loss_value; + var batch_dim = array_ops.shape(y_true)[0]; + + var loss_values = new List(); + var loss_metric_values = new List(); + + /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE + || _losses.Reduction == ReductionV2.AUTO) + loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/ + loss_values.append(loss_value); + loss_metric_values.append(loss_metric_value); + + if (loss_values.Count > 0) + { + var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray()); + _loss_metric.update_state(total_loss_metric_value, batch_dim); + // loss_values = losses_utils.cast_losses_to_common_dtype(loss_values); + var total_loss = math_ops.add_n(loss_values.ToArray()); + return total_loss; + } + else + { + // Ok for a model to have no compiled loss. + return array_ops.zeros(Shape.Null); + } + } + + public void Build(Tensor y_pred) + { + _create_metrics(); + _built = true; + } + + void _create_metrics() + { + // _per_output_metrics = _output_names.Select(x => null); + } + + public IEnumerable metrics + { + get + { + if (!_built) + return new List(); + + return new[] { _loss_metric }; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs new file mode 100644 index 000000000..ee6384107 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -0,0 +1,116 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Engine +{ + public class MetricsContainer : Container + { + IMetricFunc[] _user_metrics = new IMetricFunc[0]; + string[] _metric_names = new string[0]; + Metric[] _metrics = new Metric[0]; + List _metrics_in_order = new List(); + + public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null) + : base(output_names) + { + _user_metrics = metrics; + _built = false; + } + + public MetricsContainer(string[] metrics, string[] output_names = null) + : base(output_names) + { + _metric_names = metrics; + _built = false; + } + + public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + if (!_built) + Build(y_true, y_pred); + + foreach (var metric_obj in _metrics_in_order) + metric_obj.update_state(y_true, y_pred); + } + + void Build(Tensor y_true, Tensor y_pred) + { + _metrics = _get_metric_objects(_metric_names, y_true, y_pred); + _set_metric_names(); + _create_ordered_metrics(); + _built = true; + } + + void _set_metric_names() + { + + } + + void _create_ordered_metrics() + { + foreach (var m in _metrics) + _metrics_in_order.append(m); + + foreach(var m in _user_metrics) + _metrics_in_order.append(m); + } + + Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p) + { + return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray(); + } + + public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p) + { + Func metric_obj = null; + if (metric == "accuracy" || metric == "acc") + { + var y_t_rank = y_t.rank; + var y_p_rank = y_p.rank; + var y_t_last_dim = y_t.shape[y_t.shape.ndim - 1]; + var y_p_last_dim = y_p.shape[y_p.shape.ndim - 1]; + + bool is_binary = y_p_last_dim == 1; + bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; + + if (is_binary) + metric_obj = keras.metrics.binary_accuracy; + else if (is_sparse_categorical) + metric_obj = keras.metrics.sparse_categorical_accuracy; + else + metric_obj = keras.metrics.categorical_accuracy; + + metric = "accuracy"; + } + else if(metric == "mean_absolute_error" || metric == "mae") + { + metric_obj = keras.metrics.mean_absolute_error; + metric = "mean_absolute_error"; + } + else if (metric == "mean_absolute_percentage_error" || metric == "mape") + { + metric_obj = keras.metrics.mean_absolute_percentage_error; + metric = "mean_absolute_percentage_error"; + } + else + throw new NotImplementedException(""); + + return new MeanMetricWrapper(metric_obj, metric); + } + + public IEnumerable metrics + { + get + { + if (!_built) + return new List(); + + return _metrics_in_order; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Build.cs b/src/TensorFlowNET.Keras/Engine/Model.Build.cs new file mode 100644 index 000000000..233363832 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Build.cs @@ -0,0 +1,51 @@ +using System; +using System.Linq; +using Tensorflow.Graphs; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public override void build(KerasShapesWrapper input_shape) + { + if (_is_graph_network || this is Functional || this is Sequential) + { + base.build(input_shape); + return; + } + + if(input_shape is not null && this.inputs is null) + { + var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); + graph.as_default(); + var shapes = input_shape.ToShapeArray(); + var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray()); + try + { + Call(x, training: false); + } + catch (InvalidArgumentError) + { + throw new ValueError("You cannot build your model by calling `build` " + + "if your layers do not support float type inputs. " + + "Instead, in order to instantiate and build your " + + "model, `call` your model on real tensor data (of the correct dtype)."); + } + catch (TypeError) + { + throw new ValueError("You cannot build your model by calling `build` " + + "if your layers do not support float type inputs. " + + "Instead, in order to instantiate and build your " + + "model, `call` your model on real tensor data (of the correct dtype)."); + } + graph.Exit(); + } + + base.build(input_shape); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs new file mode 100644 index 000000000..dabdccf9d --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -0,0 +1,108 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + LossesContainer compiled_loss; + MetricsContainer compiled_metrics; + + public void compile(IOptimizer optimizer, + ILossFunc loss) + { + this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs + { + }); + + this.loss = loss ?? new MeanSquaredError(); + + compiled_loss = new LossesContainer(this.loss, output_names: output_names); + compiled_metrics = new MetricsContainer(new string[0], output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + } + + public void compile(IOptimizer optimizer, + ILossFunc loss, + string[] metrics) + { + this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs + { + }); + + this.loss = loss ?? new MeanSquaredError(); + + compiled_loss = new LossesContainer(this.loss, output_names: output_names); + compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + } + + public void compile(string optimizer, + string loss, + string[] metrics) + { + this.optimizer = optimizer switch + { + "rmsprop" => new RMSprop(new RMSpropArgs + { + + }), + _ => new RMSprop(new RMSpropArgs + { + }) + }; + + this.loss = loss switch + { + "mse" => new MeanSquaredError(), + "mae" => new MeanAbsoluteError(), + _ => new MeanSquaredError() + }; + + compiled_loss = new LossesContainer(this.loss, output_names: output_names); + compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + } + + public void compile(IOptimizer optimizer, + ILossFunc loss, + IMetricFunc[] metrics) + { + this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs + { + }); + + this.loss = loss ?? new MeanSquaredError(); + + compiled_loss = new LossesContainer(this.loss, output_names: output_names); + compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs new file mode 100644 index 000000000..ec99d7ef9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -0,0 +1,206 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Callbacks; +using Tensorflow.Keras.Engine.DataAdapters; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + /// + /// Returns the loss value and metrics values for the model in test mode. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Dictionary evaluate(NDArray x, NDArray y, + int batch_size = -1, + int verbose = 1, + NDArray sample_weight = null, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false, + bool return_dict = false, + bool is_val = false + ) + { + if (x.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); + } + var data_handler = new DataHandler(new DataHandlerArgs + { + X = x, + Y = y, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + SampleWeight = sample_weight, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Steps = data_handler.Inferredsteps + }); + + return evaluate(data_handler, callbacks, is_val, test_function); + } + + public Dictionary evaluate( + IEnumerable x, + Tensor y, + int verbose = 1, + NDArray sample_weight = null, + bool is_val = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(x.ToArray()), + Y = y, + Model = this, + SampleWeight = sample_weight, + StepsPerExecution = _steps_per_execution + }); + + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Steps = data_handler.Inferredsteps + }); + + return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function); + } + + public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = x, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Steps = data_handler.Inferredsteps + }); + + Func> testFunction; + + if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || + data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) + { + testFunction = test_step_multi_inputs_function; + } + else + { + testFunction = test_function; + } + + return evaluate(data_handler, callbacks, is_val, testFunction); + } + + /// + /// Internal bare implementation of evaluate function. + /// + /// Interations handling objects + /// + /// The function to be called on each batch of data. + /// Whether it is validation or test. + /// + Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func) + { + callbacks.on_test_begin(); + + var logs = new Dictionary(); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); + foreach (var step in data_handler.steps()) + { + callbacks.on_test_batch_begin(step); + logs = test_func(data_handler, iterator); + var end_step = step + data_handler.StepIncrement; + if (!is_val) + callbacks.on_test_batch_end(end_step, logs); + GC.Collect(); + } + } + callbacks.on_test_end(logs); + var results = new Dictionary(logs); + return results; + } + + Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : + test_step(data_handler, data[0], data[1], data[2]); + tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); + return outputs; + } + + Dictionary test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + var outputs = data.Length == 2 ? + test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : + test_step( + data_handler, + new Tensors(data.Take(x_size).ToArray()), + new Tensors(data.Skip(x_size).Take(x_size).ToArray()), + new Tensors(data.Skip(2 * x_size).ToArray())); + tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); + return outputs; + } + + + Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y) + { + (x,y) = data_handler.DataAdapter.Expand1d(x, y); + + var y_pred = Apply(x, training: false); + + var loss = compiled_loss.Call(y, y_pred); + compiled_metrics.update_state(y, y_pred); + return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); + } + + Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight) + { + (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); + var y_pred = Apply(x, training: false); + var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight); + compiled_metrics.update_state(y, y_pred); + return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs new file mode 100644 index 000000000..e1303513e --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -0,0 +1,341 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; +using System.Diagnostics; +using Tensorflow.Keras.Callbacks; +using Tensorflow.Util; +using OneOf; + +namespace Tensorflow.Keras.Engine +{ + + + public partial class Model + { + /// + /// Trains the model for a fixed number of epochs (iterations on a dataset). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public ICallback fit(NDArray x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + float validation_split = 0f, + ValidationDataPack validation_data = null, + int validation_step = 10, + bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + if (x.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); + } + + // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. + sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); + + if (validation_split != 0f && validation_data == null) + { + ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); + } + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = x, + Y = y, + SampleWeight = sample_weight, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + ClassWeight = class_weight, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, + train_step_func: train_step_function); + } + + + public ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + float validation_split = 0f, + ValidationDataPack validation_data = null, + bool shuffle = true, + Dictionary class_weight = null, + NDArray sample_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + foreach(var tx in x) + { + if (tx.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); + } + } + + sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); + + if (validation_split != 0f && validation_data == null) + { + ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); + } + + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(x.ToArray()), + Y = y, + SampleWeight = sample_weight, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + ClassWeight = class_weight, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || + data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) + { + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, + train_step_func: train_step_multi_inputs_function); + } + else + { + return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, + train_step_func: train_step_function); + } + } + + public ICallback fit(IDatasetV2 dataset, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + IDatasetV2 validation_data = null, + int validation_step = 10, + bool shuffle = true, + Dictionary class_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + + var data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = dataset, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + ClassWeight = class_weight, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + Func> trainStepFunction; + + if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || + data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) + { + trainStepFunction = train_step_multi_inputs_function; + } + else + { + trainStepFunction = train_step_function; + } + + return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, + train_step_func: trainStepFunction); + } + + History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List callbackList, IDatasetV2 validation_data, + Func> train_step_func) + { + stop_training = false; + _train_counter.assign(0); + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = epochs, + Steps = data_handler.Inferredsteps + }); + + if (callbackList != null) + { + foreach(var callback in callbackList) + callbacks.callbacks.add(callback); + } + + callbacks.on_train_begin(); + + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); + callbacks.on_epoch_begin(epoch); + // data_handler.catch_stop_iteration(); + var logs = new Dictionary(); + long End_step = 0; + foreach (var step in data_handler.steps()) + { + callbacks.on_train_batch_begin(step); + logs = train_step_func(data_handler, iterator); + var end_step = step + data_handler.StepIncrement; + End_step = end_step; + callbacks.on_train_batch_end(end_step, logs); + GC.Collect(); + } + + if (validation_data != null) + { + if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) + continue; + + var val_logs = evaluate(validation_data); + foreach(var log in val_logs) + { + logs["val_" + log.Key] = log.Value; + } + callbacks.on_train_batch_end(End_step, logs); + } + + GC.Collect(); + + callbacks.on_epoch_end(epoch, logs); + + if (stop_training) + { + break; + } + } + + return callbacks.History; + } + + History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, ValidationDataPack validation_data, + Func> train_step_func) + { + stop_training = false; + _train_counter.assign(0); + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = epochs, + Steps = data_handler.Inferredsteps + }); + + if (callbackList != null) + { + foreach (var callback in callbackList) + callbacks.callbacks.add(callback); + } + + callbacks.on_train_begin(); + + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); + callbacks.on_epoch_begin(epoch); + // data_handler.catch_stop_iteration(); + var logs = new Dictionary(); + long End_step = 0; + foreach (var step in data_handler.steps()) + { + callbacks.on_train_batch_begin(step); + logs = train_step_func(data_handler, iterator); + var end_step = step + data_handler.StepIncrement; + End_step = end_step; + callbacks.on_train_batch_end(end_step, logs); + GC.Collect(); + } + + if (validation_data != null) + { + NDArray val_x; + NDArray[] val_x_array; + NDArray val_y; + NDArray val_sample_weight; + Dictionary val_logs; + if (!validation_data.val_x_is_array) + { + (val_x, val_y, val_sample_weight) = validation_data; + // Because evaluate calls call_test_batch_end, this interferes with our output on the screen + // so we need to pass a is_val parameter to stop on_test_batch_end + val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); + + } + else + { + (val_x_array, val_y, val_sample_weight, _) = validation_data; + val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); + } + foreach (var log in val_logs) + { + logs["val_" + log.Key] = log.Value; + } + // because after evaluate, logs add some new log which we need to print + callbacks.on_train_batch_end(End_step, logs); + } + + callbacks.on_epoch_end(epoch, logs); + + GC.Collect(); + if (stop_training) + { + break; + } + } + + return callbacks.History; + } + + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs new file mode 100644 index 000000000..0e33b14e3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs @@ -0,0 +1,35 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Metrics; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public IEnumerable metrics + { + get + { + var _metrics = new List(); + + if (_is_compiled) + { + if (compiled_loss != null) + _metrics.add(compiled_loss.metrics); + if (compiled_metrics != null) + _metrics.add(compiled_metrics.metrics); + } + + /*foreach (var layer in _flatten_layers()) + _metrics.extend(layer.metrics);*/ + + return _metrics; + } + } + + void reset_metrics() + { + foreach (var metric in metrics) + metric.reset_states(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs new file mode 100644 index 000000000..e3a5aba68 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -0,0 +1,129 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; +using static Tensorflow.Binding; +using Tensorflow.Keras.Callbacks; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public Tensors predict(IDatasetV2 dataset, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = dataset, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + return PredictInternal(data_handler, verbose); + } + + /// + /// Generates output predictions for the input samples. + /// + /// Input samples + /// Number of samples per batch + /// Verbosity mode + /// + /// Total number of steps (batches of samples) + /// before declaring the prediction round finished. + /// + /// + /// + /// + /// + public Tensors predict(Tensors x, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + X = x, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + return PredictInternal(data_handler, verbose); + } + + Tensors PredictInternal(DataHandler data_handler, int verbose) + { + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = 1, + Steps = data_handler.Inferredsteps + }); + + Tensors batch_outputs = null; + _predict_counter.assign(0); + callbacks.on_predict_begin(); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach (var step in data_handler.steps()) + { + callbacks.on_predict_batch_begin(step); + var tmp_batch_outputs = run_predict_step(iterator); + if (batch_outputs == null) + { + batch_outputs = tmp_batch_outputs; + } + else + { + for (int i = 0; i < batch_outputs.Length; i++) + batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); + } + var end_step = step + data_handler.StepIncrement; + callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } }); + GC.Collect(); + } + } + + callbacks.on_predict_end(); + + return batch_outputs; + } + + Tensors run_predict_step(OwnedIterator iterator) + { + var data = iterator.next(); + var outputs = predict_step(data); + tf_with(ops.control_dependencies(Array.Empty()), ctl => _predict_counter.assign_add(1)); + return outputs; + } + + Tensors predict_step(Tensors data) + { + return Apply(data, training: false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs new file mode 100644 index 000000000..a3956cccc --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; +using Tensorflow.Functions; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.ModelSaving; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + ModelSaver saver = new ModelSaver(); + + /// + /// Saves the model to Tensorflow SavedModel or a single HDF5 file. + /// + /// + /// + /// + public void save(string filepath, + bool overwrite = true, + bool include_optimizer = true, + string save_format = "tf", + SaveOptions? options = null, + ConcreteFunction? signatures = null, + bool save_traces = true) + { + if (save_format != "tf") + { + saver.save(this, filepath); + } + else + { + using (SharedObjectSavingScope.Enter()) + { + KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Summary.cs b/src/TensorFlowNET.Keras/Engine/Model.Summary.cs new file mode 100644 index 000000000..830aee962 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Summary.cs @@ -0,0 +1,17 @@ +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + /// + /// Prints a string summary of the network. + /// + public void summary(int line_length = -1, float[] positions = null) + { + layer_utils.print_summary(this, + line_length: line_length, + positions: positions); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs new file mode 100644 index 000000000..8f1ec808c --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -0,0 +1,111 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Gradients; +using Tensorflow.Keras.Engine.DataAdapters; +using Tensorflow.Keras.Optimizers; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + Dictionary train_step_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + // whether have sample_weight + var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : + train_step(data_handler, data[0], data[1], data[2]); + tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + return outputs; + } + + Dictionary train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + var outputs = data.Length == 2 ? + train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : + train_step( + data_handler, + new Tensors(data.Take(x_size).ToArray()), + new Tensors(data.Skip(x_size).Take(x_size).ToArray()), + new Tensors(data.Skip(2 * x_size).ToArray())); + tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + return outputs; + } + + /// + /// The logic for one training step. + /// + /// + /// + /// + /// + Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y) + { + (x, y) = data_handler.DataAdapter.Expand1d(x, y); + using var tape = tf.GradientTape(); + var y_pred = Apply(x, training: true); + var loss = compiled_loss.Call(y, y_pred); + + // For custom training steps, users can just write: + // trainable_variables = self.trainable_variables + // gradients = tape.gradient(loss, trainable_variables) + // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + // The _minimize call does a few extra steps unnecessary in most cases, + // such as loss scaling and gradient clipping. + _minimize(tape, optimizer, loss, TrainableVariables); + compiled_metrics.update_state(y, y_pred); + + var dict = new Dictionary(); + metrics.ToList().ForEach(x => + { + var r = x.result(); + if (r.ndim > 0) + { + r = tf.reduce_mean(r); + } + dict[x.Name] = (float)r; + }); + return dict; + } + Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) + { + (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); + using var tape = tf.GradientTape(); + var y_pred = Apply(x, training: true); + var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); + + // For custom training steps, users can just write: + // trainable_variables = self.trainable_variables + // gradients = tape.gradient(loss, trainable_variables) + // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + // The _minimize call does a few extra steps unnecessary in most cases, + // such as loss scaling and gradient clipping. + _minimize(tape, optimizer, loss, TrainableVariables); + compiled_metrics.update_state(y, y_pred); + + var dict = new Dictionary(); + metrics.ToList().ForEach(x => + { + var r = x.result(); + if (r.ndim > 0) + { + r = tf.reduce_mean(r); + } + dict[x.Name] = (float)r; + }); + return dict; + } + + void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List trainable_variables) + { + var gradients = tape.gradient(loss, trainable_variables); + gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); + gradients = optimizer.clip_gradients(gradients); + + optimizer.apply_gradients(zip(gradients, trainable_variables), + experimental_aggregate_gradients: false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs new file mode 100644 index 000000000..457b3d694 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.Text; +using HDF.PInvoke; +using HDF5CSharp; +using static Tensorflow.Binding; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + static Dictionary> weightsCache + = new Dictionary>(); + + public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) + { + // Get from cache + if (weightsCache.ContainsKey(filepath)) + { + var filtered_layers = new List(); + foreach (var layer in Layers) + { + var weights = hdf5_format._legacy_weights(layer); + if (weights.Count > 0) + filtered_layers.append(layer); + } + + var weight_value_tuples = new List<(IVariableV1, NDArray)>(); + filtered_layers.Select((layer, i) => + { + var symbolic_weights = hdf5_format._legacy_weights(layer); + foreach(var weight in symbolic_weights) + { + var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2; + weight_value_tuples.Add((weight, weight_value)); + } + return layer; + }).ToList(); + + keras.backend.batch_set_value(weight_value_tuples); + return; + } + + long fileId = Hdf5.OpenFile(filepath, true); + if(fileId < 0) + { + tf_output_redirect.WriteLine($"Can't find weights file {filepath}"); + return; + } + bool msuccess = Hdf5.GroupExists(fileId, "model_weights"); + bool lsuccess = Hdf5.GroupExists(fileId, "layer_names"); + + if (!lsuccess && msuccess) + fileId = H5G.open(fileId, "model_weights"); + + if (by_name) + //fdf5_format.load_weights_from_hdf5_group_by_name(); + throw new NotImplementedException(""); + else + { + var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); + Hdf5.CloseFile(fileId); + + weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList(); + keras.backend.batch_set_value(weight_value_tuples); + } + } + + public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null) + { + long fileId = Hdf5.CreateFile(filepath); + hdf5_format.save_weights_to_hdf5_group(fileId, Layers); + Hdf5.CloseFile(fileId); + } + } +} + diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs new file mode 100644 index 000000000..7b35d5477 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -0,0 +1,203 @@ +using System.Diagnostics; +using Tensorflow.Common.Types; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; +using Tensorflow.Util; + +namespace Tensorflow.Keras.Engine +{ + /// + /// `Model` groups layers into an object with training and inference features. + /// + public partial class Model : Layer, IModel + { +#pragma warning disable CS0169 // The field 'Model._cloning' is never used + bool _cloning; +#pragma warning restore CS0169 // The field 'Model._cloning' is never used +#pragma warning disable CS0108 // Member hides inherited member; missing new keyword +#pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used + bool _is_compiled; +#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used +#pragma warning restore CS0108 // Member hides inherited member; missing new keyword + ILossFunc loss; + IOptimizer optimizer; + IVariableV1 _steps_per_execution; + protected bool _is_graph_network; + public Tensors inputs; + protected Tensors outputs; + protected List input_names; + public string[] output_names; + IVariableV1 _train_counter; + IVariableV1 _test_counter; + IVariableV1 _predict_counter; + bool _base_model_initialized; + bool stop_training; + TensorSpec _saved_model_inputs_spec; + + public bool IsGraphNetwork => _is_graph_network; + + public IOptimizer Optimizer + { + get => optimizer; + set => optimizer = value; + } + + public bool Stop_training + { + get => stop_training; + set => stop_training = value; + } + + public Model(ModelArgs args) + : base(args) + { + _init_batch_counters(); + } + + public void _set_inputs(TensorSpec inputs) + { + _set_save_spec(inputs); + } + + internal void _set_save_spec(TensorSpec inputs) + { + if(_saved_model_inputs_spec is not null) + { + return; + } + var input_names = this.input_names; + if(input_names is null || input_names.Count == 0) + { + input_names = compile_utils.create_pseudo_input_names(inputs); + } + + var flat_inputs = nest.flatten(inputs); + List specs = new(); + foreach(var (name, tensor) in zip(input_names, flat_inputs)) + { + specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); + } + var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; + Debug.Assert(specs is not null); + _saved_model_inputs_spec = packed_specs; + if(this is Sequential && _buildInputShape is null) + { + _buildInputShape = nest.map_structure(x => x is null ? null : x.shape, packed_specs); + } + } + + internal override void Initialize(LayerArgs args) + { + _init_batch_counters(); + base.Initialize(args); + } + + void _configure_steps_per_execution(int steps_per_execution) + { + _steps_per_execution = tf.Variable(steps_per_execution, + dtype: TF_DataType.TF_INT64, + aggregation: VariableAggregation.OnlyFirstReplica); + } + + void _reset_compile_cache() + { + // Used to cache `trainable` attr of `Layer`s for `fit`. + _compiled_trainable_state = _get_trainable_state(); + keras.backend._GRAPH = null; + } + + void _init_batch_counters() + { + _train_counter = tf.Variable(0L, + dtype: TF_DataType.TF_INT64, + aggregation: VariableAggregation.OnlyFirstReplica); + + _test_counter = tf.Variable(0L, + dtype: TF_DataType.TF_INT64, + aggregation: VariableAggregation.OnlyFirstReplica); + + _predict_counter = tf.Variable(0L, + dtype: TF_DataType.TF_INT64, + aggregation: VariableAggregation.OnlyFirstReplica); + } + + public override List Layers + => _flatten_layers(recursive: false, include_self: false).ToList(); + + public override List TrainableWeights + { + get + { + // skip the assertion of weights created. + var variables = new List(); + + if (!Trainable) + { + return variables; + } + + foreach (var trackable_obj in _self_tracked_trackables) + { + if (trackable_obj.Trainable) + variables.AddRange(trackable_obj.TrainableWeights); + } + + variables.AddRange(_trainable_weights); + + return variables.Distinct().ToList(); + } + } + + public override List NonTrainableWeights + { + get + { + // skip the assertion of weights created. + var variables = new List(); + + foreach (var trackable_obj in _self_tracked_trackables) + { + variables.AddRange(trackable_obj.NonTrainableWeights); + } + + if (!Trainable) + { + var trainable_variables = new List(); + foreach (var trackable_obj in _self_tracked_trackables) + { + variables.AddRange(trackable_obj.TrainableWeights); + } + variables.AddRange(trainable_variables); + variables.AddRange(_trainable_weights); + variables.AddRange(_non_trainable_weights); + } + + return variables.Distinct().ToList(); + } + } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + if(save_type == SaveType.SAVEDMODEL) + { + //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. + } + var children = base._trackable_children(save_type, cache); + return children; + } + + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with "_self_setattr_tracking". + //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) + //{ + // this._base_model_initialized; + //} + base.SetAttr(name, value); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs b/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs new file mode 100644 index 000000000..5da2fa44f --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Engine +{ + public partial class Node + { + public ILayer[] InboundLayers + => iterate_inbound().Select(x => x.Item1).ToArray(); + + public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound() + { + foreach (var kt in KerasInputs) + { + var (layer, node_index, tensor_index) = kt.KerasHistory; + yield return (layer, node_index, tensor_index, kt); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs new file mode 100644 index 000000000..7c8c805bf --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Node + { + /// + /// Serializes `Node` for Functional API's `get_config`. + /// + /// + public List serialize(Func make_node_key, Dictionary node_conversion_map) + { + return KerasInputs.Select(x => { + var kh = x.KerasHistory; + var node_key = make_node_key(kh.Layer.Name, kh.NodeIndex); + var new_node_index = node_conversion_map.Get(node_key, 0); + return new NodeConfig + { + Name = kh.Layer.Name, + NodeIndex = new_node_index, + TensorIndex = kh.TensorIndex + }; + }).ToList(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Node.cs b/src/TensorFlowNET.Keras/Engine/Node.cs new file mode 100644 index 000000000..bb34da6b3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Node.cs @@ -0,0 +1,113 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + /// + /// A `Node` describes the connectivity between two layers. + /// + /// Each time a layer is connected to some new input, + /// a node is added to `layer._inbound_nodes`. + /// Each time the output of a layer is used by another layer, + /// a node is added to `layer._outbound_nodes`. + /// + public partial class Node : INode + { + NodeArgs args; + + public Tensors input_tensors => is_input ? Outputs : args.InputTensors; + public Tensors Outputs => args.Outputs; + public List KerasInputs { get; set; } = new List(); + ILayer _layer; + public ILayer Layer => _layer; + public bool is_input => args.InputTensors == null; + + public INode[] ParentNodes + { + get + { + var node_deps = new List(); + foreach (var kt in KerasInputs) + { + var (layer, node_index, _) = kt.KerasHistory; + if (layer != null) + node_deps.append(layer.InboundNodes[node_index]); + } + return node_deps.ToArray(); + } + } + + public Node(NodeArgs args) + { + this.args = args; + } + + public void Connect(Layer layer) + { + _layer = layer; + + if (args.InputTensors != null) + KerasInputs.AddRange(args.InputTensors); + + // Wire up Node to Layers. + layer.InboundNodes.Add(this); + + foreach (var kt in KerasInputs) + { + if (kt.KerasHistory == null) + continue; + var (inbound_layer, _, _) = kt.KerasHistory; + if (inbound_layer != null) + inbound_layer.OutboundNodes.Add(this); + } + + // Set metadata on outputs. + var node_index = layer.InboundNodes.Count - 1; + foreach (var (i, tensor) in enumerate(Outputs)) + tensor.KerasHistory = new KerasHistory(layer, node_index, i); + } + + /// + /// Maps Keras Tensors to computed Tensors using `tensor_dict`. + /// + /// + /// + public Tensors MapArguments(Dictionary> tensor_dict) + { + if (KerasInputs.Count() == 1) + { + var kt_id = KerasInputs[0].Id; + return tensor_dict[kt_id].Dequeue(); + } + else + { + var flat_arguments = KerasInputs.Select(x => x).ToArray(); + foreach (var (kt_index, kt) in enumerate(KerasInputs)) + flat_arguments[kt_index] = tensor_dict[kt.Id].Dequeue(); + + return flat_arguments; + } + } + + public override string ToString() + => $"{Layer.Name}, {KerasInputs.Count} inputs: {string.Join(",", KerasInputs.Select(x => x.name))}"; + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs new file mode 100644 index 000000000..6a468ad27 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -0,0 +1,222 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; +using static Tensorflow.KerasApi; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Engine +{ + /// + /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. + /// `Sequential` provides training and inference features on this model. + /// + public class Sequential : Functional + { + SequentialArgs args; + + bool _compute_output_and_mask_jointly; + bool _auto_track_sub_layers; + Shape _inferred_input_shape; + bool _has_explicit_input_shape; + bool _graph_initialized; + public Shape output_shape => outputs[0].shape; + List _created_nodes; + + public Sequential(SequentialArgs args) + : base(args.Inputs, args.Outputs, name: args.Name) + { + this.args = args; + // SupportsMasking = true; + _compute_output_and_mask_jointly = true; + _auto_track_sub_layers = false; + _has_explicit_input_shape = false; + _is_graph_network = false; + _created_nodes = new List(); + + // Add to the model any layers passed to the constructor. + if (args.Layers is not null) + { + InitLayers(args.Layers); + } + } + + public void InitLayers(IEnumerable layers) + { + foreach(var layer in layers) + { + // TODO(Rinne): remove it and completely fix issue 1084 + if(layer is Sequential s) + { + s.Layers.ForEach(x => ((Layer)x).enforce_layer_construction()); + } + add(layer); + // TODO(Rinne): remove it and completely fix issue 1084 + if (layer is Sequential s2) + { + s2.Layers.ForEach(x => ((Layer)x).unset_layer_construction()); + } + } + } + + public void add(Tensor tensor) + { + var layer = tensor.KerasHistory.Layer; + add(layer); + } + + /// + /// Adds a layer instance on top of the layer stack. + /// + /// + public void add(ILayer layer) + { + built = false; + var set_inputs = false; + if (_self_tracked_trackables.Count == 0) + { + if (layer is InputLayer) + { + set_inputs = true; + } + else + { + if (layer.BatchInputShape != null) + { + // Instantiate an input layer. + var x = keras.Input( + batch_input_shape: layer.BatchInputShape.ToSingleShape(), + dtype: layer.DType, + name: layer.Name + "_input"); + + // This will build the current layer + // and create the node connecting the current layer + // to the input layer we just created. + layer.Apply(x); + set_inputs = true; + } + } + + if (set_inputs) + { + // If an input layer (placeholder) is available. + outputs = layer.InboundNodes.Last().Outputs; + inputs = layer_utils.get_source_inputs(outputs[0]); + built = true; + _has_explicit_input_shape = true; + } + } + else if (outputs != null) + { + // If the model is being built continuously on top of an input layer: + // refresh its output. + outputs = layer.Apply(outputs); + built = true; + } + + if (set_inputs || _is_graph_network) + { + _init_graph_network(inputs, outputs); + _graph_initialized = true; + } + else + { + _self_tracked_trackables.add(layer); + // TODO(Rinne): self._handle_deferred_layer_dependencies([layer]) + } + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (!_has_explicit_input_shape) + { + _build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype); + } + + if(_graph_initialized) + { + if (!built) + _init_graph_network(this.inputs, outputs); + return base.Call(inputs, state, training); + } + + return base.Call(inputs, state, training); + } + + void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) + { + if (_inferred_input_shape == input_shape) + return; + + ops.init_scope(); + var inputs = keras.Input(batch_input_shape: input_shape, + dtype: input_dtype, + name: _self_tracked_trackables[0].Name.EndsWith("_input") ? _self_tracked_trackables[0].Name : $"{_self_tracked_trackables[0].Name}_input"); + Tensors layer_input = inputs; + Tensors layer_output = null; + Tensors outputs = null; + List created_nodes = new List(); + foreach (var layer in Layers) + { + clear_previously_created_nodes(layer, _created_nodes); + layer_output = layer.Apply(layer_input); + // Keep track of nodes just created above + track_nodes_created_by_last_call(layer, created_nodes); + layer_input = layer_output; + outputs = layer_output; + } + _created_nodes = created_nodes; + _init_graph_network(inputs, outputs); + _graph_initialized = true; + _inferred_input_shape = input_shape; + } + + void clear_previously_created_nodes(ILayer layer, List created_nodes) + { + foreach(var node in layer.InboundNodes) + { + foreach(var prev_layer in node.InboundLayers) + { + var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); + prev_layer.OutboundNodes.Clear(); + prev_layer.OutboundNodes.AddRange(outNodes); + } + } + + var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); + layer.InboundNodes.Clear(); + layer.InboundNodes.AddRange(inNodes); + } + + void track_nodes_created_by_last_call(ILayer layer, List created_nodes) + { + var node = layer.InboundNodes.Last(); + created_nodes.Add(node); + foreach(var prev_layer in node.InboundLayers) + { + created_nodes.add(prev_layer.OutboundNodes.Last()); + } + } + + public override List Layers + => base.Layers.Where(x => x is not InputLayer).ToList(); + } +} diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs new file mode 100644 index 000000000..85cd9194c --- /dev/null +++ b/src/TensorFlowNET.Keras/GlobalUsing.cs @@ -0,0 +1,8 @@ +global using System; +global using System.Collections.Generic; +global using System.Text; +global using System.Linq; +global using static Tensorflow.Binding; +global using static Tensorflow.KerasApi; +global using Tensorflow.NumPy; +global using Tensorflow.Keras.Engine; \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/GraphLearningPhase.cs b/src/TensorFlowNET.Keras/GraphLearningPhase.cs new file mode 100644 index 000000000..6f833e06b --- /dev/null +++ b/src/TensorFlowNET.Keras/GraphLearningPhase.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras +{ + public enum GraphLearningPhase + { + train_mode = 1, + test_mode = 0 + } +} diff --git a/src/TensorFlowNET.Keras/ImageDataFormat.cs b/src/TensorFlowNET.Keras/ImageDataFormat.cs new file mode 100644 index 000000000..d32849fe4 --- /dev/null +++ b/src/TensorFlowNET.Keras/ImageDataFormat.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras +{ + public enum ImageDataFormat + { + channels_last, + channels_first + } +} diff --git a/src/TensorFlowNET.Keras/InitializersApi.cs b/src/TensorFlowNET.Keras/InitializersApi.cs new file mode 100644 index 000000000..d6dfa51be --- /dev/null +++ b/src/TensorFlowNET.Keras/InitializersApi.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Operations.Initializers; + +namespace Tensorflow.Keras; + +public partial class InitializersApi : IInitializersApi +{ + /// + /// He normal initializer. + /// + /// + /// + public IInitializer HeNormal(int? seed = null) + { + return new VarianceScaling(scale: 2.0f, mode: "fan_in", seed: seed); + } + + public IInitializer Orthogonal(float gain = 1.0f, int? seed = null) + => new Orthogonal(gain: gain, seed: seed); +} diff --git a/src/TensorFlowNET.Keras/IsExternalInit.cs b/src/TensorFlowNET.Keras/IsExternalInit.cs new file mode 100644 index 000000000..11f062fa8 --- /dev/null +++ b/src/TensorFlowNET.Keras/IsExternalInit.cs @@ -0,0 +1,4 @@ +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} diff --git a/src/TensorFlowNET.Keras/KerasApi.cs b/src/TensorFlowNET.Keras/KerasApi.cs new file mode 100644 index 000000000..69c59ab82 --- /dev/null +++ b/src/TensorFlowNET.Keras/KerasApi.cs @@ -0,0 +1,12 @@ +using Tensorflow.Keras; + +namespace Tensorflow +{ + /// + /// Deprecated, will use tf.keras + /// + public static class KerasApi + { + public static KerasInterface keras { get; } = KerasInterface.Instance; + } +} diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs new file mode 100644 index 000000000..6bc381095 --- /dev/null +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Datasets; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Models; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Utils; +using System.Threading; +using Tensorflow.Framework.Models; + +namespace Tensorflow.Keras +{ + public class KerasInterface : IKerasApi + { + private static KerasInterface _instance = null; + private static readonly object _lock = new object(); + + public static KerasInterface Instance + { + get + { + lock (_lock) + { + if (_instance is null) + { + _instance = new KerasInterface(); + } + return _instance; + } + } + } + + static KerasInterface() + { + RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer()); + } + + public KerasDataset datasets { get; } = new KerasDataset(); + public IInitializersApi initializers { get; } = new InitializersApi(); + public Regularizers regularizers { get; } = new Regularizers(); + public ILayersApi layers { get; } = new LayersApi(); + public ILossesApi losses { get; } = new LossesApi(); + public IActivationsApi activations { get; } = new Activations(); + public Preprocessing preprocessing { get; } = new Preprocessing(); + ThreadLocal _backend = new ThreadLocal(() => new BackendImpl()); + public BackendImpl backend => _backend.Value; + public IOptimizerApi optimizers { get; } = new OptimizerApi(); + public IMetricsApi metrics { get; } = new MetricsApi(); + public IModelsApi models { get; } = new ModelsApi(); + public KerasUtils utils { get; } = new KerasUtils(); + + public Sequential Sequential(List layers = null, + string name = null) + => new Sequential(new SequentialArgs + { + Layers = layers, + Name = name + }); + + public Sequential Sequential(params ILayer[] layers) + => new Sequential(new SequentialArgs + { + Layers = layers.ToList() + }); + + /// + /// `Model` groups layers into an object with training and inference features. + /// + /// + /// + /// + public IModel Model(Tensors inputs, Tensors outputs, string name = null) + => new Functional(inputs, outputs, name: name); + + /// + /// Instantiate a Keras tensor. + /// + /// + /// + /// + /// + /// + /// A boolean specifying whether the placeholder to be created is sparse. + /// + /// + /// A boolean specifying whether the placeholder to be created is ragged. + /// + /// + /// Optional existing tensor to wrap into the `Input` layer. + /// If set, the layer will not create a placeholder tensor. + /// + /// + public Tensors Input(Shape shape = null, + int batch_size = -1, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name, + dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs new file mode 100644 index 000000000..23f36c862 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + /// + /// ELU Layer: + /// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere + /// + public class ELU : Layer + { + ELUArgs args; + float alpha => args.Alpha; + public ELU(ELUArgs args) : base(args) + { + this.args = args; + } + + public override void build(KerasShapesWrapper input_shape) + { + if (alpha < 0f) + { + throw new ValueError("Alpha must be a number greater than 0."); + } + base.build(input_shape); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor output = inputs; + output = tf.where(output > 0f, output, + tf.multiply(alpha, tf.sub(tf.exp(output), 1f))); + return output; + } + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs new file mode 100644 index 000000000..81fefb314 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers { + public class Exponential : Layer + { + public Exponential(LayerArgs args) : base(args) + { + // Exponential has no args + } + public override void build(KerasShapesWrapper input_shape) + { + base.build(input_shape); + } + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor output = inputs; + return tf.exp(output); + } + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs new file mode 100644 index 000000000..e0f91380b --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + public class HardSigmoid : Layer { + public HardSigmoid ( LayerArgs args ) : base(args) { + // hard sigmoid has no arguments + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) { + Tensor x = inputs; + return tf.clip_by_value( + tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs new file mode 100644 index 000000000..cfbd0186d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Leaky version of a Rectified Linear Unit. + /// + public class LeakyReLu : Layer + { + LeakyReLuArgs args; + float alpha => args.Alpha; + public LeakyReLu(LeakyReLuArgs args) : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return tf.nn.leaky_relu(inputs, alpha: alpha); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs b/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs new file mode 100644 index 000000000..5af3f7677 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Leaky version of a Rectified Linear Unit. + /// + public class ReLu6 : Layer + { + public ReLu6() : base(new LayerArgs { }) + { + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return tf.nn.relu6(inputs); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs new file mode 100644 index 000000000..2e943d5f7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + /// + /// SELU Layer: + /// similar to ELU, but has pre-defined alpha and scale + /// + public class SELU : Layer { + protected const float alpha = 1.67326324f, scale = 1.05070098f; + public SELU ( LayerArgs args ) : base(args) { + // SELU has no arguments + } + public override void build(KerasShapesWrapper input_shape) { + if ( alpha < 0f ) { + throw new ValueError("Alpha must be a number greater than 0."); + } + base.build(input_shape); + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor output = inputs; + return tf.where(output > 0f, + tf.multiply(scale, output), + tf.multiply(scale, tf.multiply(alpha, tf.sub(tf.exp(output), 1f)))); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs new file mode 100644 index 000000000..d018128d5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + public class Softmax : Layer { + Axis axis; + public Softmax ( SoftmaxArgs args ) : base(args) { + axis = args.axis; + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor x = inputs.Length == 2 ? inputs[0] + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) + : inputs; + Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); + Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); + return tf.div(e, s); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs new file mode 100644 index 000000000..1e6c59b42 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + public class Softplus : Layer { + public Softplus ( LayerArgs args ) : base(args) { + // Softplus has no arguments + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor x = inputs; + return tf.log( + tf.add(tf.exp(x), 1f)); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs new file mode 100644 index 000000000..5ad33e99d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + public class Softsign : Layer { + public Softsign ( LayerArgs args ) : base(args) { + // Softsign has no arguments + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor x = inputs; + // x / (abs(x) + 1) + return tf.div(x, tf.add(1f, tf.abs(x))); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs new file mode 100644 index 000000000..ed0d105a6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers { + public class Swish : Layer { + public Swish ( LayerArgs args ) : base(args) { + // Swish has no arguments + } + protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { + Tensor x = inputs; + + // x / (1 + exp(-x)) + return tf.div(x, (tf.add(1f, tf.exp(tf.negative(x))))); + } + public override Shape ComputeOutputShape ( Shape input_shape ) { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs new file mode 100644 index 000000000..7e90cf9d8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Tanh : Layer + { + public Tanh(LayerArgs args) : base(args) + { + // Tanh has no arguments + } + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor x = inputs; + + return tf.tanh(x); + } + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs new file mode 100644 index 000000000..e6a8e1a63 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -0,0 +1,161 @@ +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Dot-product attention layer, a.k.a. Luong-style attention. + /// Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of + /// shape `[batch_size, Tv, dim]` and `key` tensor of shape + /// `[batch_size, Tv, dim]`. The calculation follows the steps: + /// + /// 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot + /// product: `scores = tf.matmul(query, key, transpose_b=True)`. + /// + /// + /// 2. Use scores to calculate a distribution with shape + /// `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. + /// + /// + /// 3. Use `distribution` to create a linear combination of `value` with + /// shape `[batch_size, Tq, dim]`: + /// `return tf.matmul(distribution, value)`. + /// + /// + /// 0 + /// + /// //Variable-length int sequences. + /// var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + /// var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); + /// // Embedding lookup. + /// var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); + /// // Query embeddings of shape [batch_size, Tq, dimension]. + /// var query_embeddings = token_embedding.Apply(query_input); + /// // Value embeddings of shape [batch_size, Tv, dimension]. + /// var value_embeddings = token_embedding.Apply(value_input); + /// // CNN layer. + /// var cnn_layer = keras.layers.Conv1D( + /// filters: 100, + /// kernel_size: 4, + /// // Use 'same' padding so outputs have the same shape as inputs. + /// padding: "same"); + /// var cnn_layer2 = keras.layers.Conv1D( + /// filters: 100, + /// kernel_size: 4, + /// // Use 'same' padding so outputs have the same shape as inputs. + /// padding: "same"); + /// // Query encoding of shape [batch_size, Tq, filters]. + /// var query_seq_encoding = cnn_layer.Apply(query_embeddings); + /// // Value encoding of shape [batch_size, Tv, filters]. + /// var value_seq_encoding = cnn_layer.Apply(value_embeddings); + /// // Query-value attention of shape [batch_size, Tq, filters]. + /// var query_value_attention_seq = keras.layers.Attention().Apply( + /// (query_seq_encoding, value_seq_encoding)); + /// // Reduce over the sequence axis to produce encodings of shape + /// // [batch_size, filters]. + /// var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( + /// query_seq_encoding); + /// var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( + /// query_value_attention_seq); + /// // Concatenate query and document encodings to produce a DNN input layer. + /// var input_layer = keras.layers.Concatenate().Apply( + /// (query_encoding, query_value_attention)); + /// // Add DNN layers, and create Model. + /// // ... + /// + /// + public class Attention : BaseDenseAttention + { + + public IVariableV1 concat_score_weight; + + public IVariableV1 scale; + + AttentionArgs args; + + string score_mode { get => args.score_mode; } + + bool use_scale { get => args.use_scale; } + + public Attention(AttentionArgs args) : base(args) + { + this.args = args; + if (!new List { + "dot", + "concat" + }.Contains(this.score_mode)) + throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]"); + } + + // Creates variable when `use_scale` is True or `score_mode` is `concat`. + public override void build(KerasShapesWrapper input_shape) + { + if (this.use_scale) + this.scale = this.add_weight(name: "scale", + shape: 1, + initializer: tf.ones_initializer, + dtype: this.DType, + trainable: true); + else + this.scale = null; + + if (this.score_mode == "concat") + this.concat_score_weight = this.add_weight(name: "concat_score_weight", + shape: 1, + initializer: tf.ones_initializer, + dtype: this.DType, + trainable: true); + else + this.concat_score_weight = null; + base.build(input_shape); + } + + /// + /// Calculates attention scores as a query-key dot product. + /// + /// query: Query tensor of shape `[batch_size, Tq, dim]`. + /// key: Key tensor of shape `[batch_size, Tv, dim]`. + /// Tensor of shape `[batch_size, Tq, Tv]`. + public override Tensor _calculate_scores(Tensor query, Tensor key) + { + Tensor scores = null; + if (this.score_mode == "dot") + { + //scores = tf.matmul(query, key, transpose_b: true); + //scores = tf.matmul(tf.squeeze(query),tf.squeeze(key), transpose_b: true); + scores = tf.linalg.einsum("bij,bkj->bik", (query, key)); + if (this.scale != null) + scores *= this.scale.AsTensor(); + } else if (this.score_mode == "concat") { + // Reshape tensors to enable broadcasting. + // Reshape into [batch_size, Tq, 1, dim]. + var q_reshaped = tf.expand_dims(query, axis: -2); + // Reshape into [batch_size, 1, Tv, dim]. + var k_reshaped = tf.expand_dims(key, axis: -3); + if (this.scale != null) + scores = this.concat_score_weight.AsTensor() * + tf.reduce_sum(tf.tanh(this.scale.AsTensor() * (q_reshaped + k_reshaped)), axis: -1); + else + scores = this.concat_score_weight.AsTensor() * + tf.reduce_sum(tf.tanh(q_reshaped + k_reshaped), axis: -1); + } + return scores; + } + + public override IKerasConfig get_config() => this.args; + //var config = new Dictionary { + // { + // "use_scale", + // this.use_scale}, + // { + // "score_mode", + // this.score_mode}}; + //var base_config = base.get_config(); + //return new dict(base_config.items().ToList() + config.items().ToList()); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs new file mode 100644 index 000000000..970a938d2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -0,0 +1,253 @@ +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.ArgsDefinition; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + + /// + /// Base Attention class for Dense networks. + /// This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. + /// Attention is formed by three tensors: Query, Key and Value. + /// This class is suitable for Dense or CNN networks, and not for RNN networks. + /// Implementations of attention mechanisms should inherit from this class, and + /// reuse the `apply_attention_scores()` method. + /// + public class BaseDenseAttention : Layer + { + + BaseDenseAttentionArgs args; + + bool causal { get => args.causal; } + + float dropout { get => args.dropout; } + + protected bool supports_masking; + + public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args) + { + this.args = args; + this.supports_masking = true; + } + + /// + /// Calculates attention scores. + /// + /// query: Query tensor of shape `[batch_size, Tq, dim]`. + /// key: Key tensor of shape `[batch_size, Tv, dim]`. + /// Tensor of shape `[batch_size, Tq, Tv]`. + public virtual Tensor _calculate_scores(Tensor query, Tensor key) => + throw new NotImplementedException(""); + + /// + /// Applies attention scores to the given value tensor. + /// To use this method in your attention layer, follow the steps: + /// + /// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape + /// `[batch_size, Tv]` to calculate the attention `scores`. + /// + /// + /// * Pass `scores` and `value` tensors to this method. The method applies + /// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then + /// returns `matmul(attention_distribution, value). + /// + /// + /// * Apply `query_mask` and return the result. + /// + /// + /// Scores float tensor of shape `[batch_size, Tq, Tv]`. + /// Value tensor of shape `[batch_size, Tv, dim]`. + /// + /// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or + /// [batch_size, Tq, Tv]`. If given, scores at positions where + /// `scores_mask==False` do not contribute to the result. It must contain + /// at least one `True` value in each line along the last dimension. + /// + /// + /// Boolean indicating whether the layer should behave in + /// training mode (adding dropout) or in inference mode (no dropout). + /// + /// + /// + /// Tensor of shape `[batch_size, Tq, dim]`. + /// + /// + /// Attention scores after masking and softmax with shape + /// [batch_size, Tq, Tv]`. + /// + /// + public (Tensor, Tensor) _apply_scores(Tensor scores, + Tensor value, + Tensor scores_mask = null, + bool? training = null) + { + if (scores_mask != null) + { + var padding_mask = tf.logical_not(scores_mask); + // Bias so padding positions do not contribute to attention distribution. + // Note 65504. is the max float16 value. + if (scores.dtype == tf.float16) + scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype); + else + scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype); + } + bool _training; + training ??= false; // TODO: Delete this line when backend.learning_phase is available + if (training == null) + _training = keras.backend.learning_phase() == + Tensorflow.Keras.GraphLearningPhase.train_mode ? + true : false; + else _training = training.Value; + var weights = tf.nn.softmax(scores); + Func dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout); + weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights)); + //return (tf.matmul(weights, value), weights); + return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensors _inp; + Tensors _mask = null; + + int count = inputs.Count(); + if (count < 2 || count > 6) throw new ValueError( + $"{ this.name } layer accepts inputs list of length from 2 to 6, " + + $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + + $"Received length: {count}."); + + bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; + bool return_attention_scores = false; + if (has_bool) + { + return_attention_scores = (bool)inputs[count - 1]; + count--; + } + + switch (count) + { + case 2: + _inp = (inputs[0], inputs[1]); + break; + case 3: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + break; + case 4: + if (inputs[0].shape == inputs[2].shape) + if (inputs[1].shape == inputs[3].shape) + { + _inp = new[] { inputs[0], inputs[1] }; + _mask = new[] { inputs[2], inputs[3] }; + break; + } + throw new ValueError(); //TODO:Add discriptions for this err + case 5: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + _mask = (inputs[3], inputs[4]); + break; + default: + throw new ValueError(); //TODO:Add discriptions for this err + } + + return call(_inp, _mask, training, return_attention_scores); + } + + protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false) + { + Tensor causal_mask; + //this._validate_call_args(inputs: inputs, mask: mask); + var q = inputs[0]; + var v = inputs[1]; + var k = inputs.Count() > 2 ? inputs[2] : v; + var q_mask = mask != null ? mask[0] : null; + var v_mask = mask != null ? mask[1] : null; + var scores = this._calculate_scores(query: q, key: k); + if (v_mask != null) + // Mask of shape [batch_size, 1, Tv]. + v_mask = tf.expand_dims(v_mask, axis: -2); + if (this.causal) + { + // Creates a lower triangular mask, so position i cannot attend to + // positions j>i. This prevents the flow of information from the future + // into the past. + var scores_shape = tf.shape(scores); + // causal_mask_shape = [1, Tq, Tv]. + var causal_mask_shape = tf.concat(new List { + tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})), + tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0) + }, axis: 0); + var _causal_mask_shape = new Shape(causal_mask_shape.ToArray()); + causal_mask = _lower_triangular_mask(_causal_mask_shape); + } + else + causal_mask = null; + var scores_mask = _merge_masks(v_mask, causal_mask); + var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training); + if (q_mask != null) + { + // Mask of shape [batch_size, Tq, 1]. + q_mask = tf.expand_dims(q_mask, axis: -1); + result *= tf.cast(q_mask, dtype: result.dtype); + } + if (return_attention_scores) + return new Tensors(result, attention_scores); + return result; + } + + public Tensor compute_mask(Tensors inputs, Tensors mask = null) + { + this._validate_call_args(inputs: inputs, mask: mask); + if (mask != null) + { + var q_mask = mask[0]; + if (q_mask == null) + return null; + return tf.convert_to_tensor(q_mask); + } + return null; + } + + //public Shape compute_output_shape(Shape input_shape) { + // // return_attention_scores argument of BaseDenseAttention.call method + // // is ignored. Output shape of attention_scores cannot be returned. + // return input_shape[0]; + //} + + /// + /// Validates arguments of the call method. + /// + public void _validate_call_args(Tensors inputs, Tensors mask) + { + if (inputs.Count() < 2 || inputs.Count() > 3) + throw new ValueError( + $"{this.name} layer accepts inputs list of length 2 or 3, " + + $"namely [query, value] or [query, value, key]. Received length: {len(inputs)}."); + if (mask != null) + if (mask.Count() < 2 || mask.Count() > inputs.Count()) + throw new ValueError($"{this.name} layer mask must be a list of length 2, " + + $"namely [query_mask, value_mask]. Received length: {len(mask)}."); + } + + public static Tensor _lower_triangular_mask(Shape shape) + { + var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2); + var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1); + return tf.greater_equal(row_index, col_index); + } + + public static Tensor _merge_masks(Tensor x, Tensor y) + { + if (x == null) + return y; + if (y == null) + return x; + return tf.logical_and(x, y); + } + + public override IKerasConfig get_config() => this.args; + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs new file mode 100644 index 000000000..75dd4a41a --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -0,0 +1,357 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Core; +using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System; +using System.Linq; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class MultiHeadAttention : Layer + { + static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz"; + + MultiHeadAttentionArgs args; + Shape _query_shape = null; + Shape _key_shape = null; + Shape _value_shape = null; + bool _built_from_signature = false; + EinsumDense _query_dense = null; + EinsumDense _key_dense = null; + EinsumDense _value_dense = null; + EinsumDense _output_dense = null; + string _dot_product_equation = ""; + string _combine_equation = ""; + Softmax _softmax = null; + Dropout _dropout_layer = null; + + /// + /// Builds einsum equations for the attention computation. + /// Query, key, value inputs after projection are expected to have the shape as: + /// `(bs, [non-attention dims], [attention dims], num_heads, channels)`. + /// `bs` and `[non-attention dims]` are treated as `[batch dims]`. + /// + /// + /// The attention operations can be generalized: + /// + /// + /// (1) Query-key dot product: + /// `([batch dims], [query attention dims], num_heads, channels), ([batch dims], + /// [key attention dims], num_heads, channels) -> ([batch dim], + /// num_heads, [query attention dims], [key attention dims])` + /// + /// (2) Combination: + /// `([batch dims], num_heads, [query attention dims], [key attention dims]), + /// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims], + /// [query attention dims], num_heads, channels)` + /// + /// + /// Rank of query, key, value tensors. + /// List/tuple of axes, `[-1, rank)`, + /// that attention will be applied to. + /// + public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes) + { + var target_notation = _CHR_IDX.Substring(0, rank); + // `batch_dims` includes the head dim. + // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) + // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value + // use IEnumerable.Except instead of np.delete which is unavailable + var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); + var letter_offset = rank; + var source_notation = ""; + for (int i = 0; i < rank; i++) + { + if (batch_dims.Contains(i) || i == rank - 1) + source_notation += target_notation[i]; + else + { + source_notation += _CHR_IDX[letter_offset]; + letter_offset += 1; + } + } + var product_notation = new string((from i in batch_dims + select target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select target_notation[i]).Concat( + + from i in attn_axes.as_int_list() + select source_notation[i]).ToArray()); + var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; + var attn_scores_rank = product_notation.Count(); + var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; + return (dot_product_equation, combine_equation, attn_scores_rank); + } + + /// + /// Builds an einsum equation for projections inside multi-head attention. + /// + public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims) + { + char _char; + var input_str = ""; + var kernel_str = ""; + var output_str = ""; + var bias_axes = ""; + var letter_offset = 0; + foreach (var i in range(free_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + input_str += _char; + output_str += _char; + } + letter_offset += free_dims; + foreach (var i in range(bound_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + input_str += _char; + kernel_str += _char; + } + letter_offset += bound_dims; + foreach (var i in range(output_dims)) + { + _char = _CHR_IDX[i + letter_offset]; + kernel_str += _char; + output_str += _char; + bias_axes += _char; + } + var equation = $"{input_str},{kernel_str}->{output_str}"; + return (equation, bias_axes, output_str.Count()); + } + + static Shape _get_output_shape(int output_rank, Shape known_last_dims) + => (from _ in range(output_rank - known_last_dims.rank) + select -1).Concat(known_last_dims.as_int_list()).ToArray(); + + public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args) + { + this.args = args; + } + + public void _build_from_signature(Tensor query, Tensor value, Tensor key = null) + => this._build_from_signature(query.shape, value.shape, key?.shape); + + public void _build_from_signature(Shape query, Shape value, Shape key = null) + { + this._built_from_signature = true; + this._query_shape = query; + this._value_shape = value; + if (key == null) + this._key_shape = this._value_shape; + else + this._key_shape = key; + // Any setup work performed only once should happen in an `init_scope` + // to avoid creating symbolic Tensors that will later pollute any eager + // operations. + tf_with(tf.init_scope(), _ => + { + var free_dims = this._query_shape.rank - 1; + var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + free_dims, bound_dims: 1, output_dims: 2); + this._query_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.KeyDim)), + this.args.UseBias ? bias_axes : null, + "query"); + (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + this._key_shape.rank - 1, bound_dims: 1, output_dims: 2); + this._key_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.KeyDim)), + this.args.UseBias ? bias_axes : null, + "key"); + (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); + this._value_dense = _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, + (this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)), + this.args.UseBias ? bias_axes : null, + "value"); + // Builds the attention computations for multi-head dot product attention. + // These computations could be wrapped into the keras attention layer once + // it support mult-head einsum computations. + this._build_attention(output_rank); + this._output_dense = _build_output_dense(free_dims, "attention_output"); + }); + this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense); + } + + EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name) + => new EinsumDense(new EinsumDenseArgs() + { + Equation = equation, + OutputShape = output_shape, + BiasAxes = bias_axes, + Name = name, + KernelInitializer = this.args.KernelInitializer, + BiasInitializer = this.args.BiasInitializer, + KernelRegularizer = this.args.KernelRegularizer, + BiasRegularizer = this.args.BiasRegularizer, + KernelConstraint = this.args.KernelConstraint, + BiasConstraint = this.args.BiasConstraint + }); + + EinsumDense _build_output_dense(int free_dims, string name) + { + if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]); + var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( + free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape)); + return _get_dense(einsum_equation, + _get_output_shape(output_rank - 1, this.args.OutputShape), + this.args.UseBias ? bias_axes : null, + name); + } + + void _build_attention(int rank) + { + if (this.args.AttentionAxis == null) + this.args.AttentionAxis = new(range(1, rank - 2).ToArray()); + int attn_scores_rank; + (this._dot_product_equation, this._combine_equation, attn_scores_rank) + = _build_attention_equation(rank, this.args.AttentionAxis); + var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis), + attn_scores_rank).ToArray(); + this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes }); + this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout }); + } + + Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null) + { + if(attention_mask != null) + { + var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1; + for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++) + attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis); + } + return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask)); + } + + public Tensors _compute_attention( + Tensor query, + Tensor key, + Tensor value, + Tensor attention_mask = null, + bool training = false) + { + // Note: Applying scalar multiply at the smaller end of einsum improves + // XLA performance, but may introduce slight numeric differences in + // the Transformer attention head. + query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim))); + // Take the dot product between "query" and "key" to get the raw + // attention scores. + var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); + attention_scores = this._masked_softmax(attention_scores, attention_mask); + // This is actually dropping out entire tokens to attend to, which might + // seem a bit unusual, but is taken from the original Transformer paper. + var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training); + // `context_layer` = [B, T, N, H] + var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value)); + return (attention_output, attention_scores); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensors _inp; + Tensor _mask = null; + + int count = inputs.Count(); + if (count < 2 || count > 5) throw new ValueError( + $"{ this.name } layer accepts inputs list of length from 2 to 5, " + + $"namely [query, value, (key), (attention_mask), (return_attention_scores)]." + + $"Received length: {count}."); + + bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; + bool return_attention_scores = false; + if (has_bool) + { + return_attention_scores = (bool)inputs[count - 1]; + count--; + } + + switch (count) + { + case 2: + _inp = (inputs[0], inputs[1]); + break; + case 3: + if (inputs[2].shape[-1] == inputs[1].shape[-1]) + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + else + { + _inp = (inputs[0], inputs[1]); + _mask = inputs[2]; + } + break; + case 4: + _inp = new[] { inputs[0], inputs[1], inputs[2] }; + _mask = inputs[3]; + break; + default: + throw new ValueError(); //TODO:Add discriptions for this err + } + + return call(_inp, _mask, training, return_attention_scores); + } + + protected Tensors call(Tensors inputs, + Tensor attention_mask, + bool? training = null, + bool return_attention_scores = false) + { + var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null); + if (!this._built_from_signature) + this._build_from_signature(query: query, value: value, key: key); + if (key == null) + key = value; + + // TODO: Add RaggedTensor support + //var query_is_ragged = query is tf.RaggedTensor; + //if (query_is_ragged) + //{ + // var query_lengths = query.nested_row_lengths(); + // query = query.to_tensor(); + //} + //var key_is_ragged = key is tf.RaggedTensor; + //var value_is_ragged = value is tf.RaggedTensor; + //if (key_is_ragged && value_is_ragged) + //{ + // // Ensure they have the same shape. + // var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape()); + // key = key.to_tensor(shape: bounding_shape); + // value = value.to_tensor(shape: bounding_shape); + //} + //else if (key_is_ragged) + //{ + // key = key.to_tensor(shape: tf.shape(value)); + //} + //else if (value_is_ragged) + //{ + // value = value.to_tensor(shape: tf.shape(key)); + //} + + // N = `num_attention_heads` + // H = `size_per_head` + // `query` = [B, T, N ,H] + query = this._query_dense.Apply(query); + // `key` = [B, S, N, H] + key = this._key_dense.Apply(key); + // `value` = [B, S, N, H] + value = this._value_dense.Apply(value); + var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false); + attention_output = this._output_dense.Apply(attention_output); + + //if (query_is_ragged) + //{ + // attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths); + //} + + if (return_attention_scores) + return (attention_output, attention_scores.Single); + return attention_output; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs new file mode 100644 index 000000000..3ee61253c --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs @@ -0,0 +1,65 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class Conv1D : Convolutional + { + public Conv1D(Conv1DArgs args) : base(InitializeUndefinedArgs(args)) + { + + } + + private static Conv1DArgs InitializeUndefinedArgs(Conv1DArgs args) + { + if(args.Rank == 0) + { + args.Rank = 1; + } + if(args.Strides is null) + { + args.Strides = 1; + } + if (string.IsNullOrEmpty(args.Padding)) + { + args.Padding = "valid"; + } + if (string.IsNullOrEmpty(args.DataFormat)) + { + args.DataFormat = "channels_last"; + } + if(args.DilationRate == 0) + { + args.DilationRate = 1; + } + if(args.Groups == 0) + { + args.Groups = 1; + } + if(args.KernelInitializer is null) + { + args.KernelInitializer = tf.glorot_uniform_initializer; + } + if(args.BiasInitializer is null) + { + args.BiasInitializer = tf.zeros_initializer; + } + return args; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs new file mode 100644 index 000000000..a6963e307 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs @@ -0,0 +1,61 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class Conv2D : Convolutional + { + public Conv2D(Conv2DArgs args) : base(InitializeUndefinedArgs(args)) + { + + } + + private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args) + { + if(args.Rank == 0) + { + args.Rank = 2; + } + if (args.Strides is null) + { + args.Strides = (1, 1); + } + if (string.IsNullOrEmpty(args.Padding)) + { + args.Padding = "valid"; + } + if (args.DilationRate == 0) + { + args.DilationRate = (1, 1); + } + if (args.Groups == 0) + { + args.Groups = 1; + } + if (args.KernelInitializer is null) + { + args.KernelInitializer = tf.glorot_uniform_initializer; + } + if (args.BiasInitializer is null) + { + args.BiasInitializer = tf.zeros_initializer; + } + return args; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs new file mode 100644 index 000000000..94ad79141 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -0,0 +1,183 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; +using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class Conv2DTranspose : Conv2D + { + public Conv2DTranspose(Conv2DArgs args) : base(InitializeUndefinedArgs(args)) + { + + } + + private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args) + { + if (args.Strides is null) + { + args.Strides = (1, 1); + } + if (string.IsNullOrEmpty(args.Padding)) + { + args.Padding = "valid"; + } + if (args.DilationRate == 0) + { + args.DilationRate = (1, 1); + } + if (args.Groups == 0) + { + args.Groups = 1; + } + if (args.KernelInitializer is null) + { + args.KernelInitializer = tf.glorot_uniform_initializer; + } + if (args.BiasInitializer is null) + { + args.BiasInitializer = tf.zeros_initializer; + } + return args; + } + + public override void build(KerasShapesWrapper input_shape) + { + var single_shape = input_shape.ToSingleShape(); + if (len(single_shape) != 4) + throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); + + var channel_axis = _get_channel_axis(); + var input_dim = single_shape[-1]; + var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim); + + kernel = add_weight(name: "kernel", + shape: kernel_shape, + initializer: kernel_initializer, + regularizer: kernel_regularizer, + trainable: true); + if (use_bias) + bias = add_weight(name: "bias", + shape: filters, + initializer: bias_initializer, + trainable: true); + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var inputs_shape = array_ops.shape(inputs); + var batch_size = inputs_shape[0]; + var (h_axis, w_axis) = (1, 2); + if (data_format == "channels_first") + (h_axis, w_axis) = (2, 3); + var (height, width) = (-1, -1); + if(inputs.shape.ndim > -1) + { + var dims = inputs.shape.dims; + (height, width) = ((int)dims[h_axis], (int)dims[w_axis]); + } + var (kernel_h, kernel_w) = kernel_size; + var (stride_h, stride_w) = strides; + + var (out_pad_h, out_pad_w) = (-1, -1); + + // Infer the dynamic output shape: + var out_height = conv_utils.deconv_output_length(height, + (int)kernel_h, + padding: padding, + output_padding: out_pad_h, + stride: (int)stride_h, + dilation: (int)dilation_rate[0]); + + var out_width = conv_utils.deconv_output_length(width, + (int)kernel_w, + padding: padding, + output_padding: out_pad_w, + stride: (int)stride_w, + dilation: (int)dilation_rate[1]); + + Tensor output_shape_tensor; + if (data_format == "channels_first") + output_shape_tensor = array_ops.stack(new object[] { batch_size, filters, out_height, out_width }); + else + output_shape_tensor = array_ops.stack(new object[] { batch_size, out_height, out_width, filters }); + + var outputs = keras.backend.conv2d_transpose( + inputs, + kernel, + output_shape_tensor, + strides: strides, + padding: padding, + data_format: data_format, + dilation_rate: dilation_rate); + + if (!tf.Context.executing_eagerly()) + { + var out_shape = ComputeOutputShape(inputs.shape); + outputs.shape = out_shape; + } + + if (use_bias) + tf.nn.bias_add( + outputs, + bias, + data_format: conv_utils.convert_data_format(data_format, ndim: 4)); + + if (activation != null) + return activation.Apply(outputs); + + return outputs; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + var output_shape = input_shape.dims; + var (c_axis, h_axis, w_axis) = (3, 1, 2); + if (data_format == "channels_first") + (c_axis, h_axis, w_axis) = (1, 2, 3); + + var (kernel_h, kernel_w) = kernel_size; + var (stride_h, stride_w) = strides; + + var (out_pad_h, out_pad_w) = (-1, -1); + output_shape[c_axis] = filters; + output_shape[h_axis] = conv_utils.deconv_output_length( + (int)output_shape[h_axis], + (int)kernel_h, + padding: padding, + output_padding: out_pad_h, + stride: (int)stride_h, + dilation: (int)dilation_rate[0]); + output_shape[w_axis] = conv_utils.deconv_output_length( + (int)output_shape[w_axis], + (int)kernel_w, + padding: padding, + output_padding: out_pad_w, + stride: (int)stride_w, + dilation: (int)dilation_rate[1]); + + return new Shape(output_shape); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs new file mode 100644 index 000000000..d8e00d520 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -0,0 +1,131 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Convolutional : Layer + { + ConvolutionalArgs args; + protected int rank => args.Rank; + protected int filters => args.Filters; + protected Shape kernel_size => args.KernelSize; + protected Shape strides => args.Strides; + protected string padding => args.Padding; + protected string data_format => args.DataFormat; + protected Shape dilation_rate => args.DilationRate; + protected Activation activation => args.Activation; + protected bool use_bias => args.UseBias; + protected IInitializer kernel_initializer => args.KernelInitializer; + protected IRegularizer kernel_regularizer => args.KernelRegularizer; + protected IInitializer bias_initializer => args.BiasInitializer; + protected IVariableV1 kernel; + protected IVariableV1 bias; + ConvolutionInternal _convolution_op; + protected string _tf_data_format; + + public Convolutional(ConvolutionalArgs args) : base(args) + { + this.args = args; + args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.as_int_list(), args.Rank, "kernel_size"); + args.Strides = conv_utils.normalize_tuple(args.Strides.as_int_list(), args.Rank, "strides"); + args.Padding = conv_utils.normalize_padding(args.Padding); + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + args.DilationRate = conv_utils.normalize_tuple(args.DilationRate.as_int_list(), args.Rank, "dilation_rate"); + inputSpec = new InputSpec(ndim: rank + 2); + _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); + } + + public override void build(KerasShapesWrapper input_shape) + { + int channel_axis = data_format == "channels_first" ? 1 : -1; + var single_shape = input_shape.ToSingleShape(); + var input_channel = channel_axis < 0 ? + single_shape.dims[single_shape.ndim + channel_axis] : + single_shape.dims[channel_axis]; + Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters }); + kernel = add_weight(name: "kernel", + shape: kernel_shape, + initializer: kernel_initializer, + regularizer: kernel_regularizer, + trainable: true, + dtype: DType); + if (use_bias) + bias = add_weight(name: "bias", + shape: new int[] { filters }, + initializer: bias_initializer, + trainable: true, + dtype: DType); + + var axes = new Dictionary(); + axes.Add(-1, (int)input_channel); + inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); + + string tf_padding; + if (padding == "causal") + tf_padding = "VALID"; + else + tf_padding = padding.ToUpper(); + + string tf_op_name = GetType().Name; + + + _convolution_op = nn_ops.convolution_internal(tf_padding, + strides, + dilation_rate, + rank, + data_format: _tf_data_format, + name: tf_op_name); + + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = false, IOptionalArgs? optional_args = null) + { + var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); + if (use_bias) + { + if (data_format == "channels_first") + { + throw new NotImplementedException("call channels_first"); + } + else + { + outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); + } + } + + if (activation != null) + outputs = activation.Apply(outputs); + + return outputs; + } + + protected virtual int _get_channel_axis() + => data_format == "channels_first" ? -1 - rank : -1; + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs new file mode 100644 index 000000000..dae4a4036 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Utils; +using Tensorflow.Operations; +using Newtonsoft.Json; +using System.Security.Cryptography; + +namespace Tensorflow.Keras.Layers +{ + public class DepthwiseConv2DArgs: Conv2DArgs + { + /// + /// depth_multiplier: The number of depthwise convolution output channels for + /// each input channel.The total number of depthwise convolution output + /// channels will be equal to `filters_in* depth_multiplier`. + /// + [JsonProperty("depth_multiplier")] + public int DepthMultiplier { get; set; } = 1; + + [JsonProperty("depthwise_initializer")] + public IInitializer DepthwiseInitializer { get; set; } + } + + public class DepthwiseConv2D : Conv2D + { + /// + /// depth_multiplier: The number of depthwise convolution output channels for + /// each input channel.The total number of depthwise convolution output + /// channels will be equal to `filters_in* depth_multiplier`. + /// + int DepthMultiplier = 1; + + IInitializer DepthwiseInitializer; + + int[] strides; + + int[] dilation_rate; + + string getDataFormat() + { + return data_format == "channels_first" ? "NCHW" : "NHWC"; + } + + static int _id = 1; + + public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args) + { + args.Padding = args.Padding.ToUpper(); + + if(string.IsNullOrEmpty(args.Name)) + name = "DepthwiseConv2D_" + _id; + + this.DepthMultiplier = args.DepthMultiplier; + this.DepthwiseInitializer = args.DepthwiseInitializer; + + } + + public override void build(KerasShapesWrapper input_shape) + { + //base.build(input_shape); + + var shape = input_shape.ToSingleShape(); + + int channel_axis = data_format == "channels_first" ? 1 : -1; + var input_channel = channel_axis < 0 ? + shape.dims[shape.ndim + channel_axis] : + shape.dims[channel_axis]; + + var arg = args as DepthwiseConv2DArgs; + + if (arg.Strides.ndim != shape.ndim) + { + if (arg.Strides.ndim == 2) + { + this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 }; + } + else + { + this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides"); + } + } + else + { + this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray(); + } + + if (arg.DilationRate.ndim != shape.ndim) + { + this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate"); + } + + long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1]; + + var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] { + channel_data, + this.DepthMultiplier + }); + + this.kernel = this.add_weight( + shape: depthwise_kernel_shape, + initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer, + name: "depthwise_kernel", + trainable: true, + dtype: DType, + regularizer: this.kernel_regularizer + ); + + var axes = new Dictionary(); + axes.Add(-1, (int)input_channel); + inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); + + + if (use_bias) + { + bias = add_weight(name: "bias", + shape: ((int)channel_data), + initializer: bias_initializer, + trainable: true, + dtype: DType); + } + + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, + bool? training = false, IOptionalArgs? optional_args = null) + { + Tensor outputs = null; + + outputs = gen_nn_ops.depthwise_conv2d_native( + inputs, + filter: this.kernel.AsTensor(), + strides: this.strides, + padding: this.padding, + dilations: this.dilation_rate, + data_format: this.getDataFormat(), + name: name + ); + + if (use_bias) + { + if (data_format == "channels_first") + { + throw new NotImplementedException("call channels_first"); + } + else + { + outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias), + data_format: this.getDataFormat(), name: name); + } + } + + if (activation != null) + outputs = activation.Apply(outputs); + + + return outputs; + } + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs new file mode 100644 index 000000000..db5d626ed --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -0,0 +1,94 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Just your regular densely-connected NN layer. + /// + public class Dense : Layer + { + DenseArgs args; + IVariableV1 kernel; + IVariableV1 bias; + Activation activation => args.Activation; + + public Dense(DenseArgs args) : + base(args) + { + this.args = args; + this.SupportsMasking = true; + this.inputSpec = new InputSpec(min_ndim: 2); + } + + public override void build(KerasShapesWrapper input_shape) + { + _buildInputShape = input_shape; + Debug.Assert(input_shape.Shapes.Length <= 1); + var single_shape = input_shape.ToSingleShape(); + var last_dim = single_shape.dims.Last(); + var axes = new Dictionary(); + axes[-1] = (int)last_dim; + inputSpec = new InputSpec(min_ndim: 2, axes: axes); + kernel = add_weight( + "kernel", + shape: new Shape(last_dim, args.Units), + initializer: args.KernelInitializer, + dtype: DType, + trainable: true); + if (args.UseBias) + bias = add_weight( + "bias", + shape: new Shape(args.Units), + initializer: args.BiasInitializer, + dtype: DType, + trainable: true); + + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor outputs = null; + var rank = inputs.rank; + if (rank > 2) + { + outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } }); + } + else + { + outputs = math_ops.matmul(inputs, kernel.AsTensor()); + } + + if (args.UseBias) + outputs = tf.nn.bias_add(outputs, bias); + if (args.Activation != null) + outputs = activation.Apply(outputs); + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs new file mode 100644 index 000000000..0cbd50846 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -0,0 +1,338 @@ +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.ArgsDefinition.Core; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + // A layer that uses `tf.einsum` as the backing computation. + // This layer can perform einsum calculations of arbitrary dimensionality. + // Args: + // equation: An equation describing the einsum to perform. This equation must + // be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or + // `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis + // expression sequence. + // output_shape: The expected shape of the output tensor (excluding the batch + // dimension and any dimensions represented by ellipses). You can specify + // None for any dimension that is unknown or can be inferred from the input + // shape. + // activation: Activation function to use. If you don't specify anything, no + // activation is applied (that is, a "linear" activation: `a(x) = x`). + // bias_axes: A string containing the output dimension(s) to apply a bias to. + // Each character in the `bias_axes` string should correspond to a character + // in the output portion of the `equation` string. + // kernel_initializer: Initializer for the `kernel` weights matrix. + // bias_initializer: Initializer for the bias vector. + // kernel_regularizer: Regularizer function applied to the `kernel` weights + // matrix. + // bias_regularizer: Regularizer function applied to the bias vector. + // activity_regularizer: Regularizer function applied to the output of the + // layer (its "activation"). + // kernel_constraint: Constraint function applied to the `kernel` weights + // matrix. + // bias_constraint: Constraint function applied to the bias vector. + // Examples: + // **Biased dense layer with einsums** + // This example shows how to instantiate a standard Keras dense layer using + // einsum operations. This example is equivalent to + // `tf.keras.layers.Dense(64, use_bias=True)`. + // >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac", + // ... output_shape=64, + // ... bias_axes="c") + // >>> input_tensor = tf.keras.Input(shape=[32]) + // >>> output_tensor = layer(input_tensor) + // >>> output_tensor + // <... shape=(None, 64) dtype=...> + // **Applying a dense layer to a sequence** + // This example shows how to instantiate a layer that applies the same dense + // operation to every element in a sequence. Here, the `output_shape` has two + // values (since there are two non-batch dimensions in the output); the first + // dimension in the `output_shape` is `None`, because the sequence dimension `b` + // has an unknown shape. + // >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd", + // ... output_shape=(None, 64), + // ... bias_axes="d") + // >>> input_tensor = tf.keras.Input(shape=[32, 128]) + // >>> output_tensor = layer(input_tensor) + // >>> output_tensor + // <... shape=(None, 32, 64) dtype=...> + // **Applying a dense layer to a sequence using ellipses** + // This example shows how to instantiate a layer that applies the same dense + // operation to every element in a sequence, but uses the ellipsis notation + // instead of specifying the batch and sequence dimensions. + // Because we are using ellipsis notation and have specified only one axis, the + // `output_shape` arg is a single value. When instantiated in this way, the layer + // can handle any number of sequence dimensions - including the case where no + // sequence dimension exists. + // >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y", + // ... output_shape=64, + // ... bias_axes="y") + // >>> input_tensor = tf.keras.Input(shape=[32, 128]) + // >>> output_tensor = layer(input_tensor) + // >>> output_tensor + // <... shape=(None, 32, 64) dtype=...> + // + public class EinsumDense : Layer + { + + string equation; + + Activation activation; + + IVariableV1 bias; + + IVariableV1 kernel; + + string bias_axes; + + IInitializer kernel_initializer; + + IInitializer bias_initializer; + + System.Action kernel_constraint; + + System.Action bias_constraint; + + IRegularizer bias_regularizer; + + IRegularizer kernel_regularizer; + + Shape full_output_shape; + + Shape partial_output_shape; + + public EinsumDense(EinsumDenseArgs args) : base(args) + { + this.equation = args.Equation; + this.partial_output_shape = args.OutputShape; + this.bias_axes = args.BiasAxes; + this.activation = args.Activation; + this.kernel_initializer = args.KernelInitializer; + this.bias_initializer = args.BiasInitializer; + this.kernel_regularizer = args.KernelRegularizer; + this.bias_regularizer = args.BiasRegularizer; + this.kernel_constraint = args.KernelConstraint; + this.bias_constraint = args.BiasConstraint; + } + + public override void build(KerasShapesWrapper input_shape) + { + var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, + input_shape.ToSingleShape(), this.partial_output_shape); + var kernel_shape = shape_data.Item1; + var bias_shape = shape_data.Item2; + this.full_output_shape = shape_data.Item3; + this.kernel = this.add_weight("kernel", shape: kernel_shape, + initializer: this.kernel_initializer, + regularizer: this.kernel_regularizer, + //constraint: this.kernel_constraint, + dtype: this.DType, + trainable: true); + if (bias_shape != null) + this.bias = this.add_weight("bias", shape: bias_shape, + initializer: this.bias_initializer, + regularizer: this.bias_regularizer, + //constraint: this.bias_constraint, + dtype: this.DType, + trainable: true); + else + this.bias = null; + base.build(input_shape); + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return this.full_output_shape; + } + + //public virtual object get_config() { + // var config = new Dictionary { + // { + // "output_shape", + // this.partial_output_shape}, + // { + // "equation", + // this.equation}, + // { + // "activation", + // activations.serialize(this.activation)}, + // { + // "bias_axes", + // this.bias_axes}, + // { + // "kernel_initializer", + // initializers.serialize(this.kernel_initializer)}, + // { + // "bias_initializer", + // initializers.serialize(this.bias_initializer)}, + // { + // "kernel_regularizer", + // regularizers.serialize(this.kernel_regularizer)}, + // { + // "bias_regularizer", + // regularizers.serialize(this.bias_regularizer)}, + // { + // "activity_regularizer", + // regularizers.serialize(this.activity_regularizer)}, + // { + // "kernel_constraint", + // constraints.serialize(this.kernel_constraint)}, + // { + // "bias_constraint", + // constraints.serialize(this.bias_constraint)}}; + // var base_config = base.get_config(); + // return new dict(base_config.items().ToList() + config.items().ToList()); + //} + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); + if (this.bias != null) + ret += this.bias.AsTensor(); + if (this.activation != null) + ret = this.activation.Apply(ret); + return ret; + } + /// + /// Analyzes an einsum string to determine the required weight shape. + /// + public static (Shape, Shape, Shape) _analyze_einsum_string(string equation, string bias_axes, Shape input_shape, Shape output_shape) + { + var dot_replaced_string = Regex.Replace(equation, @"\.\.\.", "0"); + // This is the case where no ellipses are present in the string. + var split_string = Regex.Match(dot_replaced_string, "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"); + if (split_string.Success) + return _analyze_split_string(split_string, bias_axes, input_shape, output_shape); + // This is the case where ellipses are present on the left. + split_string = Regex.Match(dot_replaced_string, "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)"); + if (split_string.Success) + return _analyze_split_string(split_string, bias_axes, input_shape, output_shape, left_elided: true); + // This is the case where ellipses are present on the right. + split_string = Regex.Match(dot_replaced_string, "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0"); + if (split_string.Success) + return _analyze_split_string(split_string, bias_axes, input_shape, output_shape); + throw new ValueError($"Invalid einsum equation '{equation}'. " + + $"Equations must be in the form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...."); + } + + /// + /// Analyze an pre-split einsum string to find the weight shape. + /// + public static (Shape, Shape, Shape) _analyze_split_string(Match split_string, + string bias_axes, + Shape input_shape, + Shape output_shape, + bool left_elided = false) + { + List bias_shape; + Dictionary output_dim_map; + Dictionary input_dim_map; + + var input_spec = split_string.Groups[1].Value; + var weight_spec = split_string.Groups[2].Value; + var output_spec = split_string.Groups[3].Value; + var elided = input_shape.ndim - input_spec.Count(); + var _output_shape = new List(); + _output_shape.Add((int)input_shape[0]); + _output_shape.AddRange(output_shape.as_int_list()); + + if (elided > 0 && left_elided) + for (var i = 1; i < elided - 1; i++) + // We already inserted the 0th input dimension at dim 0, so we need to + // start at location 1 here. + _output_shape.Insert(1, (int)input_shape[i]); + else if (elided > 0 && !left_elided) + for (var i = input_shape.ndim - elided; i < input_shape.ndim - (input_shape.ndim - elided); i++) + _output_shape.Add((int)input_shape[i]); + + if (left_elided) + { + // If we have beginning dimensions elided, we need to use negative indexing + // to determine where in the input dimension our values are. + //input_dim_map = { dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec) } + input_dim_map = input_spec.Select((dim, i) => (i, dim)) + .ToDictionary(_ => _.dim, _ => _.i + elided - input_shape.ndim); + // Because we've constructed the full output shape already, we don't need + // to do negative indexing. + //output_dim_map = { dim: (i + elided) for i, dim in enumerate(output_spec)} + output_dim_map = output_spec.Select((dim, i) => (i, dim)) + .ToDictionary(_ => _.dim, _ => _.i + elided); + } + else + { + input_dim_map = input_spec.Select((dim, i) => (i, dim)) + .ToDictionary(_ => _.dim, _ => _.i); + output_dim_map = output_spec.Select((dim, i) => (i, dim)) + .ToDictionary(_ => _.dim, _ => _.i); + } + + foreach (var dim in input_spec) + { + var input_shape_at_dim = input_shape[input_dim_map[dim]]; + if (output_dim_map.TryGetValue(dim, out int index)) + { + var output_shape_at_dim = _output_shape[index]; + if (output_shape_at_dim != -1 && output_shape_at_dim != input_shape_at_dim) + throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + + $"Input shape is {input_shape_at_dim}, " + + $"and output shape is {output_shape[output_dim_map[dim]]}."); + } + } + + foreach (var dim in output_spec) + { + if (!input_spec.Contains(dim) && !weight_spec.Contains(dim)) + { + throw new ValueError($"Dimension '{dim}' was specified in the output '{output_spec}' " + + $"but has no corresponding dim in the input spec '{input_spec}' " + + $"or weight spec '{output_spec}'"); + } + } + + var weight_shape = new List(); + foreach (var dim in weight_spec) + { + if (input_dim_map.ContainsKey(dim)) + weight_shape.append(input_shape[input_dim_map[dim]]); + else if (output_dim_map.ContainsKey(dim)) + weight_shape.append(_output_shape[output_dim_map[dim]]); + else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + + $"either the input spec '{input_spec}' " + + $"or the output spec '{output_spec}'. " + + $"For this layer, the weight must be fully specified."); + } + + if (bias_axes != null) + { + var num_left_elided = left_elided ? elided : 0; + var idx_map = output_spec.Select((_char, i) => (i, _char)) + .ToDictionary(_ => _._char, _ => _output_shape[_.i + num_left_elided]); + foreach (var _char in bias_axes) + if (!output_spec.Contains(_char)) + throw new ValueError($"Bias dimension '{_char}' was requested," + + $" but is not part of the output spec '{output_spec}'"); + var first_bias_location = (from _char in bias_axes + select output_spec.IndexOf(_char)).ToList().Min(); + var bias_output_spec = output_spec.Substring(first_bias_location); + bias_shape = (from _char in bias_output_spec + select bias_axes.Contains(_char) ? idx_map[_char] : 1).ToList(); + if (!left_elided) + foreach (var _ in Enumerable.Range(0, elided)) + bias_shape.append(1); + } + else bias_shape = null; + + return (weight_shape.ToArray(), + (bias_shape ?? new List()).ToArray(), + _output_shape.ToArray()); + } + } +} + + diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs new file mode 100644 index 000000000..87b42bb7b --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -0,0 +1,80 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Turns positive integers (indexes) into dense vectors of fixed size. + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding + /// + public class Embedding : Layer + { + EmbeddingArgs args; + int input_dim => args.InputDim; + int output_dim => args.OutputDim; + bool mask_zero => args.MaskZero; + IVariableV1 embeddings; + IInitializer embeddings_initializer; + + public Embedding(EmbeddingArgs args) + : base(new LayerArgs // copy args + { + DType = args.DType, + Name = args.Name, + InputShape = args.InputShape, + BatchSize = args.BatchSize + }) + { + this.args = args; + if (args.InputShape == null) + args.InputShape = args.InputLength; + + if (args.BatchInputShape == null) + args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); + + embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer; + SupportsMasking = mask_zero; + } + + public override void build(KerasShapesWrapper input_shape) + { + tf.Context.eager_mode(); + embeddings = add_weight(shape: (input_dim, output_dim), + initializer: embeddings_initializer, + name: "embeddings"); + tf.Context.graph_mode(); + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var dtype = inputs.dtype; + if (dtype != tf.int32 && dtype != tf.int64) + inputs = math_ops.cast(inputs, tf.int32); + + var outputs = embedding_ops.embedding_lookup(embeddings, inputs); + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs new file mode 100644 index 000000000..f7385bad5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -0,0 +1,106 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer to be used as an entry point into a Network (a graph of layers). + /// + public class InputLayer : Layer + { + InputLayerArgs args; + bool isPlaceholder; + TensorSpec typeSpec; + + public InputLayer(InputLayerArgs args) : + base(args) + { + this.args = args; + built = true; + SupportsMasking = true; + + if (BatchInputShape is not null) + { + args.BatchSize = (int)(BatchInputShape.ToSingleShape().dims[0]); + args.InputShape = BatchInputShape.ToSingleShape().dims.Skip(1).ToArray(); + } + + // moved to base class + if (string.IsNullOrEmpty(args.Name)) + { + var prefix = "input"; + name = prefix + '_' + keras.backend.get_uid(prefix); + args.Name = name; + } + + if (args.DType == TF_DataType.DtInvalid) + { + args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; + } + + if (args.InputTensor == null) + { + if (args.InputShape != null) + { + args.BatchInputShape = new Saving.KerasShapesWrapper(new long[] { args.BatchSize } + .Concat(args.InputShape.dims).ToArray()); + } + else + { + args.BatchInputShape = null; + } + + var graph = keras.backend.get_graph(); + graph.as_default(); + + args.InputTensor = keras.backend.placeholder( + shape: BatchInputShape.ToSingleShape(), + dtype: DType, + name: Name, + sparse: args.Sparse, + ragged: args.Ragged); + + graph.Exit(); + + isPlaceholder = true; + } + + // Create an input node to add to self.outbound_node + // and set output_tensors' _keras_history. + // input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0) + // input_tensor._keras_mask = None + var node = new Node(new NodeArgs + { + Outputs = args.InputTensor + }); + node.Connect(this); + + typeSpec = new TensorSpec(args.InputTensor.shape, + dtype: args.InputTensor.dtype, + name: Name); + } + + public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs new file mode 100644 index 000000000..2c55f8fd5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs @@ -0,0 +1,23 @@ +using Tensorflow.NumPy; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers { + public partial class LayersApi { + public ILayer ELU ( float alpha = 0.1f ) + => new ELU(new ELUArgs { Alpha = alpha }); + public ILayer SELU () + => new SELU(new SELUArgs { }); + public ILayer Softmax(int axis = -1) => new Softmax(new SoftmaxArgs { axis = axis }); + public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); + public ILayer Softplus () => new Softplus(new SoftplusArgs { }); + public ILayer HardSigmoid () => new HardSigmoid(new HardSigmoidArgs { }); + public ILayer Softsign () => new Softsign(new SoftsignArgs { }); + public ILayer Swish () => new Swish(new SwishArgs { }); + public ILayer Tanh () => new Tanh(new TanhArgs { }); + public ILayer Exponential () => new Exponential(new ExponentialArgs { }); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs new file mode 100644 index 000000000..859e9c14d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs @@ -0,0 +1,56 @@ +using System; +using Tensorflow.NumPy; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi + { + public ILayer Attention(bool use_scale = false, + string score_mode = "dot", + bool causal = false, + float dropout = 0f) => + new Attention(new AttentionArgs + { + use_scale = use_scale, + score_mode = score_mode, + causal = causal, + dropout = dropout + }); + public ILayer MultiHeadAttention(int num_heads, + int key_dim, + int? value_dim = null, + float dropout = 0f, + bool use_bias = true, + Shape output_shape = null, + Shape attention_axes = null, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null, + Action kernel_constraint = null, + Action bias_constraint = null) => + new MultiHeadAttention(new MultiHeadAttentionArgs + { + NumHeads = num_heads, + KeyDim = key_dim, + ValueDim = value_dim, + Dropout = dropout, + UseBias = use_bias, + OutputShape = output_shape, + AttentionAxis = attention_axes, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + KernelRegularizer = kernel_regularizer, + BiasRegularizer = bias_regularizer, + ActivityRegularizer = activity_regularizer, + KernelConstraint = kernel_constraint, + BiasConstraint = bias_constraint, + }); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs new file mode 100644 index 000000000..3e3442f25 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs @@ -0,0 +1,38 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Layers.Reshaping; +using Tensorflow.Keras.ArgsDefinition.Reshaping; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi { + /// + /// Cropping layer for 1D input + /// + /// cropping size + public ILayer Cropping1D ( NDArray cropping ) + => new Cropping1D(new Cropping1DArgs { + cropping = cropping + }); + + /// + /// Cropping layer for 2D input
+ ///
+ public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) + => new Cropping2D(new Cropping2DArgs { + cropping = cropping, + data_format = data_format + }); + + /// + /// Cropping layer for 3D input
+ ///
+ public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) + => new Cropping3D(new Cropping3DArgs { + cropping = cropping, + data_format = data_format + }); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs new file mode 100644 index 000000000..bf06b1418 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs @@ -0,0 +1,22 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi + { + /// + /// Layer that concatenates a list of inputs. + /// + /// Axis along which to concatenate. + /// + public ILayer Concatenate(int axis = -1) + => new Concatenate(new ConcatenateArgs + { + Axis = axis + }); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs new file mode 100644 index 000000000..2ee99bc79 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs @@ -0,0 +1,70 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers { + public partial class LayersApi { + + /// + /// Upsampling layer for 1D inputs. Repeats each temporal step `size` times along the time axis. + /// + /// + /// + public ILayer UpSampling1D(int size) + => new UpSampling1D(new UpSampling1DArgs + { + Size = size + }); + + /// + /// Zero-padding layer for 2D input (e.g. picture). + /// + /// + /// + public ILayer ZeroPadding2D ( NDArray padding ) + => new ZeroPadding2D(new ZeroPadding2DArgs { + Padding = padding + }); + + /// + /// Upsampling layer for 2D inputs.
+ /// Repeats the rows and columns of the data by size[0] and size[1] respectively. + ///
+ /// + /// + /// + /// + public ILayer UpSampling2D(Shape size, string data_format, string interpolation) + => new UpSampling2D(new UpSampling2DArgs + { + Size = size, + DataFormat = data_format, + Interpolation = interpolation + }); + + /// + /// Permutes the dimensions of the input according to a given pattern. + /// + public ILayer Permute ( int[] dims ) + => new Permute(new PermuteArgs { + dims = dims + }); + + /// + /// Layer that reshapes inputs into the given shape. + /// + /// + /// + public ILayer Reshape ( Shape target_shape ) + => new Reshape(new ReshapeArgs { + TargetShape = target_shape + }); + + public ILayer Reshape ( object[] target_shape ) + => new Reshape(new ReshapeArgs { + TargetShapeObjects = target_shape + }); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs new file mode 100644 index 000000000..a1e4c11b1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -0,0 +1,1161 @@ +using System; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Core; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi : ILayersApi + { + public IPreprocessing preprocessing { get; } = new Preprocessing(); + + /// + /// Layer that normalizes its inputs. + /// Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1. + /// Importantly, batch normalization works differently during training and during inference. + /// + /// http://arxiv.org/abs/1502.03167 + /// + /// The axis that should be normalized (typically the features axis). + /// For instance, after a Conv2D layer with data_format="channels_first", set axis=1 in BatchNormalization. + /// + /// Momentum for the moving average. + /// Small float added to variance to avoid dividing by zero. + /// If True, add offset of beta to normalized tensor. If False, beta is ignored. + /// If True, multiply by gamma. If False, gamma is not used. When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. + /// Initializer for the beta weight. + /// Initializer for the gamma weight. + /// Initializer for the moving mean. + /// Initializer for the moving variance. + /// Boolean, if True the variables will be marked as trainable. + /// Layer name. + /// Whether to use Batch Renormalization. This adds extra variables during training. The inference is the same for either value of this parameter. + /// Momentum used to update the moving means and standard deviations with renorm. + /// Unlike momentum, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). + /// Note that momentum is still applied to get the means and variances for inference. + /// + /// Tensor of the same shape as input. + public ILayer BatchNormalization(int axis = -1, + float momentum = 0.99f, + float epsilon = 0.001f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null, + IInitializer moving_mean_initializer = null, + IInitializer moving_variance_initializer = null, + bool trainable = true, + string name = null, + bool renorm = false, + float renorm_momentum = 0.99f) + => new BatchNormalization(new BatchNormalizationArgs + { + Axis = axis, + Momentum = momentum, + Epsilon = epsilon, + Center = center, + Scale = scale, + BetaInitializer = beta_initializer ?? tf.zeros_initializer, + GammaInitializer = gamma_initializer ?? tf.ones_initializer, + MovingMeanInitializer = moving_mean_initializer ?? tf.zeros_initializer, + MovingVarianceInitializer = moving_variance_initializer ?? tf.ones_initializer, + Renorm = renorm, + RenormMomentum = renorm_momentum, + Trainable = trainable, + Name = name + }); + + /// + /// 1D convolution layer (e.g. temporal convolution). + /// This layer creates a convolution kernel that is convolved with the layer input over a single spatial(or temporal) dimension to produce a tensor of outputs.If use_bias is True, a bias vector is created and added to the outputs.Finally, if activation is not None, it is applied to the outputs as well. + /// + /// Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution) + /// An integer specifying the width of the 1D convolution window. + /// An integer specifying the stride of the convolution window . Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. + /// one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last. + /// An integer specifying the dilation rate to use for dilated convolution.Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. + /// A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups. + /// Activation function to use. If you don't specify anything, no activation is applied (see keras.activations). + /// Boolean, whether the layer uses a bias vector. + /// Initializer for the kernel weights matrix (see keras.initializers). + /// Initializer for the bias vector (see keras.initializers). + /// A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias). + public ILayer Conv1D(int filters, + Shape kernel_size, + int strides = 1, + string padding = "valid", + string data_format = "channels_last", + int dilation_rate = 1, + int groups = 1, + string activation = null, + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros") + => new Conv1D(new Conv1DArgs + { + Rank = 1, + Filters = filters, + KernelSize = kernel_size ?? new Shape(1, 5), + Strides = strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate, + Groups = groups, + UseBias = use_bias, + Activation = keras.activations.GetActivationFromName(activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + BiasInitializer = GetInitializerByName(bias_initializer) + }); + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid") + => new Conv2D(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = null, + DilationRate = (1, 1), + Groups = 1, + UseBias = false, + KernelRegularizer = null, + KernelInitializer =tf.glorot_uniform_initializer, + BiasInitializer = tf.zeros_initializer, + BiasRegularizer = null, + ActivityRegularizer = null, + Activation = keras.activations.Linear, + }); + /// + /// 2D convolution layer (e.g. spatial convolution over images). + /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. + /// If use_bias is True, a bias vector is created and added to the outputs.Finally, if activation is not None, it is applied to the outputs as well. + /// + /// Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution) + /// An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. + /// An integer or tuple/list of 2 integers, specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. + /// one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last. + /// an integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. + /// A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups. + /// Activation function to use. If you don't specify anything, no activation is applied (see keras.activations). + /// Boolean, whether the layer uses a bias vector. + /// Initializer for the kernel weights matrix (see keras.initializers). + /// Initializer for the bias vector (see keras.initializers). + /// Regularizer function applied to the kernel weights matrix (see keras.regularizers). + /// Regularizer function applied to the bias vector (see keras.regularizers). + /// Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers). + /// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias). + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + Activation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null) + => new Conv2D(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, + Groups = groups, + UseBias = use_bias, + KernelRegularizer = kernel_regularizer, + KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, + BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, + BiasRegularizer = bias_regularizer, + ActivityRegularizer = activity_regularizer, + Activation = activation ?? keras.activations.Linear, + }); + + /// + /// 2D convolution layer (e.g. spatial convolution over images). + /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. + /// If use_bias is True, a bias vector is created and added to the outputs.Finally, if activation is not None, it is applied to the outputs as well. + /// + /// Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution) + /// An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. + /// An integer or tuple/list of 2 integers, specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. + /// one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last. + /// an integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. + /// A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups. + /// Activation function to use. If you don't specify anything, no activation is applied (see keras.activations). + /// Boolean, whether the layer uses a bias vector. + /// The name of the initializer for the kernel weights matrix (see keras.initializers). + /// The name of the initializer for the bias vector (see keras.initializers). + /// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias). + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + string activation = null, + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros") + => new Conv2D(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = (kernel_size == null) ? (5,5) : kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, + Groups = groups, + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Activation = keras.activations.GetActivationFromName(activation) + }); + + public ILayer DepthwiseConv2D(Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + int depth_multiplier = 1, + string activation = null, + bool use_bias = false, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros", + string depthwise_initializer = "glorot_uniform" + ) + => new DepthwiseConv2D(new DepthwiseConv2DArgs + { + Rank = 2, + Filters = 1, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1) : strides, + Padding = padding, + DepthMultiplier = depth_multiplier, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1) : dilation_rate, + Groups = groups, + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Activation = keras.activations.GetActivationFromName(activation), + }); + + + /// + /// Transposed convolution layer (sometimes called Deconvolution). + /// + /// Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution) + /// An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. + /// An integer or tuple/list of 2 integers, specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. + /// one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last. + /// an integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. + /// Activation function to use. If you don't specify anything, no activation is applied (see keras.activations). + /// Boolean, whether the layer uses a bias vector. + /// The name of the initializer for the kernel weights matrix (see keras.initializers). + /// The name of the initializer for the bias vector (see keras.initializers). + /// The name of the regularizer function applied to the kernel weights matrix (see keras.regularizers). + /// The name of the regularizer function applied to the bias vector (see keras.regularizers). + /// The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers). + /// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias). + public ILayer Conv2DTranspose(int filters, + Shape kernel_size = null, + Shape strides = null, + string output_padding = "valid", + string data_format = null, + Shape dilation_rate = null, + string activation = null, + bool use_bias = false, + string kernel_initializer = null, + string bias_initializer = null, + string kernel_regularizer = null, + string bias_regularizer = null, + string activity_regularizer = null) + => new Conv2DTranspose(new Conv2DTransposeArgs + { + Rank = 2, + Filters = filters, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = output_padding, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Activation = keras.activations.GetActivationFromName(activation) + }); + + /// + /// Just your regular densely-connected NN layer. + /// + /// Dense implements the operation: output = activation(dot(input, kernel) + bias) where activation is the + /// element-wise activation function passed as the activation argument, kernel is a weights matrix created by the layer, + /// and bias is a bias vector created by the layer (only applicable if use_bias is True). + /// + /// Positive integer, dimensionality of the output space. + /// Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x). + /// Initializer for the kernel weights matrix. + /// Boolean, whether the layer uses a bias vector. + /// Initializer for the bias vector. + /// N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim). + /// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units). + public ILayer Dense(int units, + Activation activation = null, + IInitializer kernel_initializer = null, + bool use_bias = true, + IInitializer bias_initializer = null, + Shape input_shape = null) + => new Dense(new DenseArgs + { + Units = units, + Activation = activation ?? keras.activations.Linear, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + BiasInitializer = bias_initializer ?? (use_bias ? tf.zeros_initializer : null), + InputShape = input_shape + }); + + /// + /// Just your regular densely-connected NN layer. + /// + /// Dense implements the operation: output = activation(dot(input, kernel) + bias) where activation is the + /// element-wise activation function passed as the activation argument, kernel is a weights matrix created by the layer, + /// and bias is a bias vector created by the layer (only applicable if use_bias is True). + /// + /// Positive integer, dimensionality of the output space. + /// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units). + public ILayer Dense(int units) + => new Dense(new DenseArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName("linear") + }); + + /// + /// Just your regular densely-connected NN layer. + /// + /// Dense implements the operation: output = activation(dot(input, kernel) + bias) where activation is the + /// element-wise activation function passed as the activation argument, kernel is a weights matrix created by the layer, + /// and bias is a bias vector created by the layer (only applicable if use_bias is True). + /// + /// Positive integer, dimensionality of the output space. + /// Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x). + /// N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim). + /// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units). + public ILayer Dense(int units, + string activation = null, + Shape input_shape = null) + => new Dense(new DenseArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + InputShape = input_shape + }); + + /// + /// Densely-connected layer class. aka fully-connected

+ /// `outputs = activation(inputs * kernel + bias)` + ///
+ /// + /// Python integer, dimensionality of the output space. + /// + /// Boolean, whether the layer uses a bias. + /// + /// + /// + /// + /// + /// + public Tensor dense(Tensor inputs, + int units, + Activation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + bool trainable = true, + string name = null, + bool? reuse = null) + { + if (bias_initializer == null) + bias_initializer = tf.zeros_initializer; + + var layer = new Dense(new DenseArgs + { + Units = units, + Activation = activation, + UseBias = use_bias, + BiasInitializer = bias_initializer, + KernelInitializer = kernel_initializer, + Trainable = trainable, + Name = name + }); + + return layer.Apply(inputs); + } + + + public ILayer EinsumDense(string equation, + Shape output_shape, + string bias_axes, + Activation activation = null, + IInitializer kernel_initializer= null, + IInitializer bias_initializer= null, + IRegularizer kernel_regularizer= null, + IRegularizer bias_regularizer= null, + IRegularizer activity_regularizer= null, + Action kernel_constraint= null, + Action bias_constraint= null) => + new EinsumDense(new EinsumDenseArgs() + { + Equation = equation, + OutputShape = output_shape, + BiasAxes = bias_axes, + Activation = activation, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + KernelRegularizer = kernel_regularizer, + BiasRegularizer = bias_regularizer, + ActivityRegularizer = activity_regularizer, + KernelConstraint = kernel_constraint, + BiasConstraint = bias_constraint + }); + + /// + /// Applies Dropout to the input. + /// The Dropout layer randomly sets input units to 0 with a frequency of rate at each step during training time, + /// which helps prevent overfitting.Inputs not set to 0 are scaled up by 1/(1 - rate) such that the sum over all inputs is unchanged. + /// + /// Float between 0 and 1. Fraction of the input units to drop. + /// 1D integer tensor representing the shape of the binary dropout mask that will be multiplied with the input. For instance, + /// if your inputs have shape (batch_size, timesteps, features) and you want the dropout mask to be the same for all timesteps, + /// you can use noise_shape=(batch_size, 1, features). + /// + /// An integer to use as random seed. + /// + public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null) + => new Dropout(new DropoutArgs + { + Rate = rate, + NoiseShape = noise_shape, + Seed = seed + }); + + /// + /// Turns positive integers (indexes) into dense vectors of fixed size. + /// This layer can only be used as the first layer in a model. + /// e.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding + /// + /// Size of the vocabulary, i.e. maximum integer index + 1. + /// Dimension of the dense embedding. + /// Initializer for the embeddings matrix (see keras.initializers). + /// + /// + public ILayer Embedding(int input_dim, + int output_dim, + IInitializer embeddings_initializer = null, + bool mask_zero = false, + Shape input_shape = null, + int input_length = -1) + => new Embedding(new EmbeddingArgs + { + InputDim = input_dim, + OutputDim = output_dim, + MaskZero = mask_zero, + InputShape = input_shape ?? input_length, + InputLength = input_length, + EmbeddingsInitializer = embeddings_initializer + }); + + /// + /// Flattens the input. Does not affect the batch size. + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, ..., channels) while channels_first corresponds to inputs with shape (batch, channels, ...). + /// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. + /// If you never set it, then it will be "channels_last". + /// + /// + public ILayer Flatten(string data_format = null) + => new Flatten(new FlattenArgs + { + DataFormat = data_format + }); + + /// + /// `Input()` is used to instantiate a Keras tensor. + /// Keras tensor is a TensorFlow symbolic tensor object, which we augment with certain attributes that allow us + /// to build a Keras model just by knowing the inputs and outputs of the model. + /// + /// A shape tuple not including the batch size. + /// An optional name string for the layer. Should be unique in a model (do not reuse the same name twice). It will be autogenerated if it isn't provided. + /// A boolean specifying whether the placeholder to be created is sparse. Only one of 'ragged' and 'sparse' can be True. + /// Note that, if sparse is False, sparse tensors can still be passed into the input - they will be densified with a default value of 0. + /// + /// A boolean specifying whether the placeholder to be created is ragged. Only one of 'ragged' and 'sparse' can be True. + /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. + /// + /// A tensor. + public KerasTensor Input(Shape shape = null, + int batch_size = -1, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null) + { + if(sparse && ragged) + { + throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`."); + } + + InputLayerArgs input_layer_config = new() + { + Name = name, + DType = dtype, + Sparse = sparse, + Ragged = ragged, + InputTensor = tensor, + // skip the `type_spec` + }; + + if(shape is not null && batch_input_shape is not null) + { + throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument " + + "to Input, not both at the same time."); + } + + if(batch_input_shape is null && shape is null && tensor is null && type_spec is null) + { + throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " + + "`shape` does not include the batch dimension."); + } + + if(batch_input_shape is not null) + { + shape = batch_input_shape["1:"]; + input_layer_config.BatchInputShape = batch_input_shape; + } + else + { + input_layer_config.BatchSize = batch_size; + input_layer_config.InputShape = shape; + } + + var input_layer = new InputLayer(input_layer_config); + + return input_layer.InboundNodes[0].Outputs; + } + + public ILayer InputLayer(Shape input_shape, + string name = null, + bool sparse = false, + bool ragged = false) + => new InputLayer(new InputLayerArgs + { + InputShape = input_shape, + Name = name, + Sparse = sparse, + Ragged = ragged + }); + + /// + /// Average pooling operation for spatial data. + /// + /// + /// + /// + /// + /// + public ILayer AveragePooling2D(Shape pool_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null) + => new AveragePooling2D(new AveragePooling2DArgs + { + PoolSize = pool_size ?? (2, 2), + Strides = strides, + Padding = padding, + DataFormat = data_format + }); + + /// + /// Max pooling operation for 1D temporal data. + /// + /// Integer, size of the max pooling window. + /// Integer, or null. Specifies how much the pooling window moves for each pooling step. If null, it will default to pool_size. + /// One of "valid" or "same" (case-insensitive). "valid" means no padding. + /// "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). + /// + /// + public ILayer MaxPooling1D(int? pool_size = null, + int? strides = null, + string padding = "valid", + string data_format = null) + => new MaxPooling1D(new MaxPooling1DArgs + { + PoolSize = pool_size ?? 2, + Strides = strides ?? (pool_size ?? 2), + Padding = padding, + DataFormat = data_format + }); + + /// + /// Max pooling operation for 2D spatial data. + /// Downsamples the input representation by taking the maximum value over the window defined by pool_size for each dimension along the features axis. + /// The window is shifted by strides in each dimension. The resulting output when using "valid" padding option has a shape(number of rows or columns) + /// of: output_shape = (input_shape - pool_size + 1) / strides) + /// The resulting output shape when using the "same" padding option is: output_shape = input_shape / strides + /// + /// + /// Integer or tuple of 2 integers, window size over which to take the maximum. + /// (2, 2) will take the max value over a 2x2 pooling window. If only one integer is specified, the same window length will be used for both dimensions. + /// + /// + /// Integer, tuple of 2 integers, or null. Strides values. Specifies how far the pooling window moves for each pooling step. + /// If null, it will default to pool_size. + /// + /// One of "valid" or "same" (case-insensitive). "valid" means no padding. + /// "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to + /// inputs with shape (batch, channels, height, width). + /// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. + /// If you never set it, then it will be "channels_last" + /// + public ILayer MaxPooling2D(Shape pool_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null) + => new MaxPooling2D(new MaxPooling2DArgs + { + PoolSize = pool_size ?? (2, 2), + Strides = strides, + Padding = padding, + DataFormat = data_format + }); + + /// + /// Max pooling layer for 2D inputs (e.g. images). + /// + /// The tensor over which to pool. Must have rank 4. + /// + /// Integer or tuple of 2 integers, window size over which to take the maximum. + /// (2, 2) will take the max value over a 2x2 pooling window. If only one integer is specified, the same window length will be used for both dimensions. + /// + /// + /// Integer, tuple of 2 integers, or null. Strides values. Specifies how far the pooling window moves for each pooling step. + /// If null, it will default to pool_size. + /// + /// One of "valid" or "same" (case-insensitive). "valid" means no padding. + /// "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. + /// + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to + /// inputs with shape (batch, channels, height, width). + /// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. + /// If you never set it, then it will be "channels_last" + /// A name for the layer + /// + public Tensor max_pooling2d(Tensor inputs, + int[] pool_size, + int[] strides, + string padding = "valid", + string data_format = "channels_last", + string name = null) + { + var layer = new MaxPooling2D(new MaxPooling2DArgs + { + PoolSize = pool_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + Name = name + }); + + return layer.Apply(inputs); + } + + public ILayer LayerNormalization(Axis? axis, + float epsilon = 1e-3f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null) + => new LayerNormalization(new LayerNormalizationArgs + { + Axis = axis ?? -1, + Epsilon = epsilon, + Center = center, + Scale = scale, + BetaInitializer = beta_initializer ?? tf.zeros_initializer + }); + + /// + /// Leaky version of a Rectified Linear Unit. + /// + /// Negative slope coefficient. + /// + public ILayer LeakyReLU(float alpha = 0.3f) + => new LeakyReLu(new LeakyReLuArgs + { + Alpha = alpha + }); + + + /// + /// Leaky version of a Rectified Linear Unit. + /// + /// Negative slope coefficient. + /// + public ILayer ReLU6() + => new ReLu6(); + + + public IRnnCell SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f) + => new SimpleRNNCell(new SimpleRNNCellArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Dropout = dropout, + RecurrentDropout = recurrent_dropout + }); + + public IRnnCell StackedRNNCells( + IEnumerable cells) + => new StackedRNNCells(cells.ToList(), new StackedRNNCellsArgs()); + + /// + /// + /// + /// Positive integer, dimensionality of the output space. + /// The name of the activation function to use. Default: hyperbolic tangent (tanh).. + /// + public ILayer SimpleRNN(int units, + string activation = "tanh", + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + bool return_sequences = false, + bool return_state = false) + => new SimpleRNN(new SimpleRNNArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + ReturnSequences = return_sequences, + ReturnState = return_state + }); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public ILayer RNN( + IRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(cell, new RNNArgs + { + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + public ILayer RNN( + IEnumerable cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(cell, new RNNArgs + { + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + + public IRnnCell LSTMCell(int uints, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2) + => new LSTMCell(new LSTMCellArgs + { + Units = uints, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UnitForgetBias = unit_forget_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + Implementation = implementation + }); + + /// + /// Long Short-Term Memory layer - Hochreiter 1997. + /// + /// Positive integer, dimensionality of the output space. + /// Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x). + /// Activation function to use for the recurrent step. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x). + /// Boolean (default True), whether the layer uses a bias vector. + /// Initializer for the kernel weights matrix, used for the linear transformation of the inputs. Default: glorot_uniform. + /// Initializer for the recurrent_kernel weights matrix, used for the linear transformation of the recurrent state. Default: orthogonal. + /// Initializer for the bias vector. Default: zeros. + /// Boolean (default True). If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force bias_initializer="zeros". This is recommended in Jozefowicz et al.. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. + /// + /// Boolean. Whether to return the last output. in the output sequence, or the full sequence. Default: False. + /// Whether to return the last state in addition to the output. Default: False. + /// Boolean (default false). If True, process the input sequence backwards and return the reversed sequence. + /// Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. + /// + /// The shape format of the inputs and outputs tensors. If True, the inputs and outputs will be in shape [timesteps, batch, feature], + /// whereas in the False case, it will be [batch, timesteps, feature]. Using time_major = True is a bit more efficient because it avoids transposes at the + /// beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. + /// + /// Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, + /// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. + /// + /// + public ILayer LSTM(int units, + Activation activation = null, + Activation recurrent_activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer recurrent_initializer = null, + IInitializer bias_initializer = null, + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool time_major = false, + bool unroll = false) + => new LSTM(new LSTMArgs + { + Units = units, + Activation = activation ?? keras.activations.Tanh, + RecurrentActivation = recurrent_activation ?? keras.activations.Sigmoid, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + Implementation = implementation, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + TimeMajor = time_major, + Unroll = unroll, + UnitForgetBias = unit_forget_bias + }); + + /// + /// Cell class for the GRU layer. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public IRnnCell GRUCell( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool reset_after = true) + => new GRUCell(new GRUCellArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UseBias = use_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + ResetAfter = reset_after + }); + + /// + /// Gated Recurrent Unit - Cho et al. 2014. + /// + /// Positive integer, dimensionality of the output space. + /// Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`). + /// Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`). + /// Boolean, (default `True`), whether the layer uses a bias vector. + /// Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. + /// Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. + /// Initializer for the bias vector. Default: `zeros`. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. + /// + /// Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`. + /// Boolean. Whether to return the last state in addition to the output. Default: `False`. + /// Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence. + /// Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. + /// Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, + /// The shape format of the `inputs` and `outputs` tensors. + /// GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible). + /// + public ILayer GRU( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false, + bool reset_after = true + ) + => new GRU(new GRUArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UseBias = use_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + TimeMajor = time_major, + Unroll = unroll, + ResetAfter = reset_after + }); + + public ILayer Bidirectional( + ILayer layer, + string merge_mode = "concat", + NDArray weights = null, + ILayer backward_layer = null) + => new Bidirectional(new BidirectionalArgs + { + Layer = layer, + MergeMode = merge_mode, + Weights = weights, + BackwardLayer = backward_layer + }); + + + /// + /// + /// + /// + /// + /// + /// + public ILayer Rescaling(float scale, + float offset = 0, + Shape input_shape = null) + => new Rescaling(new RescalingArgs + { + Scale = scale, + Offset = offset, + InputShape = input_shape + }); + + /// + /// + /// + /// + public ILayer Add() + => new Add(new AddArgs { }); + + /// + /// + /// + /// + public ILayer Subtract() + => new Subtract(new SubtractArgs { }); + + /// + /// Global max pooling operation for spatial data. + /// + /// + public ILayer GlobalAveragePooling2D() + => new GlobalAveragePooling2D(new GlobalAveragePooling2DArgs { }); + + /// + /// Global average pooling operation for temporal data. + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). + /// + /// + public ILayer GlobalAveragePooling1D(string data_format = "channels_last") + => new GlobalAveragePooling1D(new GlobalAveragePooling1DArgs { DataFormat = data_format }); + + /// + /// Global max pooling operation for spatial data. + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width). + /// + public ILayer GlobalAveragePooling2D(string data_format = "channels_last") + => new GlobalAveragePooling2D(new GlobalAveragePooling2DArgs { DataFormat = data_format }); + + /// + /// Global max pooling operation for 1D temporal data. + /// Downsamples the input representation by taking the maximum value over the time dimension. + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). + /// + /// + public ILayer GlobalMaxPooling1D(string data_format = "channels_last") + => new GlobalMaxPooling1D(new GlobalMaxPooling1DArgs { DataFormat = data_format }); + + /// + /// Global max pooling operation for spatial data. + /// + /// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. + /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width). + /// + public ILayer GlobalMaxPooling2D(string data_format = "channels_last") + => new GlobalMaxPooling2D(new GlobalMaxPooling2DArgs { DataFormat = data_format }); + + /// + /// Get an weights initializer from its name. + /// + /// The name of the initializer. One of zeros, ones, and glorot_uniform. + /// + IInitializer GetInitializerByName(string name) + => name switch + { + "glorot_uniform" => tf.glorot_uniform_initializer, + "zeros" => tf.zeros_initializer, + "ones" => tf.ones_initializer, + "orthogonal" => tf.orthogonal_initializer, + _ => tf.glorot_uniform_initializer + }; + + public ILayer CategoryEncoding(int num_tokens, string output_mode = "one_hot", bool sparse = false, NDArray count_weights = null) + => new CategoryEncoding(new CategoryEncodingArgs + { + NumTokens = num_tokens, + OutputMode = output_mode, + Sparse = sparse, + CountWeights = count_weights + }); + + public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false) + => new Normalization(new NormalizationArgs + { + InputShape = input_shape, + Axis = axis, + Mean = mean, + Variance = variance, + Invert = invert + }); + + + + + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Add.cs b/src/TensorFlowNET.Keras/Layers/Merging/Add.cs new file mode 100644 index 000000000..94c8c5918 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Add.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class Add : Merge + { + public Add(MergeArgs args) : base(args) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs new file mode 100644 index 000000000..fa82426ce --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer that concatenates a list of inputs. + /// + public class Concatenate : Merge + { + MergeArgs args; + int axis => args.Axis; + + public Concatenate(MergeArgs args) : base(args) + { + this.args = args; + } + + public override void build(KerasShapesWrapper input_shape) + { + /*var shape_set = new HashSet(); + var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); + for (var i = 0; i < reduced_inputs_shapes.Length; i++) + { + int seq = -1; + Shape shape = reduced_inputs_shapes[i].Where(x => + { + seq++; + return seq != i; + }).ToArray(); + shape_set.Add(shape); + }*/ + _buildInputShape = input_shape; + built = true; + } + + protected override Tensors _merge_function(Tensors inputs) + { + return keras.backend.concatenate(inputs, axis: axis); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs new file mode 100644 index 000000000..bcbb20d88 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public abstract class Merge : Layer + { + public Merge(MergeArgs args) : base(args) + { + + } + + public override void build(KerasShapesWrapper input_shape) + { + // output_shape = input_shape.dims[1^]; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return _merge_function(inputs); + } + + protected virtual Tensors _merge_function(Tensors inputs) + { + var output = inputs[0]; + foreach (var i in range(1, inputs.Length)) + output += inputs[i]; + return output; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs new file mode 100644 index 000000000..b6a1039ec --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Subtract : Merge + { + public Subtract(MergeArgs args) : base(args) + { + + } + + protected override Tensors _merge_function(Tensors inputs) + { + if (len(inputs) != 2) + throw new ValueError($"A `Subtract` layer should be called on exactly 2 inputs"); + return inputs[0] - inputs[1]; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs new file mode 100644 index 000000000..655581576 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -0,0 +1,302 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class BatchNormalization : Layer + { + BatchNormalizationArgs args; + + float momentum => args.Momentum; + float epsilon => args.Epsilon; + bool center => args.Center; + bool scale => args.Scale; + bool renorm => args.Renorm; + bool fused; + int[] axis; + string _data_format; + Shape kernel_size; + IInitializer beta_initializer => args.BetaInitializer; + IInitializer gamma_initializer => args.GammaInitializer; + IInitializer moving_mean_initializer => args.MovingMeanInitializer; + IInitializer moving_variance_initializer => args.MovingVarianceInitializer; + IRegularizer gamma_regularizer => args.GammaRegularizer; + IVariableV1 gamma; + IVariableV1 beta; + IVariableV1 moving_mean; + IVariableV1 moving_variance; + + public BatchNormalization(BatchNormalizationArgs args) : base(args) + { + this.args = args; + axis = args.Axis.dims.Select(x => (int)x).ToArray(); + } + + public override void build(KerasShapesWrapper input_shape) + { + var single_shape = input_shape.ToSingleShape(); + var ndims = single_shape.ndim; + foreach (var (idx, x) in enumerate(axis)) + if (x < 0) + args.Axis.dims[idx] = axis[idx] = ndims + x; + + fused = ndims == 4; + + if (fused) + { + if (Enumerable.SequenceEqual(axis, new int[] { 1 })) + _data_format = "NCHW"; + else if (Enumerable.SequenceEqual(axis, new int[] { 3 })) + _data_format = "NHWC"; + else + throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]"); + } + + var axis_to_dim = new Dictionary(); + foreach (var x in axis) + axis_to_dim[x] = (int)single_shape[x]; + + inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); + var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; + var param_shape = inputSpec.AllAxisDim; + + if (scale) + gamma = add_weight("gamma", + param_shape, + dtype: param_dtype, + initializer: gamma_initializer, + trainable: true); + else + throw new NotImplementedException("add_weight gamma"); + + if (center) + beta = add_weight("beta", + param_shape, + dtype: param_dtype, + initializer: beta_initializer, + trainable: true); + else + throw new NotImplementedException("add_weight beta"); + + moving_mean = add_weight("moving_mean", + param_shape, + dtype: param_dtype, + initializer: moving_mean_initializer, + synchronization: VariableSynchronization.OnRead, + aggregation: VariableAggregation.Mean, + trainable: false); + + moving_variance = add_weight("moving_variance", + shape: param_shape, + dtype: param_dtype, + initializer: moving_variance_initializer, + synchronization: VariableSynchronization.OnRead, + aggregation: VariableAggregation.Mean, + trainable: false); + + if (renorm) + throw new NotImplementedException("build when renorm is true"); + + built = true; + _buildInputShape = input_shape; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + (Tensor, Tensor) _moments(Tensors inputs, int[] reduction_axes, bool keep_dims) + { + var (mean, variance) = _calculate_mean_and_var(inputs, reduction_axes, keep_dims); + if (_support_zero_size_input()) + throw new NotImplementedException(""); + return (mean, variance); + } + + (Tensor, Tensor) _calculate_mean_and_var(Tensors inputs, int[] reduction_axes, bool keep_dims) + { + return nn_impl.moments(inputs, reduction_axes, keep_dims: keep_dims); + } + + bool _support_zero_size_input() + { + return false; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor outputs = null; + var training_tensor = training == null + ? tf.placeholder(tf.@bool, Shape.Scalar) + : tf.logical_and(training.Value, Trainable); + if (fused) + { + // var training = tf.convert_to_tensor(training); + outputs = _fused_batch_norm(inputs, training: training_tensor); + return outputs; + } + + var inputs_dtype = inputs.dtype.as_base_dtype(); + var input_shape = inputs.shape; + var ndims = len(input_shape); + var reduction_axes = range(ndims).Where(x => !axis.Contains(x)).ToArray(); + + // Broadcasting only necessary for single-axis batch norm where the axis is + // not the last dimension + var broadcast_shape = range(ndims).Select(x => 1).ToArray(); + broadcast_shape[axis[0]] = (int)input_shape.dims[axis[0]]; + + var (scale, offset) = (gamma, beta); + var training_value = tf_utils.constant_value(training_tensor); + + Tensor mean; + Tensor variance; + if (training_value.HasValue && training_value.Value == false) + { + (mean, variance) = (moving_mean.AsTensor(), moving_variance.AsTensor()); + } + else + { + var keep_dims = len(axis) > 1; + (mean, variance) = _moments(inputs, reduction_axes, keep_dims: keep_dims); + mean = tf_utils.smart_cond(training_tensor, + () => new[] { mean }, + () => new[] { ops.convert_to_tensor(moving_mean) }).FirstOrDefault(); + + variance = tf_utils.smart_cond(training_tensor, + () => new[] { variance }, + () => new[] { ops.convert_to_tensor(moving_variance) }).FirstOrDefault(); + + var (new_mean, new_variance) = (mean, variance); + } + + mean = math_ops.cast(mean, inputs.dtype); + variance = math_ops.cast(variance, inputs.dtype); + var offset_tensor = math_ops.cast(offset, inputs.dtype); + var scale_tensor = math_ops.cast(scale, inputs.dtype); + outputs = nn_impl.batch_normalization(inputs, mean, variance, + offset_tensor, scale_tensor, epsilon); + // If some components of the shape got lost due to adjustments, fix that. + outputs.shape = input_shape; + return outputs; + } + + private Tensor _fused_batch_norm(Tensor inputs, Tensor training) + { + Shape input_batch_size = null; + var use_fused_avg_updates = true; + float exponential_avg_factor = 0; + if (use_fused_avg_updates) + exponential_avg_factor = 1.0f - momentum; + + Func _fused_batch_norm_training = () => + { + return tf.nn.fused_batch_norm( + inputs, + gamma.AsTensor(), + beta.AsTensor(), + mean: moving_mean.AsTensor(), + variance: moving_variance.AsTensor(), + epsilon: epsilon, + is_training: true, + data_format: _data_format, + exponential_avg_factor: exponential_avg_factor); + }; + + Func _fused_batch_norm_inference = () => + { + return tf.nn.fused_batch_norm( + inputs, + gamma.AsTensor(), + beta.AsTensor(), + mean: moving_mean.AsTensor(), + variance: moving_variance.AsTensor(), + epsilon: epsilon, + is_training: false, + data_format: _data_format); + }; + + if (use_fused_avg_updates && input_batch_size != null) + throw new NotImplementedException(""); + + var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); + var (output, mean, variance) = (results[0], results[1], results[2]); + var training_value = tf_utils.constant_value(training); + + if (!training_value.HasValue || (training_value.HasValue && training_value.Value)) + { + Tensor momentum_tensor = null; + if (!use_fused_avg_updates) + { + if (training_value == null) + momentum_tensor = tf_utils.smart_cond(training, + () => new float[] { momentum }, + () => new float[] { 1.0f })[0]; + else + momentum_tensor = ops.convert_to_tensor(momentum); + } + + if (use_fused_avg_updates) + _assign_new_value(moving_mean, mean); + else + _assign_moving_average(moving_variance, variance, momentum_tensor); + + if (use_fused_avg_updates) + _assign_new_value(moving_variance, variance); + else + _assign_moving_average(moving_variance, variance, momentum_tensor); + + // var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor); + // var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor); + // add_update(new Tensor[] { mean_update }, inputs: true); + // add_update(new Tensor[] { variance_update }, inputs: true); + } + + return output; + } + + void _assign_new_value(IVariableV1 variable, Tensor value) + { + tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope => + { + // var cm = ops.colocate_with(variable); + variable.assign_lazy_load(value, name: scope); + }); + } + + void _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum) + { + tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope => + { + // var cm = ops.colocate_with(variable); + var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay"); + var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay; + variable.assign_sub_lazy_load(update_delta, name: scope); + }); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs new file mode 100644 index 000000000..69bdfbaa0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -0,0 +1,178 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class LayerNormalization : Layer + { + LayerNormalizationArgs args; + + float epsilon => args.Epsilon; + bool center => args.Center; + bool scale => args.Scale; + bool _fused; + int[] axis; + string _data_format; + Shape kernel_size; + IInitializer beta_initializer => args.BetaInitializer; + IInitializer gamma_initializer => args.GammaInitializer; + IRegularizer gamma_regularizer => args.GammaRegularizer; + IVariableV1 gamma; + IVariableV1 beta; + IVariableV1 moving_mean; + IVariableV1 moving_variance; + + public LayerNormalization(LayerNormalizationArgs args) : base(args) + { + this.args = args; + axis = args.Axis.axis; + } + + public override void build(KerasShapesWrapper input_shape) + { + var single_shape = input_shape.ToSingleShape(); + var ndims = single_shape.ndim; + foreach (var (idx, x) in enumerate(axis)) + if (x < 0) + axis[idx] = ndims + x; + + var axis_to_dim = new Dictionary(); + foreach (var x in axis) + axis_to_dim[x] = (int)single_shape[x]; + + inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); + var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; + var param_shape = inputSpec.AllAxisDim; + + if (scale) + gamma = add_weight("gamma", + param_shape, + dtype: param_dtype, + initializer: gamma_initializer, + trainable: true); + + if (center) + beta = add_weight("beta", + param_shape, + dtype: param_dtype, + initializer: beta_initializer, + trainable: true); + + _fused = _fused_can_be_used(ndims); + + built = true; + _buildInputShape = input_shape; + } + + bool _fused_can_be_used(int ndims) + { + var can_use_fused = false; + if (axis.Last() == ndims - 1 && axis.Last() - axis[0] == len(axis) - 1) + can_use_fused = true; + if (epsilon < 1.001e-5 || DType != tf.float32) + can_use_fused = false; + return can_use_fused; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor outputs = null; + var inputs_dtype = inputs.dtype.as_base_dtype(); + var input_shape = inputs.shape; + var ndims = len(input_shape); + var broadcast_shape = range(ndims).Select(x => 1).ToArray(); + foreach (var dim in axis) + broadcast_shape[dim] = input_shape.as_int_list()[dim]; + + Func _broadcast = v => + { + if (v.shape.ndim != ndims && !axis.SequenceEqual(new int[] { ndims - 1 })) + return tf.reshape(v.AsTensor(), broadcast_shape); + return v.AsTensor(); + }; + + if (_fused) + { + var tensor_shape = tf.shape(inputs); + var pre_dim = tf.constant(1); + var in_dim = tf.constant(1); + foreach (var dim in range(ndims)) + { + var dim_tensor = tensor_shape[dim]; + if (dim < axis[0]) + pre_dim = pre_dim * dim_tensor; + else + in_dim = in_dim * dim_tensor; + } + inputs = tf.reshape(inputs, new object[] { 1, pre_dim, in_dim, 1 }); + + var scale = tf.ones(new Shape((int)pre_dim), dtype: DType); + var offset = tf.zeros(new Shape((int)pre_dim), dtype: DType); + + outputs = tf.nn.fused_batch_norm( + inputs, + scale: scale, + offset: offset, + epsilon: epsilon, + data_format: "NCHW")[0]; + + outputs = tf.reshape(outputs, tensor_shape); + + (scale, offset) = (_broadcast(gamma), _broadcast(beta)); + + outputs = outputs * tf.cast(scale, outputs.dtype); + outputs = outputs + tf.cast(offset, outputs.dtype); + } + else + { + var input_dtype = inputs.dtype; + if ((input_dtype == tf.float16) && DType == tf.float32) inputs = tf.cast(inputs, tf.float32); + (Tensor mean, Tensor variance) = tf.nn.moments(inputs, axis, keep_dims: true); + + (Tensor scale, Tensor offset) = (_broadcast(gamma), _broadcast(beta)); + + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset: offset, + scale: scale, + variance_epsilon: epsilon); + + outputs = tf.cast(outputs, input_dtype); + } + // If some components of the shape got lost due to adjustments, fix that. + outputs.shape = input_shape; + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs new file mode 100644 index 000000000..987b56bc4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -0,0 +1,176 @@ +/***************************************************************************** + Copyright 2023 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + public class Normalization : PreprocessingLayer + { + NormalizationArgs _args; + + int[] axis; + int[] _reduce_axis; + IVariableV1 adapt_mean, adapt_variance, count; + Tensor mean, variance; + Shape _broadcast_shape; + float? input_mean, input_variance; + TF_DataType compute_dtype = tf.float32; + + public Normalization(NormalizationArgs args) : base(args) + { + _args = args; + if (args.Axis == null) + { + axis = new int[0]; + } + else + { + axis = args.Axis.axis; + } + input_mean = args.Mean; + input_variance = args.Variance; + } + + public override void build(KerasShapesWrapper input_shape) + { + base.build(input_shape); + var single_shape = input_shape.ToSingleShape(); + var ndim = single_shape.ndim; + foreach (var (idx, x) in enumerate(axis)) + if (x < 0) + axis[idx] = ndim + x; + + var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray(); + _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); + var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); + // Broadcast any reduced axes. + _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? single_shape.dims[d] : 1).ToArray()); + var mean_and_var_shape = _keep_axis.Select(d => single_shape.dims[d]).ToArray(); + + var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; + var param_shape = input_shape; + + if(input_mean == null) + { + adapt_mean = add_weight("mean", + mean_and_var_shape, + dtype: tf.float32, + initializer: tf.zeros_initializer, + trainable: false); + + adapt_variance = add_weight("variance", + mean_and_var_shape, + dtype: tf.float32, + initializer: tf.ones_initializer, + trainable: false); + + count = add_weight("count", + Shape.Scalar, + dtype: tf.int64, + initializer: tf.zeros_initializer, + trainable: false); + + finalize_state(); + } + else + { + mean = input_mean * np.ones(mean_and_var_shape); + variance = input_variance * np.ones(mean_and_var_shape); + mean = tf.reshape(mean, _broadcast_shape); + variance = tf.reshape(variance, _broadcast_shape); + mean = tf.cast(mean, compute_dtype); + variance = tf.cast(variance, compute_dtype); + } + } + + public override void reset_state() + { + if (input_mean != null && !built) + { + return; + } + adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor())); + adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor())); + count.assign(tf.zeros_like(count.AsTensor())); + } + + public override void finalize_state() + { + if (input_mean != null && !built) + { + return; + } + mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape); + variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape); + } + + public override void update_state(Tensor data) + { + data = tf.cast(data, adapt_mean.dtype); + var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis); + var batch_shape = tf.shape(data, out_type: count.dtype); + + var batch_count = constant_op.constant(1L); + if (_reduce_axis != null) + { + var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis)); + batch_count = tf.reduce_prod(batch_reduce_shape); + } + var total_count = batch_count + count.AsTensor(); + var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast( + total_count, dtype: compute_dtype); + var existing_weight = 1.0 - batch_weight; + var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight; + + var total_variance = ( + adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean) + ) * existing_weight + ( + batch_variance + tf.square(batch_mean - total_mean) + ) * batch_weight; + adapt_mean.assign(total_mean); + adapt_variance.assign(total_variance); + count.assign(total_count); + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + public override void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + base.adapt(data, batch_size: batch_size, steps: steps); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (_args.Invert) + { + return mean + ( + inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon()) + ); + } + else + { + return (inputs - mean) / tf.maximum( + tf.sqrt(variance), keras.backend.epsilon()); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs new file mode 100644 index 000000000..fbdb557cc --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Operations; + +namespace Tensorflow.Keras.Layers +{ + public class AveragePooling2D : Pooling2D + { + public AveragePooling2D(AveragePooling2DArgs args) + : base(args) + { + args.PoolFunction = new AveragePoolFunction(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs new file mode 100644 index 000000000..ffaabec97 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class GlobalAveragePooling1D : GlobalPooling1D + { + public GlobalAveragePooling1D(Pooling1DArgs args) + : base(args) + { + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (data_format == "channels_last") + return math_ops.reduce_mean(inputs, 1, false); + else + return math_ops.reduce_mean(inputs, 2, false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs new file mode 100644 index 000000000..e06665173 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class GlobalAveragePooling2D : GlobalPooling2D + { + public GlobalAveragePooling2D(Pooling2DArgs args) + : base(args) + { + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (data_format == "channels_last") + return math_ops.reduce_mean(inputs, (1, 2), false); + else + return math_ops.reduce_mean(inputs, (2, 3), false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs new file mode 100644 index 000000000..15695e8a7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class GlobalMaxPooling1D : GlobalPooling1D + { + public GlobalMaxPooling1D(Pooling1DArgs args) + : base(args) + { + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (data_format == "channels_last") + return math_ops.reduce_max(inputs, 1, false); + else + return math_ops.reduce_max(inputs, 2, false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs new file mode 100644 index 000000000..76db858da --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class GlobalMaxPooling2D : GlobalPooling2D + { + public GlobalMaxPooling2D(Pooling2DArgs args) + : base(args) + { + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (data_format == "channels_last") + return math_ops.reduce_max(inputs, (1, 2), false); + else + return math_ops.reduce_max(inputs, (2, 3), false); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs new file mode 100644 index 000000000..04fadeeb8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + public abstract class GlobalPooling1D : Layer + { + Pooling1DArgs args; + protected string data_format => args.DataFormat; + protected InputSpec input_spec; + + public GlobalPooling1D(Pooling1DArgs args) : base(args) + { + this.args = args; + args.DataFormat = conv_utils.normalize_data_format(data_format); + input_spec = new InputSpec(ndim: 3); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling2D.cs new file mode 100644 index 000000000..e944aef05 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling2D.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + public abstract class GlobalPooling2D : Layer + { + Pooling2DArgs args; + protected string data_format => args.DataFormat; + protected InputSpec input_spec; + + public GlobalPooling2D(Pooling2DArgs args) : base(args) + { + this.args = args; + args.DataFormat = conv_utils.normalize_data_format(data_format); + input_spec = new InputSpec(ndim: 4); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs new file mode 100644 index 000000000..c1deb9bfd --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Operations; + +namespace Tensorflow.Keras.Layers +{ + public class MaxPooling1D : Pooling1D + { + public MaxPooling1D(Pooling1DArgs args) + : base(args) + { + args.PoolFunction = new MaxPoolFunction(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling2D.cs new file mode 100644 index 000000000..90a45cb10 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling2D.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Operations; + +namespace Tensorflow.Keras.Layers +{ + public class MaxPooling2D : Pooling2D + { + public MaxPooling2D(MaxPooling2DArgs args) + : base(args) + { + args.PoolFunction = new MaxPoolFunction(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs new file mode 100644 index 000000000..81a340199 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs @@ -0,0 +1,69 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Pooling1D : Layer + { + Pooling1DArgs args; + InputSpec input_spec; + + public Pooling1D(Pooling1DArgs args) + : base(args) + { + this.args = args; + args.Padding = conv_utils.normalize_padding(args.Padding); + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + input_spec = new InputSpec(ndim: 3); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; + inputs = tf.expand_dims(inputs, pad_axis); + int[] pool_shape = new int[] { args.PoolSize, 1 }; + int[] strides = new int[] { args.Strides, 1 }; + var ndim = inputs[0].ndim; + + if (args.DataFormat == "channels_last") + { + pool_shape = new int[] { 1 }.Concat(pool_shape).Concat(new int[] { 1 }).ToArray(); + strides = new int[] { 1 }.Concat(strides).Concat(new int[] { 1 }).ToArray(); + } + else + { + pool_shape = new int[] { 1, 1 }.Concat(pool_shape).ToArray(); + strides = new int[] { 1, 1 }.Concat(strides).ToArray(); + } + + var outputs = args.PoolFunction.Apply( + inputs, + ksize: pool_shape, + strides: strides, + padding: args.Padding.ToUpper(), + data_format: conv_utils.convert_data_format(args.DataFormat, ndim)); + + return tf.squeeze(outputs, pad_axis); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs new file mode 100644 index 000000000..f83f1e152 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs @@ -0,0 +1,65 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class Pooling2D : Layer + { + Pooling2DArgs args; + InputSpec input_spec; + + public Pooling2D(Pooling2DArgs args) + : base(args) + { + this.args = args; + args.PoolSize = conv_utils.normalize_tuple(args.PoolSize, 2, "pool_size"); + args.Strides = conv_utils.normalize_tuple(args.Strides ?? args.PoolSize, 2, "strides"); + args.Padding = conv_utils.normalize_padding(args.Padding); + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + input_spec = new InputSpec(ndim: 4); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + int[] pool_shape; + int[] strides; + if (args.DataFormat == "channels_last") + { + pool_shape = new int[] { 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1], 1 }; + strides = new int[] { 1, (int)args.Strides.dims[0], (int)args.Strides.dims[1], 1 }; + } + else + { + pool_shape = new int[] { 1, 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1] }; + strides = new int[] { 1, 1, (int)args.Strides.dims[0], (int)args.Strides.dims[1] }; + } + + var outputs = args.PoolFunction.Apply( + inputs, + ksize: pool_shape, + strides: strides, + padding: args.Padding.ToUpper(), + data_format: conv_utils.convert_data_format(args.DataFormat, 4)); + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs new file mode 100644 index 000000000..20d2a53d5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs @@ -0,0 +1,75 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +namespace Tensorflow.Keras.Layers +{ + /// + /// This layer provides options for condensing data into a categorical encoding when the total number of tokens are known in advance. + /// + public class CategoryEncoding : Layer + { + CategoryEncodingArgs args; + + public CategoryEncoding(CategoryEncodingArgs args) : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var depth = args.NumTokens; + var max_value = tf.reduce_max(inputs); + var min_value = tf.reduce_min(inputs); + + /*var condition = tf.logical_and(tf.greater(tf.cast(constant_op.constant(depth), max_value.dtype), max_value), + tf.greater_equal(min_value, tf.cast(constant_op.constant(0), min_value.dtype)));*/ + + var bincounts = encode_categorical_inputs(inputs, args.OutputMode, depth, args.DType, + sparse: args.Sparse, + count_weights: args.CountWeights); + + if(args.OutputMode != "tf_idf") + { + return bincounts; + } + + return inputs; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + + Tensors encode_categorical_inputs(Tensor inputs, string output_mode, int depth, + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool sparse = false, + Tensor count_weights = null) + { + bool binary_output = false; + if (output_mode == "one_hot") + { + binary_output = true; + if (inputs.shape[-1] != 1) + { + inputs = tf.expand_dims(inputs, -1); + } + } + else if (output_mode == "multi_hot") + { + binary_output = true; + } + + var depth_tensor = constant_op.constant(depth); + var result = tf.math.bincount(inputs, + weights: count_weights, + minlength: depth_tensor, + maxlength: depth_tensor, + dtype: dtype, + axis: -1, + binary_output: binary_output); + + return result; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs new file mode 100644 index 000000000..5e02f5626 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class IndexLookup : CombinerPreprocessingLayer + { + public IndexLookup(int max_tokens = -1, + int num_oov_indices = 1, + string mask_token = "", + string oov_token = "[UNK]", + string encoding = "utf-8", + bool invert = false) : base(new PreprocessingLayerArgs()) + { + var num_mask_tokens = mask_token == null ? 0 : 1; + var vocab_size = max_tokens - (num_oov_indices + num_mask_tokens); + combiner = new IndexLookupCombiner(vocab_size, mask_token); + } + + public override void adapt(IDatasetV2 data, bool reset_state = true) + { + if (!reset_state) + throw new ValueError("IndexLookup does not support streaming adapts."); + base.adapt(data, reset_state); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs new file mode 100644 index 000000000..e2de669d8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class IndexLookupAccumulator : IAccumulator + { + public Dictionary CountDict { get; set; } + public IndexLookupAccumulator() + { + CountDict = new Dictionary(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs new file mode 100644 index 000000000..ac4c5dc95 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Combiner for the IndexLookup preprocessing layer. + /// + public class IndexLookupCombiner : ICombiner + { + int _vocab_size; + string _mask_value; + + public IndexLookupCombiner(int vocab_size = -1, string mask_value = null) + { + _vocab_size = vocab_size; + _mask_value = mask_value; + } + + public void Compute(Tensor values, IAccumulator accumulator = null) + { + if(accumulator == null) + { + accumulator = new IndexLookupAccumulator(); + } + } + + public void Deserialize() + { + throw new NotImplementedException(); + } + + public void Extract() + { + throw new NotImplementedException(); + } + + public void Merge() + { + throw new NotImplementedException(); + } + + public IAccumulator Restore() + { + throw new NotImplementedException(); + } + + public void Serialize() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs new file mode 100644 index 000000000..a032dcd09 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs @@ -0,0 +1,97 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Engine.DataAdapters; + +namespace Tensorflow.Keras.Layers +{ + public class PreprocessingLayer : Layer + { + bool _is_compiled; + bool _is_adapted; + IVariableV1 _steps_per_execution; + PreprocessingLayerArgs _args; + public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) + { + _args = args; + } + + public override void adapt(Tensor data, int? batch_size = null, int? steps = null) + { + if (!_is_compiled) + { + compile(); + } + + if (built) + { + reset_state(); + } + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(data), + BatchSize = _args.BatchSize, + Epochs = 1, + StepsPerExecution = _steps_per_execution + }); + + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach (var _ in data_handler.steps()) + { + run_step(iterator); + } + } + finalize_state(); + _is_adapted = true; + } + + private void run_step(OwnedIterator iterator) + { + var data = iterator.next(); + _adapt_maybe_build(data[0]); + update_state(data[0]); + } + + public virtual void reset_state() + { + + } + + public virtual void finalize_state() + { + + } + + public virtual void update_state(Tensor data) + { + + } + + private void _adapt_maybe_build(Tensor data) + { + if (!built) + { + var data_shape = data.shape; + var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); + _args.BatchInputShape = BatchInputShape ?? new Saving.KerasShapesWrapper(new Shape(data_shape_nones)); + build(new Saving.KerasShapesWrapper(data_shape)); + built = true; + } + } + + public void compile(bool run_eagerly = false, int steps_per_execution = 1) + { + _steps_per_execution = tf.Variable( + steps_per_execution, + dtype: tf.int64, + aggregation: VariableAggregation.OnlyFirstReplica + ); + + _is_compiled = true; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs new file mode 100644 index 000000000..7fa367eea --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs @@ -0,0 +1,33 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Multiply inputs by `scale` and adds `offset`. + /// + public class Rescaling : Layer + { + RescalingArgs args; + Tensor scale; + Tensor offset; + + public Rescaling(RescalingArgs args) : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + scale = constant_op.constant(args.Scale, args.DType); + offset = constant_op.constant(args.Offset, args.DType); + return math_ops.cast(inputs, args.DType) * scale + offset; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs new file mode 100644 index 000000000..081966ad4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs @@ -0,0 +1,40 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Resize the batched image input to target height and width. + /// The input should be a 4-D tensor in the format of NHWC. + /// + public class Resizing : PreprocessingLayer + { + ResizingArgs args; + public Resizing(ResizingArgs args) : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + return new Shape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]); + } + + public static Resizing from_config(JObject config) + { + var args = JsonConvert.DeserializeObject(config.ToString()); + args.IsFromConfig = true; + return new Resizing(args); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs new file mode 100644 index 000000000..616af1c6c --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Maps strings from a vocabulary to integer indices. + /// + class StringLookup : IndexLookup + { + public StringLookup(int max_tokens = -1, + int num_oov_indices = 1, + string mask_token = "", + string[] vocabulary = null, + string oov_token = "[UNK]", + string encoding = "utf-8", + bool invert = false) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs new file mode 100644 index 000000000..6c504006a --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class TextVectorization : CombinerPreprocessingLayer + { + TextVectorizationArgs args; + IndexLookup _index_lookup_layer; + + public TextVectorization(TextVectorizationArgs args) + : base(args) + { + this.args = args; + args.DType = TF_DataType.TF_STRING; + // string standardize = "lower_and_strip_punctuation", + + var mask_token = args.OutputMode == "int" ? "" : null; + _index_lookup_layer = new StringLookup(max_tokens: args.MaxTokens, + mask_token: mask_token, + vocabulary: args.Vocabulary); + } + + /// + /// Fits the state of the preprocessing layer to the dataset. + /// + /// + /// + public override void adapt(IDatasetV2 data, bool reset_state = true) + { + var shape = data.output_shapes[0]; + if (shape.ndim == 1) + data = data.map(tensor => array_ops.expand_dims(tensor, -1)); + build(new KerasShapesWrapper(data.variant_tensor.shape)); + var preprocessed_inputs = data.map(_preprocess); + _index_lookup_layer.adapt(preprocessed_inputs); + } + + public override void build(KerasShapesWrapper input_shape) + { + base.build(input_shape); + } + + Tensors _preprocess(Tensors inputs) + { + Tensor input_tensor = null; + if (args.Standardize != null) + input_tensor = args.Standardize(inputs); + if (!string.IsNullOrEmpty(args.Split)) + { + if (inputs.shape.ndim > 1) + input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 }); + if (args.Split == "whitespace") + input_tensor = tf.strings.split(input_tensor); + } + return input_tensor; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs new file mode 100644 index 000000000..ada1851ce --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs @@ -0,0 +1,42 @@ +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Dropout : Layer + { + DropoutArgs args; + + public Dropout(DropoutArgs args) + : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (training == null) + training = false; + + var output = tf_utils.smart_cond(training.Value, + () => tf.nn.dropout(inputs, + noise_shape: get_noise_shape(inputs), + seed: args.Seed, + rate: args.Rate), + () => array_ops.identity(inputs)); + + return output; + } + + Tensor get_noise_shape(Tensor inputs) + { + if (args.NoiseShape == null) + return null; + + return null; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs new file mode 100644 index 000000000..7d5385e6f --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -0,0 +1,66 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + public class Cropping1D : Layer + { + Cropping1DArgs args; + public Cropping1D(Cropping1DArgs args) : base(args) + { + this.args = args; + } + + public override void build(KerasShapesWrapper input_shape) + { + if (args.cropping.rank != 1) + { + // throw an ValueError exception + throw new ValueError(""); + } + else if (args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1) + { + throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); + } + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor output = inputs; + if (output.rank != 3) + { + // throw an ValueError exception + throw new ValueError("Expected dim=3, found dim=" + output.rank); + } + if (args.cropping.shape[0] == 1) + { + int crop_start = args.cropping[0]; + output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_start), new Slice()]; + } + else + { + int crop_start = args.cropping[0], crop_end = args.cropping[1]; + output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_end), new Slice()]; + } + return output; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape[0] == 1) + { + int crop = args.cropping[0]; + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop * 2), (int)input_shape[2]); + } + else + { + int crop_start = args.cropping[0], crop_end = args.cropping[1]; + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop_start - crop_end), (int)input_shape[2]); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs new file mode 100644 index 000000000..4a5c6eabc --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -0,0 +1,142 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Crop the input along axis 1 and 2. + /// For example: + /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) + /// + public class Cropping2D : Layer + { + Cropping2DArgs args; + public Cropping2D(Cropping2DArgs args) : base(args) + { + this.args = args; + } + public override void build(KerasShapesWrapper input_shape) + { + built = true; + _buildInputShape = input_shape; + } + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor output = inputs; + if (output.rank != 4) + { + // throw an ValueError exception + throw new ValueError("Expected dim=4, found dim=" + output.rank); + } + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop)]; + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0]; + int crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2)]; + } + } + else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2) + { + int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x_start, (int)output.shape[1] - x_end), + new Slice(y_start, (int)output.shape[2] - y_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x_start, (int)output.shape[2] - x_end), + new Slice(y_start, (int)output.shape[3] - y_end) + ]; + } + } + return output; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2); + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2); + } + } + else if (args.cropping.shape == new Shape(2, 2)) + { + int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; + int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end, + (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], + (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs new file mode 100644 index 000000000..83f86c6fc --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -0,0 +1,152 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Similar to copping 2D + /// + public class Cropping3D : Layer + { + Cropping3DArgs args; + public Cropping3D(Cropping3DArgs args) : base(args) + { + this.args = args; + } + + public override void build(KerasShapesWrapper input_shape) + { + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor output = inputs; + if (output.rank != 5) + { + // throw an ValueError exception + throw new ValueError("Expected dim=5, found dim=" + output.rank); + } + + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice(crop, (int)output.shape[4] - crop)]; + } + + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_1 = args.cropping[0]; + var crop_2 = args.cropping[1]; + var crop_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice(crop_3, (int)output.shape[3] - crop_3), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2), + new Slice(crop_3, (int)output.shape[4] - crop_3)]; + } + } + else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x, (int)output.shape[1] - x_end), + new Slice(y, (int)output.shape[2] - y_end), + new Slice(z, (int)output.shape[3] - z_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x, (int)output.shape[2] - x_end), + new Slice(y, (int)output.shape[3] - y_end), + new Slice(z, (int)output.shape[4] - z_end) + ]; + } + } + return output; + } + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2); + } + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_start_1 = args.cropping[0]; + var crop_start_2 = args.cropping[1]; + var crop_start_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2); + } + } + else if (args.cropping.shape == new Shape(3, 2)) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs new file mode 100644 index 000000000..a6192849d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -0,0 +1,63 @@ +using System; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Framework; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Flatten : Layer + { + FlattenArgs args; + InputSpec input_spec; + bool _channels_first; + + public Flatten(FlattenArgs args) + : base(args) + { + this.args = args; + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + input_spec = new InputSpec(min_ndim: 1); + _channels_first = args.DataFormat == "channels_first"; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (_channels_first) + { + throw new NotImplementedException(""); + } + + if (tf.executing_eagerly()) + { + return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); + } + else + { + var input_shape = inputs.shape; + var rank = inputs.shape.ndim; + if (rank == 1) + return array_ops.expand_dims(inputs, axis: 1); + var batch_dim = tensor_shape.dimension_value(input_shape[0]); + if (batch_dim != -1) + { + return array_ops.reshape(inputs, new[] { batch_dim, -1 }); + } + + var non_batch_dims = ((int[])input_shape).Skip(1).ToArray(); + var num = 1; + if (non_batch_dims.Length > 0) + { + for (var i = 0; i < non_batch_dims.Length; i++) + { + num *= non_batch_dims[i]; + } + } + return array_ops.reshape(inputs, new[] { inputs.shape[0], num }); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs new file mode 100644 index 000000000..7fdb816bf --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -0,0 +1,49 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers { + public class Permute : Layer + { + int[] dims, permute; + public Permute(PermuteArgs args) : base(args) + { + this.dims = args.dims; + } + public override void build(KerasShapesWrapper input_shape) + { + var single_shape = input_shape.ToSingleShape(); + var rank = single_shape.rank; + if (dims.Length != rank - 1) + { + throw new ValueError("Dimensions must match."); + } + permute = new int[single_shape.rank]; + dims.CopyTo(permute, 1); + built = true; + _buildInputShape = input_shape; + } + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + Tensor outputs = inputs; + return tf.transpose(outputs, new Axis(permute)); + } + public override Shape ComputeOutputShape(Shape input_shape) + { + Shape output_shape = new Shape(input_shape.dims); + for (int i = 0; i < dims.Length; i += 1) + { + var d = dims[i]; + var target_dim = input_shape[d]; + output_shape[i + 1] = target_dim; + } + return output_shape; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs new file mode 100644 index 000000000..4b3d30e29 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -0,0 +1,55 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using System.Collections.Generic; +using System; +using System.Linq; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer that reshapes inputs into the given shape. + /// + public class Reshape : Layer + { + ReshapeArgs args; + public Reshape(ReshapeArgs args) + : base(args) + { + this.args = args; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var shapes = new List(); + shapes.Add(array_ops.shape(inputs)[0]); + var dtype = shapes[0].dtype; + if (args.TargetShapeObjects != null) + // shapes.AddRange(args.TargetShapeObjects); + throw new NotImplementedException(""); + if (args.TargetShape != null) + shapes.AddRange(args.TargetShape.dims.Select(x => constant_op.constant(x, dtype))); + var shape = ops.convert_to_tensor(shapes); + + var result = array_ops.reshape(inputs, shape); + if (!tf.Context.executing_eagerly()) + result.shape = ComputeOutputShape(inputs.shape); + return result; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + if (input_shape.dims.Skip(1).Contains(-1)) + { + throw new NotImplementedException(""); + } + else + { + input_shape = new Shape(input_shape.dims[0]); + var output_shape = input_shape.concatenate(args.TargetShape.dims); + return output_shape; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling1D.cs new file mode 100644 index 000000000..3bc8d6c6b --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling1D.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + + +namespace Tensorflow.Keras.Layers +{ + /// + /// Upsampling layer for 1D inputs. + /// + public class UpSampling1D : Layer + { + UpSampling1DArgs args; + int size; + + public UpSampling1D(UpSampling1DArgs args) : base(args) + { + this.args = args; + size = args.Size; + inputSpec = new InputSpec(ndim: 3); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var output = keras.backend.repeat_elements(inputs, size, axis: 1); + return output; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs new file mode 100644 index 000000000..cb579d61e --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Upsampling layer for 2D inputs. + /// + public class UpSampling2D : Layer + { + UpSampling2DArgs args; + int[] size; + string data_format; + string interpolation => args.Interpolation; + + public UpSampling2D(UpSampling2DArgs args) : base(args) + { + this.args = args; + data_format = conv_utils.normalize_data_format(args.DataFormat); + size = conv_utils.normalize_tuple(args.Size, 2, "size"); + inputSpec = new InputSpec(ndim: 4); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return keras.backend.resize_images(inputs, + size[0], size[1], + data_format, + interpolation: interpolation); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs new file mode 100644 index 000000000..3b37dac46 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs @@ -0,0 +1,37 @@ +using Tensorflow.NumPy; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Zero-padding layer for 2D input (e.g. picture). + /// + /// This layer can add rows and columns of zeros + /// at the top, bottom, left and right side of an image tensor. + /// + public class ZeroPadding2D : Layer + { + string data_format; + NDArray padding; + InputSpec input_spec; + + public ZeroPadding2D(ZeroPadding2DArgs args, string data_format = null) + : base(args) + { + this.data_format = conv_utils.normalize_data_format(data_format); + this.padding = args.Padding; + this.input_spec = new InputSpec(ndim: 4); + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + return keras.backend.spatial_2d_padding(inputs, + padding: padding, + data_format: data_format); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs new file mode 100644 index 000000000..737f88cd4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways. + /// Do not use this class as a layer, it is only an abstract base class. + /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. + /// + public abstract class Wrapper: Layer + { + public ILayer _layer; + public Wrapper(WrapperArgs args):base(args) + { + _layer = args.Layer; + } + + public virtual void Build(KerasShapesWrapper input_shape) + { + if (!_layer.Built) + { + _layer.build(input_shape); + } + built = true; + } + + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs new file mode 100644 index 000000000..0566b08ad --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs @@ -0,0 +1,285 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Bidirectional wrapper for RNNs. + /// + public class Bidirectional: Wrapper + { + int _num_constants = 0; + bool _support_masking = true; + bool _return_state; + bool _stateful; + bool _return_sequences; + BidirectionalArgs _args; + RNNArgs _layer_args_copy; + RNN _forward_layer; + RNN _backward_layer; + RNN _layer; + InputSpec _input_spec; + public Bidirectional(BidirectionalArgs args):base(args) + { + _args = args; + if (_args.Layer is not ILayer) + throw new ValueError( + "Please initialize `Bidirectional` layer with a " + + $"`tf.keras.layers.Layer` instance. Received: {_args.Layer}"); + + if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer) + throw new ValueError( + "`backward_layer` need to be a `tf.keras.layers.Layer` " + + $"instance. Received: {_args.BackwardLayer}"); + if (!new List { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode)) + { + throw new ValueError( + $"Invalid merge mode. Received: {_args.MergeMode}. " + + "Merge mode should be one of " + + "{\"sum\", \"mul\", \"ave\", \"concat\", null}" + ); + } + if (_args.Layer is RNN) + { + _layer = _args.Layer as RNN; + } + else + { + throw new ValueError( + "Bidirectional only support RNN instance such as LSTM or GRU"); + } + _return_state = _layer.Args.ReturnState; + _return_sequences = _layer.Args.ReturnSequences; + _stateful = _layer.Args.Stateful; + _layer_args_copy = _layer.Args.Clone(); + // We don't want to track `layer` since we're already tracking the two + // copies of it we actually run. + // TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented. + // _setattr_tracking = false; + // super().__init__(layer, **kwargs) + // _setattr_tracking = true; + + // Recreate the forward layer from the original layer config, so that it + // will not carry over any state from the layer. + if (_layer is LSTM) + { + var arg = _layer_args_copy as LSTMArgs; + _forward_layer = new LSTM(arg); + } + else if(_layer is SimpleRNN) + { + var arg = _layer_args_copy as SimpleRNNArgs; + _forward_layer = new SimpleRNN(arg); + } + // TODO(Wanglongzhi2001), add GRU if case. + else + { + _forward_layer = new RNN(_layer.Cell, _layer_args_copy); + } + //_forward_layer = _recreate_layer_from_config(_layer); + if (_args.BackwardLayer is null) + { + _backward_layer = _recreate_layer_from_config(_layer, go_backwards:true); + } + else + { + _backward_layer = _args.BackwardLayer as RNN; + } + _forward_layer.Name = "forward_" + _forward_layer.Name; + _backward_layer.Name = "backward_" + _backward_layer.Name; + _verify_layer_config(); + + void force_zero_output_for_mask(RNN layer) + { + layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences; + } + + force_zero_output_for_mask(_forward_layer); + force_zero_output_for_mask(_backward_layer); + + if (_args.Weights is not null) + { + var nw = len(_args.Weights); + _forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]); + _backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]); + } + + _input_spec = _layer.InputSpec; + } + + private void _verify_layer_config() + { + if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards) + { + throw new ValueError( + "Forward layer and backward layer should have different " + + "`go_backwards` value." + + "forward_layer.go_backwards = " + + $"{_forward_layer.Args.GoBackwards}," + + "backward_layer.go_backwards = " + + $"{_backward_layer.Args.GoBackwards}"); + } + if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful) + { + throw new ValueError( + "Forward layer and backward layer are expected to have "+ + $"the same value for attribute stateful, got "+ + $"{_forward_layer.Args.Stateful} for forward layer and "+ + $"{_backward_layer.Args.Stateful} for backward layer"); + } + if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState) + { + throw new ValueError( + "Forward layer and backward layer are expected to have " + + $"the same value for attribute return_state, got " + + $"{_forward_layer.Args.ReturnState} for forward layer and " + + $"{_backward_layer.Args.ReturnState} for backward layer"); + } + if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences) + { + throw new ValueError( + "Forward layer and backward layer are expected to have " + + $"the same value for attribute return_sequences, got " + + $"{_forward_layer.Args.ReturnSequences} for forward layer and " + + $"{_backward_layer.Args.ReturnSequences} for backward layer"); + } + } + + private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false) + { + var config = layer.get_config() as RNNArgs; + var cell = layer.Cell; + if (go_backwards) + { + config.GoBackwards = !config.GoBackwards; + } + + if (layer is LSTM) + { + var arg = config as LSTMArgs; + return new LSTM(arg); + } + else if(layer is SimpleRNN) + { + var arg = config as SimpleRNNArgs; + return new SimpleRNN(arg); + } + // TODO(Wanglongzhi2001), add GRU if case. + else + { + return new RNN(cell, config); + } + } + + public override void build(KerasShapesWrapper input_shape) + { + _buildInputShape = input_shape; + tf_with(ops.name_scope(_forward_layer.Name), scope=> + { + _forward_layer.build(input_shape); + }); + tf_with(ops.name_scope(_backward_layer.Name), scope => + { + _backward_layer.build(input_shape); + }); + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + // `Bidirectional.call` implements the same API as the wrapped `RNN`. + Tensors forward_inputs; + Tensors backward_inputs; + Tensors forward_state; + Tensors backward_state; + // if isinstance(inputs, list) and len(inputs) > 1: + if (inputs.Length > 1) + { + // initial_states are keras tensors, which means they are passed + // in together with inputs as list. The initial_states need to be + // split into forward and backward section, and be feed to layers + // accordingly. + forward_inputs = new Tensors { inputs[0] }; + backward_inputs = new Tensors { inputs[0] }; + var pivot = (len(inputs) - _num_constants) / 2 + 1; + // add forward initial state + forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] }); + if (_num_constants != 0) + // add backward initial state + backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] }); + else + { + // add backward initial state + backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] }); + // add constants for forward and backward layers + forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); + backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); + } + forward_state = null; + backward_state = null; + } + else if (state is not null) + { + // initial_states are not keras tensors, eg eager tensor from np + // array. They are only passed in from kwarg initial_state, and + // should be passed to forward/backward layer via kwarg + // initial_state as well. + forward_inputs = inputs; + backward_inputs = inputs; + var half = len(state) / 2; + forward_state = state[$":{half}"]; + backward_state = state[$"{half}:"]; + } + else + { + forward_inputs = inputs; + backward_inputs = inputs; + forward_state = null; + backward_state = null; + } + var y = _forward_layer.Apply(forward_inputs, forward_state); + var y_rev = _backward_layer.Apply(backward_inputs, backward_state); + + Tensors states = new(); + if (_return_state) + { + states = y["1:"] + y_rev["1:"]; + y = y[0]; + y_rev = y_rev[0]; + } + + if (_return_sequences) + { + int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1; + y_rev = keras.backend.reverse(y_rev, time_dim); + } + Tensors output; + if (_args.MergeMode == "concat") + output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() }); + else if (_args.MergeMode == "sum") + output = y.Single() + y_rev.Single(); + else if (_args.MergeMode == "ave") + output = (y.Single() + y_rev.Single()) / 2; + else if (_args.MergeMode == "mul") + output = y.Single() * y_rev.Single(); + else if (_args.MergeMode is null) + output = new Tensors { y.Single(), y_rev.Single() }; + else + throw new ValueError( + "Unrecognized value for `merge_mode`. " + + $"Received: {_args.MergeMode}" + + "Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]"); + if (_return_state) + { + if (_args.MergeMode is not null) + return new Tensors { output.Single(), states.Single()}; + } + return output; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs new file mode 100644 index 000000000..27c13f349 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + public abstract class DropoutRNNCellMixin: Layer, IRnnCell + { + public float dropout; + public float recurrent_dropout; + // TODO(Rinne): deal with cache. + public DropoutRNNCellMixin(LayerArgs args): base(args) + { + + } + + public abstract INestStructure StateSize { get; } + public abstract INestStructure OutputSize { get; } + public abstract bool SupportOptionalArgs { get; } + public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); + } + + protected void _create_non_trackable_mask_cache() + { + + } + + public void reset_dropout_mask() + { + + } + + public void reset_recurrent_dropout_mask() + { + + } + + public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1) + { + if (dropout == 0f) + return null; + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + // Get the recurrent dropout mask for RNN cell. + public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1) + { + if (dropout == 0f) + return null; + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) + { + Tensors dropped_inputs() + { + DropoutArgs args = new DropoutArgs(); + args.Rate = rate; + var DropoutLayer = new Dropout(args); + var mask = DropoutLayer.Apply(ones, training: training); + return mask; + } + + if (count > 1) + { + Tensors results = new Tensors(); + for (int i = 0; i < count; i++) + { + results.Add(dropped_inputs()); + } + return results; + } + + return dropped_inputs(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs new file mode 100644 index 000000000..0919883d2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs @@ -0,0 +1,168 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Saving; + + +namespace Tensorflow.Keras.Layers +{ + public class GRU : RNN + { + GRUArgs _args; + private static GRUCell _cell; + + bool _return_runtime; + public GRUCell Cell { get => _cell; } + public int units { get => _args.Units; } + public Activation activation { get => _args.Activation; } + public Activation recurrent_activation { get => _args.RecurrentActivation; } + public bool use_bias { get => _args.UseBias; } + public float dropout { get => _args.Dropout; } + public float recurrent_dropout { get => _args.RecurrentDropout; } + public IInitializer kernel_initializer { get => _args.KernelInitializer; } + public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; } + public IInitializer bias_initializer { get => _args.BiasInitializer; } + public int implementation { get => _args.Implementation; } + public bool reset_after { get => _args.ResetAfter; } + + public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args)) + { + _args = args; + + if (_args.Implementation == 0) + { + // Use the red output to act as a warning message that can also be used under the release version + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine("Warning: `implementation=0` has been deprecated, "+ + "and now defaults to `implementation=2`."+ + "Please update your layer call."); + Console.ResetColor(); + } + + GRUCell cell = new GRUCell(new GRUCellArgs + { + Units = _args.Units, + Activation = _args.Activation, + RecurrentActivation = _args.RecurrentActivation, + UseBias = _args.UseBias, + Dropout = _args.Dropout, + RecurrentDropout = _args.RecurrentDropout, + KernelInitializer = _args.KernelInitializer, + RecurrentInitializer = _args.RecurrentInitializer, + BiasInitializer = _args.BiasInitializer, + ResetAfter = _args.ResetAfter, + Implementation = _args.Implementation + }); + _cell = cell; + } + + protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs; + if (optional_args is not null && gru_optional_args is null) + { + throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`."); + } + Tensors? mask = gru_optional_args?.Mask; + + // Not support ragger input temporarily; + int row_length = 0; + bool is_ragged_input = false; + + _validate_args_if_ragged(is_ragged_input, mask); + + // GRU does not support constants.Ignore it during process. + (inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null); + + if (mask.Length > 1) + { + mask = mask[0]; + } + + var input_shape = inputs.shape; + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; + + + // TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part + Func step = (cell_inputs, cell_states) => + { + var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value); + var (output, state) = res; + return (output, state); + }; + + var (last_output, outputs, states) = keras.backend.rnn( + step, + inputs, + initial_state, + constants: null, + go_backwards: _args.GoBackwards, + mask: mask, + unroll: _args.Unroll, + input_length: ops.convert_to_tensor(timesteps), + time_major: _args.TimeMajor, + zero_output_for_mask: base.Args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences + ); + + Tensors output; + if (_args.ReturnSequences) + { + output = outputs; + } + else + { + output = last_output; + } + + if (_args.ReturnState) + { + output = new Tensors { output, states }; + } + return output; + } + + private static IRnnCell CreateCell(GRUArgs gruArgs) + { + return new GRUCell(new GRUCellArgs + { + Units = gruArgs.Units, + Activation = gruArgs.Activation, + RecurrentActivation = gruArgs.RecurrentActivation, + UseBias = gruArgs.UseBias, + Dropout = gruArgs.Dropout, + RecurrentDropout = gruArgs.RecurrentDropout, + KernelInitializer = gruArgs.KernelInitializer, + RecurrentInitializer = gruArgs.RecurrentInitializer, + BiasInitializer = gruArgs.BiasInitializer, + ResetAfter = gruArgs.ResetAfter, + Implementation = gruArgs.Implementation + }); + } + + private static RNNArgs PreConstruct(GRUArgs args) + { + return new RNNArgs + { + ReturnSequences = args.ReturnSequences, + ReturnState = args.ReturnState, + GoBackwards = args.GoBackwards, + Stateful = args.Stateful, + Unroll = args.Unroll, + TimeMajor = args.TimeMajor, + Units = args.Units, + Activation = args.Activation, + RecurrentActivation = args.RecurrentActivation, + UseBias = args.UseBias, + Dropout = args.Dropout, + RecurrentDropout = args.RecurrentDropout, + KernelInitializer = args.KernelInitializer, + RecurrentInitializer = args.RecurrentInitializer, + BiasInitializer = args.BiasInitializer + }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs new file mode 100644 index 000000000..2b9c01e31 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs @@ -0,0 +1,281 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Cell class for the GRU layer. + /// + public class GRUCell : DropoutRNNCellMixin + { + GRUCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IInitializer _bias_initializer; + IVariableV1 _bias; + INestStructure _state_size; + INestStructure _output_size; + int Units; + public override INestStructure StateSize => _state_size; + + public override INestStructure OutputSize => _output_size; + + public override bool SupportOptionalArgs => false; + public GRUCell(GRUCellArgs args) : base(args) + { + _args = args; + if (_args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + _args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout)); + _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + if (_args.RecurrentDropout != 0f && _args.Implementation != 1) + { + Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + + "Using `implementation=1`."); + _args.Implementation = 1; + } + Units = _args.Units; + _state_size = new NestList(Units); + _output_size = new NestNode(Units); + } + + public override void build(KerasShapesWrapper input_shape) + { + //base.build(input_shape); + + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; + + _kernel = add_weight("kernel", (input_dim, _args.Units * 3), + initializer: _args.KernelInitializer + ); + + _recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3), + initializer: _args.RecurrentInitializer + ); + if (_args.UseBias) + { + Shape bias_shape; + if (!_args.ResetAfter) + { + bias_shape = new Shape(3 * Units); + } + else + { + bias_shape = (2, 3 * Units); + } + _bias = add_weight("bias", bias_shape, + initializer: _bias_initializer + ); + } + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var h_tm1 = states.IsNested() ? states[0] : states.Single(); + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3); + + IVariableV1 input_bias = _bias; + IVariableV1 recurrent_bias = _bias; + if (_args.UseBias) + { + if (!_args.ResetAfter) + { + input_bias = _bias; + recurrent_bias = null; + } + else + { + input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]); + recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]); + } + } + + + Tensor hh; + Tensor z; + if ( _args.Implementation == 1) + { + Tensor inputs_z; + Tensor inputs_r; + Tensor inputs_h; + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs_z = inputs * dp_mask[0]; + inputs_r = inputs * dp_mask[1]; + inputs_h = inputs * dp_mask[2]; + } + else + { + inputs_z = inputs.Single(); + inputs_r = inputs.Single(); + inputs_h = inputs.Single(); + } + + + int startIndex = (int)_kernel.AsTensor().shape[0]; + var _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units }); + var x_z = math_ops.matmul(inputs_z, _kernel_slice); + _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, Units }, new[] { Units, Units}); + var x_r = math_ops.matmul( + inputs_r, _kernel_slice); + int endIndex = (int)_kernel.AsTensor().shape[1]; + _kernel_slice = tf.slice(_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + var x_h = math_ops.matmul(inputs_h, _kernel_slice); + + if(_args.UseBias) + { + x_z = tf.nn.bias_add( + x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"])); + x_r = tf.nn.bias_add( + x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"])); + x_h = tf.nn.bias_add( + x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"])); + } + + Tensor h_tm1_z; + Tensor h_tm1_r; + Tensor h_tm1_h; + if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) + { + h_tm1_z = h_tm1 * rec_dp_mask[0]; + h_tm1_r = h_tm1 * rec_dp_mask[1]; + h_tm1_h = h_tm1 * rec_dp_mask[2]; + } + else + { + h_tm1_z = h_tm1; + h_tm1_r = h_tm1; + h_tm1_h = h_tm1; + } + + startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units }); + var recurrent_z = math_ops.matmul( + h_tm1_z, _recurrent_kernel_slice); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units }, new[] { startIndex, Units}); + var recurrent_r = math_ops.matmul( + h_tm1_r, _recurrent_kernel_slice); + if(_args.ResetAfter && _args.UseBias) + { + recurrent_z = tf.nn.bias_add( + recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"])); + recurrent_r = tf.nn.bias_add( + recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"])); + } + z = _args.RecurrentActivation.Apply(x_z + recurrent_z); + var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); + + Tensor recurrent_h; + if (_args.ResetAfter) + { + endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + recurrent_h = math_ops.matmul( + h_tm1_h, _recurrent_kernel_slice); + if(_args.UseBias) + { + recurrent_h = tf.nn.bias_add( + recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"])); + } + recurrent_h *= r; + } + else + { + _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); + recurrent_h = math_ops.matmul( + r * h_tm1_h, _recurrent_kernel_slice); + } + hh = _args.Activation.Apply(x_h + recurrent_h); + } + else + { + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs = inputs * dp_mask[0]; + } + + var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor()); + if(_args.UseBias) + { + matrix_x = tf.nn.bias_add(matrix_x, input_bias); + } + var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1); + var x_z = matrix_x_spilted[0]; + var x_r = matrix_x_spilted[1]; + var x_h = matrix_x_spilted[2]; + + Tensor matrix_inner; + if (_args.ResetAfter) + { + matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); + if ( _args.UseBias) + { + matrix_inner = tf.nn.bias_add( + matrix_inner, recurrent_bias); + } + } + else + { + var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 0 }, new[] { startIndex, Units * 2 }); + matrix_inner = math_ops.matmul( + h_tm1, _recurrent_kernel_slice); + } + + var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1); + var recurrent_z = matrix_inner_splitted[0]; + var recurrent_r = matrix_inner_splitted[0]; + var recurrent_h = matrix_inner_splitted[0]; + + z = _args.RecurrentActivation.Apply(x_z + recurrent_z); + var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); + + if(_args.ResetAfter) + { + recurrent_h = r * recurrent_h; + } + else + { + var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; + var endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), + new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units }); + recurrent_h = math_ops.matmul( + r * h_tm1, _recurrent_kernel_slice); + } + hh = _args.Activation.Apply(x_h + recurrent_h); + } + var h = z * h_tm1 + (1 - z) * hh; + if (states.IsNested()) + { + var new_state = new NestList(h); + return new Nest(new INestStructure[] { new NestNode(h), new_state }).ToTensors(); + } + else + { + return new Nest(new INestStructure[] { new NestNode(h), new NestNode(h)}).ToTensors(); + } + + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs new file mode 100644 index 000000000..c766e8d69 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -0,0 +1,126 @@ +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Common.Types; +using Tensorflow.Common.Extensions; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Long Short-Term Memory layer - Hochreiter 1997. + /// + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// + public class LSTM : RNN + { + LSTMArgs _args; + InputSpec[] _state_spec; + InputSpec _input_spec; + bool _could_use_gpu_kernel; + public LSTMArgs Args { get => _args; } + public LSTM(LSTMArgs args) : + base(CreateCell(args), args) + { + _args = args; + _input_spec = new InputSpec(ndim: 3); + _state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); + _could_use_gpu_kernel = args.Activation == keras.activations.Tanh + && args.RecurrentActivation == keras.activations.Sigmoid + && args.RecurrentDropout == 0 && !args.Unroll && args.UseBias + && ops.executing_eagerly_outside_functions(); + } + + private static IRnnCell CreateCell(LSTMArgs lstmArgs) + { + return new LSTMCell(new LSTMCellArgs() + { + Units = lstmArgs.Units, + Activation = lstmArgs.Activation, + RecurrentActivation = lstmArgs.RecurrentActivation, + UseBias = lstmArgs.UseBias, + KernelInitializer = lstmArgs.KernelInitializer, + RecurrentInitializer = lstmArgs.RecurrentInitializer, + UnitForgetBias = lstmArgs.UnitForgetBias, + BiasInitializer = lstmArgs.BiasInitializer, + // TODO(Rinne): kernel_regularizer + // TODO(Rinne): recurrent_regularizer + // TODO(Rinne): bias_regularizer + // TODO(Rinne): kernel_constriant + // TODO(Rinne): recurrent_constriant + // TODO(Rinne): bias_constriant + Dropout = lstmArgs.Dropout, + RecurrentDropout = lstmArgs.RecurrentDropout, + Implementation = lstmArgs.Implementation, + DType = lstmArgs.DType, + Trainable = lstmArgs.Trainable + }); + } + + protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + // skip the condition of ragged input + + (inputs, initial_state, _) = _process_inputs(inputs, initial_state, null); + + Tensor mask = null; + if(optional_args is RnnOptionalArgs rnnArgs) + { + mask = rnnArgs.Mask; + } + + var single_input = inputs.Single; + var input_shape = single_input.shape; + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; + + _maybe_reset_cell_dropout_mask(Cell); + + Func step = (inputs, states) => + { + var res = Cell.Apply(inputs, states, training is null ? true : training.Value); + var (output, state) = res; + return (output, state); + }; + + var (last_output, outputs, states) = keras.backend.rnn( + step, + inputs, + initial_state, + constants: null, + go_backwards: _args.GoBackwards, + mask: mask, + unroll: _args.Unroll, + input_length: ops.convert_to_tensor(timesteps), + time_major: _args.TimeMajor, + zero_output_for_mask: _args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences + ); + + Tensor output; + if (_args.ReturnSequences) + { + output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards); + } + else + { + output = last_output; + } + + if (_args.ReturnState) + { + return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); + } + else + { + return output; + } + } + + public override IKerasConfig get_config() + { + return _args; + } + + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs new file mode 100644 index 000000000..e4fc6dd22 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -0,0 +1,233 @@ +using Newtonsoft.Json; +using Serilog.Core; +using System.Diagnostics; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Cell class for the LSTM layer. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// This class processes one step within the whole time sequence input, whereas + /// `tf.keras.layer.LSTM` processes the whole sequence. + /// + public class LSTMCell : DropoutRNNCellMixin + { + LSTMCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IInitializer _bias_initializer; + IVariableV1 _bias; + INestStructure _state_size; + INestStructure _output_size; + public override INestStructure StateSize => _state_size; + + public override INestStructure OutputSize => _output_size; + + public override bool SupportOptionalArgs => false; + public LSTMCell(LSTMCellArgs args) + : base(args) + { + _args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + _args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); + _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + if (_args.RecurrentDropout != 0f && _args.Implementation != 1) + { + Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + + "Using `implementation=1`."); + _args.Implementation = 1; + } + + _state_size = new NestList(_args.Units, _args.Units); + _output_size = new NestNode(_args.Units); + } + + public override void build(KerasShapesWrapper input_shape) + { + base.build(input_shape); + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; + _kernel = add_weight("kernel", (input_dim, _args.Units * 4), + initializer: _args.KernelInitializer + ); + + _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4), + initializer: _args.RecurrentInitializer + ); + + if (_args.UseBias) + { + if (_args.UnitForgetBias) + { + Tensor bias_initializer() + { + return keras.backend.concatenate( + new Tensors( + _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))), + tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))), + _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0); + } + } + else + { + _bias_initializer = _args.BiasInitializer; + } + _bias = add_weight("bias", (_args.Units * 4), + initializer: _bias_initializer + ); + } + built = true; + } + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var h_tm1 = states[0]; // previous memory state + var c_tm1 = states[1]; // previous carry state + + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell( + h_tm1, training.Value, count: 4); + + Tensor c; + Tensor o; + if (_args.Implementation == 1) + { + Tensor inputs_i; + Tensor inputs_f; + Tensor inputs_c; + Tensor inputs_o; + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs_i = inputs * dp_mask[0]; + inputs_f = inputs * dp_mask[1]; + inputs_c = inputs * dp_mask[2]; + inputs_o = inputs * dp_mask[3]; + } + else + { + inputs_i = inputs; + inputs_f = inputs; + inputs_c = inputs; + inputs_o = inputs; + } + var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1); + Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3]; + var x_i = math_ops.matmul(inputs_i, k_i); + var x_f = math_ops.matmul(inputs_f, k_f); + var x_c = math_ops.matmul(inputs_c, k_c); + var x_o = math_ops.matmul(inputs_o, k_o); + if (_args.UseBias) + { + var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); + Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; + x_i = gen_nn_ops.bias_add(x_i, b_i); + x_f = gen_nn_ops.bias_add(x_f, b_f); + x_c = gen_nn_ops.bias_add(x_c, b_c); + x_o = gen_nn_ops.bias_add(x_o, b_o); + } + + Tensor h_tm1_i; + Tensor h_tm1_f; + Tensor h_tm1_c; + Tensor h_tm1_o; + if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) + { + h_tm1_i = h_tm1 * rec_dp_mask[0]; + h_tm1_f = h_tm1 * rec_dp_mask[1]; + h_tm1_c = h_tm1 * rec_dp_mask[2]; + h_tm1_o = h_tm1 * rec_dp_mask[3]; + } + else + { + h_tm1_i = h_tm1; + h_tm1_f = h_tm1; + h_tm1_c = h_tm1; + h_tm1_o = h_tm1; + } + var x = new Tensor[] { x_i, x_f, x_c, x_o }; + var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o }; + (c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1); + } + else + { + if (0f < _args.Dropout && _args.Dropout < 1f) + inputs = inputs * dp_mask[0]; + var z = math_ops.matmul(inputs, _kernel.AsTensor()); + z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); + if (_args.UseBias) + { + z = tf.nn.bias_add(z, _bias); + } + var z_array = tf.split(z, num_split: 4, axis: 1); + (c, o) = _compute_carry_and_output_fused(z_array, c_tm1); + } + var h = o * _args.Activation.Apply(c); + // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 + return new Nest(new INestStructure[] { new NestNode(h), new NestList(h, c) }).ToTensors(); + } + + /// + /// Computes carry and output using split kernels. + /// + /// + /// + /// + /// + /// + public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1) + { + Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3]; + Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2], + h_tm1_o = h_tm1[3]; + + var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); + int startIndex = (int)_recurrent_kernel_tensor.shape[0]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, 0 }, new[] { startIndex, _args.Units }); + var i = _args.RecurrentActivation.Apply( + x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units }, new[] { startIndex, _args.Units}); + var f = _args.RecurrentActivation.Apply( + x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units }); + var c = f * c_tm1 + i * _args.Activation.Apply( + x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units }); + var o = _args.Activation.Apply( + x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); + + return new Tensors(c, o); + } + + /// + /// Computes carry and output using fused kernels. + /// + /// + /// + /// + public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1) + { + Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; + var i = _args.RecurrentActivation.Apply(z0); + var f = _args.RecurrentActivation.Apply(z1); + var c = f * c_tm1 + i * _args.Activation.Apply(z2); + var o = _args.RecurrentActivation.Apply(z3); + return new Tensors(c, o); + } + } + + +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs new file mode 100644 index 000000000..fec75559c --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -0,0 +1,546 @@ +using OneOf; +using System; +using System.Collections.Generic; +using System.Reflection; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Util; +using Tensorflow.Common.Extensions; +using System.Linq.Expressions; +using Tensorflow.Keras.Utils; +using Tensorflow.Common.Types; +using System.Runtime.CompilerServices; +// from tensorflow.python.distribute import distribution_strategy_context as ds_context; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Base class for recurrent layers. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// + public class RNN : RnnBase + { + private RNNArgs _args; + private object _input_spec = null; // or NoneValue?? + private object _state_spec = null; + private object _constants_spec = null; + private Tensors _states = null; + private int _num_constants; + protected IVariableV1 _kernel; + protected IVariableV1 _bias; + private IRnnCell _cell; + + public RNNArgs Args { get => _args; } + public IRnnCell Cell + { + get + { + return _cell; + } + init + { + _cell = value; + _self_tracked_trackables.Add(_cell); + } + } + + public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args)) + { + _args = args; + SupportsMasking = true; + + Cell = cell; + + // get input_shape + _args = PreConstruct(args); + + _num_constants = 0; + } + + public RNN(IEnumerable cells, RNNArgs args) : base(PreConstruct(args)) + { + _args = args; + SupportsMasking = true; + + Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs()); + + // get input_shape + _args = PreConstruct(args); + + _num_constants = 0; + } + + // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) + // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape + public Tensors States + { + get + { + if (_states == null) + { + // CHECK(Rinne): check if this is correct. + var nested = Cell.StateSize.MapStructure(x => null); + _states = nested.AsNest().ToTensors(); + } + return _states; + } + set { _states = value; } + } + + private INestStructure compute_output_shape(Shape input_shape) + { + var batch = input_shape[0]; + var time_step = input_shape[1]; + if (_args.TimeMajor) + { + (batch, time_step) = (time_step, batch); + } + + // state_size is a array of ints or a positive integer + var state_size = Cell.StateSize; + if(state_size?.TotalNestedCount == 1) + { + state_size = new NestList(state_size.Flatten().First()); + } + + Func _get_output_shape = (flat_output_size) => + { + var output_dim = new Shape(flat_output_size).as_int_list(); + Shape output_shape; + if (_args.ReturnSequences) + { + if (_args.TimeMajor) + { + output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); + } + else + { + output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); + + } + } + else + { + output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); + } + return output_shape; + }; + + Type type = Cell.GetType(); + PropertyInfo output_size_info = type.GetProperty("output_size"); + INestStructure output_shape; + if (output_size_info != null) + { + output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize); + } + else + { + output_shape = new NestNode(_get_output_shape(state_size.Flatten().First())); + } + + if (_args.ReturnState) + { + Func _get_state_shape = (flat_state) => + { + var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list()); + return new Shape(state_shape); + }; + + + var state_shape = Nest.MapStructure(_get_state_shape, state_size); + + return new Nest(new[] { output_shape, state_shape } ); + } + else + { + return output_shape; + } + + } + + private Tensors compute_mask(Tensors inputs, Tensors mask) + { + // Time step masks must be the same for each input. + // This is because the mask for an RNN is of size [batch, time_steps, 1], + // and specifies which time steps should be skipped, and a time step + // must be skipped for all inputs. + + mask = nest.flatten(mask)[0]; + var output_mask = _args.ReturnSequences ? mask : null; + if (_args.ReturnState) + { + var state_mask = new List(); + for (int i = 0; i < len(States); i++) + { + state_mask.Add(null); + } + return new List { output_mask }.concat(state_mask); + } + else + { + return output_mask; + } + } + + public override void build(KerasShapesWrapper input_shape) + { + _buildInputShape = input_shape; + input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); + + InputSpec get_input_spec(Shape shape) + { + var input_spec_shape = shape.as_int_list(); + + var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1); + if (!_args.Stateful) + { + input_spec_shape[batch_index] = -1; + } + input_spec_shape[time_step_index] = -1; + return new InputSpec(shape: input_spec_shape); + } + + Shape get_step_input_shape(Shape shape) + { + + // return shape[1:] if self.time_major else (shape[0],) + shape[2:] + if (_args.TimeMajor) + { + return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); + } + else + { + return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); + } + + + } + + object get_state_spec(Shape shape) + { + var state_spec_shape = shape.as_int_list(); + // append bacth dim + state_spec_shape = new int[] { -1 }.concat(state_spec_shape); + return new InputSpec(shape: state_spec_shape); + } + + // Check whether the input shape contains any nested shapes. It could be + // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from + // numpy inputs. + + + if (Cell is Layer layer && !layer.Built) + { + layer.build(input_shape); + layer.Built = true; + } + + this.built = true; + } + + /// + /// + /// + /// + /// List of initial state tensors to be passed to the first call of the cell + /// + /// + /// + /// + /// + protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + if(optional_args is not null && rnn_optional_args is null) + { + throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`"); + } + Tensors? constants = rnn_optional_args?.Constants; + Tensors? mask = rnn_optional_args?.Mask; + //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); + // 暂时先不接受ragged tensor + int row_length = 0; // TODO(Rinne): support this param. + bool is_ragged_input = false; + _validate_args_if_ragged(is_ragged_input, mask); + + (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); + + _maybe_reset_cell_dropout_mask(Cell); + if (Cell is StackedRNNCells) + { + var stack_cell = Cell as StackedRNNCells; + foreach (IRnnCell cell in stack_cell.Cells) + { + _maybe_reset_cell_dropout_mask(cell); + } + } + + if (mask != null) + { + // Time step masks must be the same for each input. + mask = mask.Flatten().First(); + } + + Shape input_shape; + if (!inputs.IsNested()) + { + // In the case of nested input, use the first element for shape check + // input_shape = nest.flatten(inputs)[0].shape; + // TODO(Wanglongzhi2001) + input_shape = inputs.Flatten().First().shape; + } + else + { + input_shape = inputs.shape; + } + + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; + + if (_args.Unroll && timesteps == null) + { + throw new ValueError( + "Cannot unroll a RNN if the " + + "time dimension is undefined. \n" + + "- If using a Sequential model, " + + "specify the time dimension by passing " + + "an `input_shape` or `batch_input_shape` " + + "argument to your first layer. If your " + + "first layer is an Embedding, you can " + + "also use the `input_length` argument.\n" + + "- If using the functional API, specify " + + "the time dimension by passing a `shape` " + + "or `batch_shape` argument to your Input layer." + ); + } + + // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) + Func step; + bool is_tf_rnn_cell = false; + if (constants is not null) + { + if (!Cell.SupportOptionalArgs) + { + throw new ValueError( + $"RNN cell {Cell} does not support constants." + + $"Received: constants={constants}"); + } + + step = (inputs, states) => + { + constants = new Tensors(states.TakeLast(_num_constants).ToArray()); + states = new Tensors(states.SkipLast(_num_constants).ToArray()); + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; + var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + return (output, new_states); + }; + } + else + { + step = (inputs, states) => + { + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; + var (output, new_states) = Cell.Apply(inputs, states); + return (output, new_states); + }; + } + + var (last_output, outputs, states) = keras.backend.rnn( + step, + inputs, + initial_state, + constants: constants, + go_backwards: _args.GoBackwards, + mask: mask, + unroll: _args.Unroll, + input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), + time_major: _args.TimeMajor, + zero_output_for_mask: _args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences); + + if (_args.Stateful) + { + throw new NotImplementedException("this argument havn't been developed."); + } + + Tensors output = new Tensors(); + if (_args.ReturnSequences) + { + // TODO(Rinne): add go_backwards parameter and revise the `row_length` param + output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false); + } + else + { + output = last_output; + } + + if (_args.ReturnState) + { + foreach (var state in states) + { + output.Add(state); + } + return output; + } + else + { + //var tapeSet = tf.GetTapeSet(); + //foreach(var tape in tapeSet) + //{ + // tape.Watch(output); + //} + return output; + } + } + + public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null) + { + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + if (optional_args is not null && rnn_optional_args is null) + { + throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`."); + } + Tensors? constants = rnn_optional_args?.Constants; + (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants); + + if(initial_states is null && constants is null) + { + return base.Apply(inputs); + } + + // TODO(Rinne): implement it. + throw new NotImplementedException(); + } + + protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) + { + if (inputs.Length > 1) + { + if (_num_constants != 0) + { + initial_state = new Tensors(inputs.Skip(1).ToArray()); + } + else + { + initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray()); + constants = new Tensors(inputs.TakeLast(_num_constants).ToArray()); + } + if (len(initial_state) == 0) + initial_state = null; + inputs = inputs[0]; + } + + + if (_args.Stateful) + { + if (initial_state != null) + { + var tmp = new Tensor[] { }; + foreach (var s in nest.flatten(States)) + { + tmp.add(tf.math.count_nonzero(s.Single())); + } + var non_zero_count = tf.add_n(tmp); + initial_state = tf.cond(non_zero_count > 0, States, initial_state); + if ((int)non_zero_count.numpy() > 0) + { + initial_state = States; + } + } + else + { + initial_state = States; + } + //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state); + } + else if (initial_state is null) + { + initial_state = get_initial_state(inputs); + } + + if (initial_state.Length != States.Length) + { + throw new ValueError($"Layer {this} expects {States.Length} state(s), " + + $"but it received {initial_state.Length} " + + $"initial state(s). Input received: {inputs}"); + } + + return (inputs, initial_state, constants); + } + + protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) + { + if (!is_ragged_input) + { + return; + } + + if (_args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); + } + + } + + protected void _maybe_reset_cell_dropout_mask(ILayer cell) + { + if (cell is DropoutRNNCellMixin CellDRCMixin) + { + CellDRCMixin.reset_dropout_mask(); + CellDRCMixin.reset_recurrent_dropout_mask(); + } + } + + private static RNNArgs PreConstruct(RNNArgs args) + { + // If true, the output for masked timestep will be zeros, whereas in the + // false case, output from previous timestep is returned for masked timestep. + var zeroOutputForMask = args.ZeroOutputForMask; + + Shape input_shape; + var propIS = args.InputShape; + var propID = args.InputDim; + var propIL = args.InputLength; + + if (propIS == null && (propID != null || propIL != null)) + { + input_shape = new Shape( + propIL ?? -1, + propID ?? -1); + args.InputShape = input_shape; + } + + return args; + } + + public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) + { + throw new NotImplementedException(); + } + + protected Tensors get_initial_state(Tensors inputs) + { + var input = inputs[0]; + var input_shape = array_ops.shape(inputs); + var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; + var dtype = input.dtype; + Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); + return init_state; + } + + public override IKerasConfig get_config() + { + return _args; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs new file mode 100644 index 000000000..1419da4b2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public abstract class RnnBase: Layer + { + public RnnBase(LayerArgs args): base(args) { } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs new file mode 100644 index 000000000..9c199eb43 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -0,0 +1,35 @@ +using System.Data; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; +using Tensorflow.Operations.Activation; +using static HDF.PInvoke.H5Z; +using static Tensorflow.ApiDef.Types; + +namespace Tensorflow.Keras.Layers +{ + public class SimpleRNN : RNN + { + SimpleRNNArgs args; + public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args), args) + { + this.args = args; + } + + private static SimpleRNNCell CreateCellForArgs(SimpleRNNArgs args) + { + return new SimpleRNNCell(new SimpleRNNCellArgs() + { + Units = args.Units, + Activation = args.Activation, + UseBias = args.UseBias, + KernelInitializer = args.KernelInitializer, + RecurrentInitializer = args.RecurrentInitializer, + BiasInitializer = args.BiasInitializer, + Dropout = args.Dropout, + RecurrentDropout = args.RecurrentDropout, + DType = args.DType, + Trainable = args.Trainable, + }); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs new file mode 100644 index 000000000..e74b56925 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; +using Tensorflow.Common.Extensions; +using Tensorflow.Keras.Utils; +using Tensorflow.Graphs; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Cell class for SimpleRNN. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// This class processes one step within the whole time sequence input, whereas + /// `tf.keras.layer.SimpleRNN` processes the whole sequence. + /// + public class SimpleRNNCell : DropoutRNNCellMixin + { + SimpleRNNCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IVariableV1 _bias; + INestStructure _state_size; + INestStructure _output_size; + + public override INestStructure StateSize => _state_size; + public override INestStructure OutputSize => _output_size; + public override bool SupportOptionalArgs => false; + + public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) + { + this._args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); + this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + _state_size = new NestNode(args.Units); + _output_size = new NestNode(args.Units); + } + + public override void build(KerasShapesWrapper input_shape) + { + // TODO(Rinne): add the cache. + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; + + _kernel = add_weight("kernel", (single_shape[-1], _args.Units), + initializer: _args.KernelInitializer + ); + + _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units), + initializer: _args.RecurrentInitializer + ); + + if (_args.UseBias) + { + _bias = add_weight("bias", (_args.Units), + initializer: _args.BiasInitializer + ); + } + + built = true; + } + + // TODO(Rinne): revise the trining param (with refactoring of the framework) + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + // TODO(Rinne): check if it will have multiple tensors when not nested. + Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); + + Tensor h; + var ranks = inputs.rank; + if (dp_mask != null) + { + + h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); + } + else + { + h = math_ops.matmul(inputs, _kernel.AsTensor()); + } + + if (_bias != null) + { + h = tf.nn.bias_add(h, _bias); + } + + if (rec_dp_mask != null) + { + prev_output = math_ops.multiply(prev_output, rec_dp_mask); + } + Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); + + if (_args.Activation != null) + { + output = _args.Activation.Apply(output); + } + if (Nest.IsNested(states)) + { + return new Nest(new List> { + new Nest(new List> { new Nest(output) }), new Nest(output) }) + .ToTensors(); + } + else + { + return new Tensors(output, output); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs new file mode 100644 index 000000000..ece2bc5bf --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -0,0 +1,159 @@ +using System; +using System.ComponentModel; +using System.Linq; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Layers +{ + public class StackedRNNCells : Layer, IRnnCell + { + public IList Cells { get; set; } + public bool _reverse_state_order; + + public StackedRNNCells(IEnumerable cells, StackedRNNCellsArgs args) : base(args) + { + Cells = cells.ToList(); + + _reverse_state_order = args.ReverseStateOrder; + + if (_reverse_state_order) + { + throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + + "be deprecated. Please update the code to work with the " + + "natural order of states if you rely on the RNN states, " + + "eg RNN(return_state=True)."); + } + } + + public bool SupportOptionalArgs => false; + + public INestStructure StateSize + { + get + { + if (_reverse_state_order) + { + var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); + return new Nest(state_sizes); + } + else + { + var state_sizes = Cells.Select(cell => cell.StateSize); + return new Nest(state_sizes); + } + } + } + + public INestStructure OutputSize + { + get + { + var lastCell = Cells.Last(); + if(lastCell.OutputSize is not null) + { + return lastCell.OutputSize; + } + else if (RnnUtils.is_multiple_state(lastCell.StateSize)) + { + return new NestNode(lastCell.StateSize.Flatten().First()); + } + else + { + return lastCell.StateSize; + } + } + } + + public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + var cells = _reverse_state_order ? Cells.Reverse() : Cells; + List initial_states = new List(); + foreach (var cell in cells) + { + initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype)); + } + return new Tensors(initial_states); + } + + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + // Recover per-cell states. + var state_size = _reverse_state_order ? new NestList(StateSize.Flatten().Reverse()) : StateSize; + var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); + + var new_nest_states = Nest.Empty; + // Call the cells in order and store the returned states. + foreach (var (cell, internal_states) in zip(Cells, nested_states)) + { + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + Tensors? constants = rnn_optional_args?.Constants; + + Tensors new_states; + (inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants }); + + new_nest_states = new_nest_states.MergeWith(new_states); + } + return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray()))); + } + + public override void build(KerasShapesWrapper input_shape) + { + var shape = input_shape.ToSingleShape(); + foreach(var cell in Cells) + { + if(cell is Layer layer && !layer.Built) + { + // ignored the name scope. + layer.build(shape); + layer.Built = true; + } + INestStructure output_dim; + if(cell.OutputSize is not null) + { + output_dim = cell.OutputSize; + } + else if (RnnUtils.is_multiple_state(cell.StateSize)) + { + output_dim = new NestNode(cell.StateSize.Flatten().First()); + } + else + { + output_dim = cell.StateSize; + } + shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray()); + } + this.Built = true; + } + + public override IKerasConfig get_config() + { + throw new NotImplementedException(); + //def get_config(self): + // cells = [] + // for cell in self.cells: + // cells.append(generic_utils.serialize_keras_object(cell)) + // config = {'cells': cells} + // base_config = super(StackedRNNCells, self).get_config() + // return dict(list(base_config.items()) + list(config.items())) + } + + + public void from_config() + { + throw new NotImplementedException(); + // @classmethod + // def from_config(cls, config, custom_objects = None): + // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + // cells = [] + // for cell_config in config.pop('cells'): + // cells.append( + // deserialize_layer(cell_config, custom_objects = custom_objects)) + // return cls(cells, **config) + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs new file mode 100644 index 000000000..6dfec3196 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -0,0 +1,105 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.Graphs; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using Tensorflow.Functions; +using System.Threading; +using Tensorflow.Common.Types; + +namespace Tensorflow.Keras.Layers +{ + public class TensorFlowOpLayer : Layer + { + TensorFlowOpLayerArgs args; + Dictionary constants => args.Constants; + NodeDef node_def => args.NodeDef; + static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; + public string OpType => node_def.Op; + + public TensorFlowOpLayer(TensorFlowOpLayerArgs args) + : base(new LayerArgs + { + Name = TF_OP_LAYER_NAME_PREFIX + args.Name, + Trainable = args.Trainable, + DType = args.DType, + Autocast = false + }) + { + this.args = args; + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + if (tf.Context.executing_eagerly()) + return DeFunCall(inputs); + return MakOp(inputs); + } + + ThreadLocal function = new ThreadLocal(); + Tensors DeFunCall(Tensors inputs) + { + if (function.Value == null) + { + function.Value = new ConcreteFunction(name); + function.Value.Enter(); + + int i = 0; + var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); + var graph_outputs = MakOp(graph_inputs); + graph_outputs = mark_as_return(graph_outputs); + + function.Value.ToGraph(graph_inputs, graph_outputs); + function.Value.Exit(); + } + + var outputs = function.Value.FilteredCall(inputs); + return outputs; + } + + Tensors mark_as_return(Tensors tensors) + { + var result = new Tensors(); + foreach (var tensor in tensors) + result.Add(array_ops.identity(tensor)); + return result; + } + + [AutoGraph] + Tensors _defun_call(Tensors inputs) + => MakOp(inputs); + + Tensors MakOp(Tensors inputs) + { + var graph = inputs.graph; + graph.as_default(); + foreach (var (index, constant) in enumerate(constants)) + { + var value = constant_op.constant(constant, name: node_def.Input[index]); + inputs.Insert(index, value); + } + + var (c_op, op_desc) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); + var op = graph._create_op_from_tf_operation(c_op, desc: op_desc); + op._control_flow_post_processing(); + + // Record the gradient because custom-made ops don't go through the + // code-gen'd eager call path + var op_type = op.node_def.Op; + + tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); + + graph.Exit(); + return op.outputs; + } + + public Layer GetOpLayer(TensorFlowOpLayerArgs args) + => new TensorFlowOpLayer(args); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs new file mode 100644 index 000000000..0de50a7ec --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs @@ -0,0 +1,24 @@ +namespace Tensorflow.Keras.Losses; + +public class BinaryCrossentropy : LossFunctionWrapper +{ + float label_smoothing; + + public BinaryCrossentropy( + bool from_logits = false, + float label_smoothing = 0, + string reduction = null, + string name = null) : + base(reduction: reduction, + name: name == null ? "binary_crossentropy" : name, + from_logits: from_logits) + { + this.label_smoothing = label_smoothing; + } + + public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) + { + var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits); + return keras.backend.mean(sum, axis: axis); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs new file mode 100644 index 000000000..1af57b552 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs @@ -0,0 +1,24 @@ +namespace Tensorflow.Keras.Losses; + +public class CategoricalCrossentropy : LossFunctionWrapper +{ + float label_smoothing; + + public CategoricalCrossentropy( + bool from_logits = false, + float label_smoothing = 0, + string reduction = null, + string name = null) : + base(reduction: reduction, + name: name == null ? "categorical_crossentropy" : name, + from_logits: from_logits) + { + this.label_smoothing = label_smoothing; + } + + public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) + { + // Try to adjust the shape so that rank of labels = rank of logits - 1. + return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs new file mode 100644 index 000000000..cf9df8d0d --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs @@ -0,0 +1,22 @@ +namespace Tensorflow.Keras.Losses; + +public class CosineSimilarity : LossFunctionWrapper +{ + protected int axis = -1; + + public CosineSimilarity( + string reduction = null, + int axis = -1, + string name = null) : + base(reduction: reduction, name: name == null ? "cosine_similarity" : name) + { + this.axis = axis; + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis); + Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis); + return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs new file mode 100644 index 000000000..61f006d2b --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/Huber.cs @@ -0,0 +1,29 @@ +namespace Tensorflow.Keras.Losses; + +public class Huber : LossFunctionWrapper +{ + protected Tensor delta = tf.Variable(1.0); + + public Huber( + string reduction = null, + Tensor delta = null, + string name = null) : + base(reduction: reduction, name: name == null ? "huber" : name) + { + this.delta = delta == null ? this.delta : delta; + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); + Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); + Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); + Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); + Tensor abs_error = math_ops.abs(error); + Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, + half * math_ops.pow(error, 2), + half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), + ops.convert_to_tensor(-1)); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs new file mode 100644 index 000000000..0c7a9b6e2 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -0,0 +1,20 @@ +namespace Tensorflow.Keras.Losses; + +public class LogCosh : LossFunctionWrapper +{ + public LogCosh( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "log_cosh" : name) + { } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor x = y_pred_dispatch - y_true_cast; + + return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), + ops.convert_to_tensor(-1)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs new file mode 100644 index 000000000..ce77f6d63 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -0,0 +1,51 @@ +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Losses; + +/// +/// Loss base class. +/// +public abstract class Loss : ILossFunc +{ + protected string reduction; + protected string name; + bool _allow_sum_over_batch_size; + protected bool from_logits = false; + string _name_scope; + + public string Reduction => reduction; + public string Name => name; + + public Loss(string reduction = ReductionV2.AUTO, + string name = null, + bool from_logits = false) + { + this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; + this.name = name; + this.from_logits = from_logits; + _allow_sum_over_batch_size = false; + } + + public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1); + + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + var losses = Apply(y_true, y_pred, from_logits: from_logits); + var reduction = GetReduction(); + return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight); + } + + string GetReduction() + { + return reduction switch + { + ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE, + _ => reduction + }; + } + + void _set_name_scope() + { + _name_scope = name; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs b/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs new file mode 100644 index 000000000..f4ee2b346 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Losses; + +public abstract class LossFunctionWrapper : Loss +{ + public LossFunctionWrapper(string reduction = ReductionV2.AUTO, + string name = null, + bool from_logits = false) + : base(reduction: reduction, + name: name, + from_logits: from_logits) + { } +} diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs new file mode 100644 index 000000000..79f16a2eb --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs @@ -0,0 +1,52 @@ +namespace Tensorflow.Keras.Losses +{ + public class LossesApi : ILossesApi + { + public ILossFunc BinaryCrossentropy(bool from_logits = false, + float label_smoothing = 0, + int axis = -1, + string reduction = "auto", + string name = "binary_crossentropy") + => new BinaryCrossentropy(from_logits: from_logits, + label_smoothing: label_smoothing, + reduction: reduction, + name: name); + + public ILossFunc SparseCategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) + => new SparseCategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); + + public ILossFunc CategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) + => new CategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); + + public ILossFunc MeanSquaredError(string reduction = null, string name = null) + => new MeanSquaredError(reduction: reduction, name:name); + public ILossFunc MeanSquaredLogarithmicError(string reduction = null, string name = null) + => new MeanSquaredLogarithmicError(reduction: reduction, name: name); + + public ILossFunc MeanAbsolutePercentageError(string reduction = null, string name = null) + => new MeanAbsolutePercentageError(reduction: reduction, name: name); + + public ILossFunc MeanAbsoluteError(string reduction = null, string name = null) + => new MeanAbsoluteError(reduction: reduction, name: name); + + public ILossFunc CosineSimilarity(string reduction = null, int axis = -1, string name = null) + => new CosineSimilarity(reduction: reduction, axis: axis, name: name); + + public ILossFunc Huber(string reduction = null, string name = null, Tensor delta=null) + => new Huber(reduction: reduction, name: name, delta: delta); + + public ILossFunc LogCosh(string reduction = null, string name = null) + => new LogCosh(reduction: reduction, name: name); + + public ILossFunc SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25F, + float gamma = 2, + string reduction = "none", + string name = "sigmoid_focal_crossentropy") + => new SigmoidFocalCrossEntropy(from_logits: from_logits, + alpha: alpha, + gamma: gamma, + reduction: reduction, + name: name); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs new file mode 100644 index 000000000..19476a68a --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Losses; + +public class MeanAbsoluteError : LossFunctionWrapper +{ + public MeanAbsoluteError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1)); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs new file mode 100644 index 000000000..226c4237a --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs @@ -0,0 +1,17 @@ +namespace Tensorflow.Keras.Losses; + +public class MeanAbsolutePercentageError : LossFunctionWrapper +{ + public MeanAbsolutePercentageError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); + return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1)); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs new file mode 100644 index 000000000..a937c1963 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Losses; + +public class MeanSquaredError : LossFunctionWrapper +{ + public MeanSquaredError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name==null? "mean_squared_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1)); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs new file mode 100644 index 000000000..0a4e7d3c5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs @@ -0,0 +1,28 @@ +namespace Tensorflow.Keras.Losses; + +public class MeanSquaredLogarithmicError : LossFunctionWrapper +{ + public MeanSquaredLogarithmicError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name) + { } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor first_log = null, second_log = null; + if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) + { + first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); + second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0); + } + else + { + first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); + second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); + } + return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/ReductionV2.cs b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs new file mode 100644 index 000000000..4b6cbbfdb --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Losses +{ + public class ReductionV2 + { + public const string NONE = "none"; + public const string AUTO = "auto"; + public const string SUM = "sum"; + public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; + public const string WEIGHTED_MEAN = "weighted_mean"; + } +} diff --git a/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs new file mode 100644 index 000000000..ec6dcedf8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs @@ -0,0 +1,47 @@ +using static HDF.PInvoke.H5L.info_t; + +namespace Tensorflow.Keras.Losses; + +public class SigmoidFocalCrossEntropy : LossFunctionWrapper +{ + float _alpha; + float _gamma; + + public SigmoidFocalCrossEntropy(bool from_logits = false, + float alpha = 0.25f, + float gamma = 2.0f, + string reduction = "none", + string name = "sigmoid_focal_crossentropy") : + base(reduction: reduction, + name: name, + from_logits: from_logits) + { + _alpha = alpha; + _gamma = gamma; + } + + public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) + { + y_true = tf.cast(y_true, dtype: y_pred.dtype); + var ce = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits); + var pred_prob = from_logits ? tf.sigmoid(y_pred) : y_pred; + + var p_t = (y_true * pred_prob) + ((1f - y_true) * (1f - pred_prob)); + Tensor alpha_factor = constant_op.constant(1.0f); + Tensor modulating_factor = constant_op.constant(1.0f); + + if(_alpha > 0) + { + var alpha = tf.cast(constant_op.constant(_alpha), dtype: y_true.dtype); + alpha_factor = y_true * alpha + (1f - y_true) * (1f - alpha); + } + + if (_gamma > 0) + { + var gamma = tf.cast(constant_op.constant(_gamma), dtype: y_true.dtype); + modulating_factor = tf.pow(1f - p_t, gamma); + } + + return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis = -1); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs new file mode 100644 index 000000000..17ce2d30b --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs @@ -0,0 +1,41 @@ +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Losses; + +public class SparseCategoricalCrossentropy : LossFunctionWrapper +{ + private bool _from_logits = false; + + public SparseCategoricalCrossentropy( + bool from_logits = false, + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) + { + _from_logits = from_logits; + } + + public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) + { + target = tf.cast(target, dtype: TF_DataType.TF_INT64); + + if (!_from_logits) + { + var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); + output = tf.clip_by_value(output, epsilon, 1 - epsilon); + output = tf.log(output); + } + + // Try to adjust the shape so that rank of labels = rank of logits - 1. + var output_shape = array_ops.shape_v2(output); + var output_rank = output.shape.ndim; + var target_rank = target.shape.ndim; + var update_shape = target_rank != output_rank - 1; + if (update_shape) + { + target = array_ops.reshape(target, new int[] { -1 }); + output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); + } + return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Metrics/Accuracy.cs b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs new file mode 100644 index 000000000..93a724679 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class Accuracy : MeanMetricWrapper +{ + public Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.accuracy(yt, yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs new file mode 100644 index 000000000..2977588e9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class BinaryAccuracy : MeanMetricWrapper +{ + public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f) + : base((yt, yp) => metrics_utils.binary_matches(yt, yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs new file mode 100644 index 000000000..d15cf26c5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs @@ -0,0 +1,12 @@ +namespace Tensorflow.Keras.Metrics; + +public class CategoricalAccuracy : MeanMetricWrapper +{ + public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_categorical_matches( + tf.math.argmax(yt, axis: -1), yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs new file mode 100644 index 000000000..95720c413 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Metrics; + +public class CategoricalCrossentropy : MeanMetricWrapper +{ + public CategoricalCrossentropy(string name = "categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + float label_smoothing = 0f, + Axis? axis = null) + : base((yt, yp) => keras.metrics.categorical_crossentropy( + yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs new file mode 100644 index 000000000..2a26bcdfe --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class CosineSimilarity : MeanMetricWrapper +{ + public CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) + : base((yt, yp) => metrics_utils.cosine_similarity(yt, yp, axis: axis ?? -1), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/F1Score.cs b/src/TensorFlowNET.Keras/Metrics/F1Score.cs new file mode 100644 index 000000000..fc24136d8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/F1Score.cs @@ -0,0 +1,13 @@ +namespace Tensorflow.Keras.Metrics; + +public class F1Score : FBetaScore +{ + public F1Score(int num_classes, + string? average = null, + float? threshold = null, + string name = "f1_score", + TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs new file mode 100644 index 000000000..a40a7caa9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs @@ -0,0 +1,131 @@ +namespace Tensorflow.Keras.Metrics; + +public class FBetaScore : Metric +{ + int _num_classes; + string? _average; + Tensor _beta; + Tensor _threshold; + Axis _axis; + int[] _init_shape; + + IVariableV1 true_positives; + IVariableV1 false_positives; + IVariableV1 false_negatives; + IVariableV1 weights_intermediate; + + public FBetaScore(int num_classes, + string? average = null, + float beta = 0.1f, + float? threshold = null, + string name = "fbeta_score", + TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _num_classes = num_classes; + _average = average; + _beta = constant_op.constant(beta); + _dtype = dtype; + + if (threshold.HasValue) + { + _threshold = constant_op.constant(threshold); + } + + _init_shape = new int[0]; + + if (average != "micro") + { + _axis = 0; + _init_shape = new int[] { num_classes }; + } + + true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + if (_threshold == null) + { + _threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true); + // make sure [0, 0, 0] doesn't become [1, 1, 1] + // Use abs(x) > eps, instead of x != 0 to check for zero + y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12f); + } + else + { + y_pred = y_pred > _threshold; + } + + y_true = tf.cast(y_true, _dtype); + y_pred = tf.cast(y_pred, _dtype); + + true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight)); + false_positives.assign_add( + _weighted_sum(y_pred * (1 - y_true), sample_weight) + ); + false_negatives.assign_add( + _weighted_sum((1 - y_pred) * y_true, sample_weight) + ); + weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight)); + + return weights_intermediate.AsTensor(); + } + + Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null) + { + if (sample_weight != null) + { + val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1)); + } + + return tf.reduce_sum(val, axis: _axis); + } + + public override Tensor result() + { + var precision = tf.math.divide_no_nan( + true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor() + ); + var recall = tf.math.divide_no_nan( + true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor() + ); + + var mul_value = precision * recall; + var add_value = (tf.math.square(_beta) * precision) + recall; + var mean = tf.math.divide_no_nan(mul_value, add_value); + var f1_score = mean * (1 + tf.math.square(_beta)); + + Tensor weights; + if (_average == "weighted") + { + weights = tf.math.divide_no_nan( + weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor()) + ); + f1_score = tf.reduce_sum(f1_score * weights); + } + // micro, macro + else if (_average != null) + { + f1_score = tf.reduce_mean(f1_score); + } + + return f1_score; + } + + public override void reset_states() + { + var reset_value = np.zeros(_init_shape, dtype: _dtype); + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, reset_value), + (false_positives, reset_value), + (false_negatives, reset_value), + (weights_intermediate, reset_value) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/HammingLoss.cs b/src/TensorFlowNET.Keras/Metrics/HammingLoss.cs new file mode 100644 index 000000000..2b65424e9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/HammingLoss.cs @@ -0,0 +1,15 @@ +namespace Tensorflow.Keras.Metrics; + +public class HammingLoss : MeanMetricWrapper +{ + public HammingLoss(string mode, + NDArray threshold = null, + string name = "hamming_loss", + TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.hamming_loss_fn(yt, yp, threshold, mode), + name: name, + dtype: dtype) + { + _dtype = dtype; + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Mean.cs b/src/TensorFlowNET.Keras/Metrics/Mean.cs new file mode 100644 index 000000000..8a55690b1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Mean.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Keras.Metrics +{ + /// + /// Computes the (weighted) mean of the given values. + /// + public class Mean : Reduce + { + public Mean(string name = "mean", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(Reduction.WEIGHTED_MEAN, name, dtype: dtype) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs new file mode 100644 index 000000000..7173aae1d --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs @@ -0,0 +1,27 @@ +using System; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Metrics +{ + public class MeanMetricWrapper : Mean + { + Func _fn = null; + + public MeanMetricWrapper(Func fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _fn = fn; + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + y_true = math_ops.cast(y_true, _dtype); + y_pred = math_ops.cast(y_pred, _dtype); + + (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true: y_true); + + var matches = _fn(y_true, y_pred); + return update_state(matches, sample_weight: sample_weight); + } + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs new file mode 100644 index 000000000..435eebd48 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -0,0 +1,69 @@ +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Metrics +{ + /// + /// Encapsulates metric logic and state. + /// + public class Metric : Layer, IMetricFunc + { + protected IVariableV1 total; + protected IVariableV1 count; + protected string _reduction; + protected TF_DataType _dtype; + + public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid) + : base(new LayerArgs + { + Name = name, + DType = dtype + }) + { + stateful = true; + built = true; + } + + protected override IVariableV1 add_weight(string name, + Shape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + IRegularizer regularizer = null, + VariableSynchronization synchronization = VariableSynchronization.OnRead, + VariableAggregation aggregation = VariableAggregation.Sum, + bool trainable = true, + Func getter = null) + { + if (shape == null) + shape = new Shape(new int[0]); + + return tf_with(ops.init_scope(), delegate + { + return base.add_weight(name, shape, + dtype: dtype, + trainable: false, + initializer: initializer, + synchronization: synchronization, + aggregation: aggregation); + }); + } + + public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + => throw new NotImplementedException(""); + + public virtual void reset_states() + { + foreach (var v in Weights) + v.assign(0); + } + + public virtual Tensor result() + => throw new NotImplementedException(""); + + public override string ToString() + => $"{name} {(float)total.numpy()}/{(float)count.numpy()}"; + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs new file mode 100644 index 000000000..e3881cf1a --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -0,0 +1,121 @@ +namespace Tensorflow.Keras.Metrics +{ + public class MetricsApi : IMetricsApi + { + public Tensor binary_accuracy(Tensor y_true, Tensor y_pred) + { + float threshold = 0.5f; + y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype); + return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1); + } + + public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) + { + var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1)); + return math_ops.cast(eql, TF_DataType.TF_FLOAT); + } + + public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) + { + y_true = tf.cast(y_true, y_pred.dtype); + // var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype); + if (label_smoothing > 0) + { + var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype); + y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes); + } + return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis); + } + + public Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, int? ignore_class = null, Axis? axis = null) + { + return keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis ?? -1, ignore_class: ignore_class); + } + + /// + /// Calculates how often predictions matches integer labels. + /// + /// Integer ground truth values. + /// The prediction values. + /// Sparse categorical accuracy values. + public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred) + { + var y_pred_rank = y_pred.shape.ndim; + var y_true_rank = y_true.shape.ndim; + // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if (y_true_rank != -1 && y_pred_rank != -1 + && y_true.shape.ndim == y_pred.shape.ndim) + y_true = array_ops.squeeze(y_true, axis: new[] { -1 }); + y_pred = math_ops.argmax(y_pred, -1); + + // If the predicted output and actual output types don't match, force cast them + // to match. + if (y_pred.dtype != y_true.dtype) + y_pred = math_ops.cast(y_pred, y_true.dtype); + + return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); + } + + public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) + { + y_true = math_ops.cast(y_true, y_pred.dtype); + return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); + } + + public Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) + { + y_true = math_ops.cast(y_true, y_pred.dtype); + var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon()); + return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1); + } + + public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) + { + return metrics_utils.sparse_top_k_categorical_matches( + tf.math.argmax(y_true, axis: -1), y_pred, k + ); + } + + public IMetricFunc Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new Accuracy(name: name, dtype: dtype); + + public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5) + => new BinaryAccuracy(); + + public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new CategoricalAccuracy(name: name, dtype: dtype); + + public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null) + => new CategoricalCrossentropy(); + + public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) + => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); + + public IMetricFunc F1Score(int num_classes, string? average = null, float? threshold = null, string name = "f1_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype); + + public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float? threshold = null, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype); + + public IMetricFunc HammingLoss(string mode, float? threshold = null, string name = "hamming_loss", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new HammingLoss(mode, threshold: threshold, name: name, dtype: dtype); + + public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype); + + public IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "precision", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new Precision(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); + + public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype); + + public IMetricFunc SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, int? ignore_class = null, Axis? axis = null) + => new SparseCategoricalCrossentropy(name: name, dtype: dtype, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1); + + public IMetricFunc SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new SparseTopKCategoricalAccuracy(k: k, name: name, dtype: dtype); + + public IMetricFunc SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + => new SparseCategoricalAccuracy(name: name, dtype: dtype); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Precision.cs b/src/TensorFlowNET.Keras/Metrics/Precision.cs new file mode 100644 index 000000000..a01773e0e --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Precision.cs @@ -0,0 +1,55 @@ +namespace Tensorflow.Keras.Metrics; + +public class Precision : Metric +{ + Tensor _thresholds; + int _top_k; + int _class_id; + IVariableV1 true_positives; + IVariableV1 false_positives; + bool _thresholds_distributed_evenly; + + public Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _thresholds = constant_op.constant(new float[] { thresholds }); + _top_k = top_k; + _class_id = class_id; + true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + false_positives = add_weight("false_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + return metrics_utils.update_confusion_matrix_variables( + new Dictionary + { + { "tp", true_positives }, + { "fp", false_positives }, + }, + y_true, + y_pred, + thresholds: _thresholds, + thresholds_distributed_evenly: _thresholds_distributed_evenly, + top_k: _top_k, + class_id: _class_id, + sample_weight: sample_weight); + } + + public override Tensor result() + { + var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_positives)); + return _thresholds.size == 1 ? result[0] : result; + } + + public override void reset_states() + { + var num_thresholds = (int)_thresholds.size; + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, np.zeros(num_thresholds)), + (false_positives, np.zeros(num_thresholds)) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Recall.cs b/src/TensorFlowNET.Keras/Metrics/Recall.cs new file mode 100644 index 000000000..9b58bf5f7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Recall.cs @@ -0,0 +1,53 @@ +namespace Tensorflow.Keras.Metrics; + +public class Recall : Metric +{ + Tensor _thresholds; + int _top_k; + int _class_id; + IVariableV1 true_positives; + IVariableV1 false_negatives; + bool _thresholds_distributed_evenly; + + public Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base(name: name, dtype: dtype) + { + _thresholds = constant_op.constant(new float[] { thresholds }); + true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer()); + false_negatives = add_weight("false_negatives", shape: 1, initializer: tf.initializers.zeros_initializer()); + } + + public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + return metrics_utils.update_confusion_matrix_variables( + new Dictionary + { + { "tp", true_positives }, + { "fn", false_negatives }, + }, + y_true, + y_pred, + thresholds: _thresholds, + thresholds_distributed_evenly: _thresholds_distributed_evenly, + top_k: _top_k, + class_id: _class_id, + sample_weight: sample_weight); + } + + public override Tensor result() + { + var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_negatives)); + return _thresholds.size == 1 ? result[0] : result; + } + + public override void reset_states() + { + var num_thresholds = (int)_thresholds.size; + keras.backend.batch_set_value( + new List<(IVariableV1, NDArray)> + { + (true_positives, np.zeros(num_thresholds)), + (false_negatives, np.zeros(num_thresholds)) + }); + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Keras/Metrics/Reduce.cs new file mode 100644 index 000000000..8874719de --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Reduce.cs @@ -0,0 +1,74 @@ +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Metrics +{ + /// + /// Encapsulates metrics that perform a reduce operation on the values. + /// + public class Reduce : Metric + { + public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) + : base(name: name, dtype: dtype) + { + _reduction = reduction; + _dtype = dtype; + total = add_weight("total", initializer: tf.zeros_initializer); + + if (reduction == Reduction.WEIGHTED_MEAN || + reduction == Reduction.SUM_OVER_BATCH_SIZE) + { + count = add_weight("count", initializer: tf.zeros_initializer); + } + } + + public Tensor update_state(Tensor values, Tensor sample_weight = null) + { + if (sample_weight != null) + { + (values, _, sample_weight) = losses_utils.squeeze_or_expand_dimensions( + values, sample_weight: sample_weight); + + sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); + values = math_ops.multiply(values, sample_weight); + } + + Tensor update_total_op = null; + var value_sum = math_ops.reduce_sum(values); + tf_with(ops.control_dependencies(new[] { value_sum }), ctl => + { + update_total_op = total.assign_add(value_sum); + }); + + // Exit early if the reduction doesn't have a denominator. + if (_reduction == Reduction.SUM) + return update_total_op; + + // Update `count` for reductions that require a denominator. + Tensor num_values = null; + if (_reduction == Reduction.SUM_OVER_BATCH_SIZE) + num_values = math_ops.cast(array_ops.size(values), _dtype); + else if (_reduction == ReductionV2.WEIGHTED_MEAN) + { + if (sample_weight == null) + num_values = math_ops.cast(array_ops.size(values), _dtype); + else + num_values = math_ops.reduce_sum(sample_weight); + } + + return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl + => count.assign_add(num_values)); + } + + public override Tensor result() + { + if (_reduction == Reduction.SUM) + return array_ops.identity(total.AsTensor()); + else if (_reduction == Reduction.WEIGHTED_MEAN || _reduction == Reduction.SUM_OVER_BATCH_SIZE) + return math_ops.div_no_nan(total.AsTensor(), count.AsTensor()); + + return base.result(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs new file mode 100644 index 000000000..6cad9aac3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseCategoricalAccuracy : MeanMetricWrapper +{ + public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_categorical_matches(yt, yp), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs new file mode 100644 index 000000000..d517da913 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs @@ -0,0 +1,16 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseCategoricalCrossentropy : MeanMetricWrapper +{ + public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", + TF_DataType dtype = TF_DataType.TF_FLOAT, + bool from_logits = false, + int? ignore_class = null, + Axis? axis = null) + : base((yt, yp) => keras.metrics.sparse_categorical_crossentropy( + yt, yp, from_logits: from_logits, ignore_class: ignore_class, axis: axis ?? -1), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs new file mode 100644 index 000000000..eb6d9f3b3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Metrics; + +public class SparseTopKCategoricalAccuracy : MeanMetricWrapper +{ + public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches(yt, yp, k), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/Sum.cs b/src/TensorFlowNET.Keras/Metrics/Sum.cs new file mode 100644 index 000000000..bf69980c6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/Sum.cs @@ -0,0 +1,6 @@ +namespace Tensorflow.Keras.Metrics +{ + class Sum + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs new file mode 100644 index 000000000..63e941024 --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs @@ -0,0 +1,12 @@ +namespace Tensorflow.Keras.Metrics; + +public class TopKCategoricalAccuracy : MeanMetricWrapper +{ + public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT) + : base((yt, yp) => metrics_utils.sparse_top_k_categorical_matches( + tf.math.argmax(yt, axis: -1), yp, k), + name: name, + dtype: dtype) + { + } +} diff --git a/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs new file mode 100644 index 000000000..3c2f8a7be --- /dev/null +++ b/src/TensorFlowNET.Keras/Metrics/metrics_utils.cs @@ -0,0 +1,310 @@ +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Metrics; + +public class metrics_utils +{ + public static Tensor accuracy(Tensor y_true, Tensor y_pred) + { + if (y_true.dtype != y_pred.dtype) + y_pred = tf.cast(y_pred, y_true.dtype); + return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx()); + } + + public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f) + { + y_pred = tf.cast(y_pred > threshold, y_pred.dtype); + return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx()); + } + + public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, Axis? axis = null) + { + y_true = tf.linalg.l2_normalize(y_true, axis: axis ?? -1); + y_pred = tf.linalg.l2_normalize(y_pred, axis: axis ?? -1); + return tf.reduce_sum(y_true * y_pred, axis: axis ?? -1); + } + + public static Tensor hamming_loss_fn(Tensor y_true, Tensor y_pred, Tensor threshold, string mode) + { + if (threshold == null) + { + threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true); + // make sure [0, 0, 0] doesn't become [1, 1, 1] + // Use abs(x) > eps, instead of x != 0 to check for zero + y_pred = tf.logical_and(y_pred >= threshold, tf.abs(y_pred) > 1e-12f); + } + else + { + y_pred = y_pred > threshold; + } + + + y_true = tf.cast(y_true, tf.int32); + y_pred = tf.cast(y_pred, tf.int32); + + if (mode == "multiclass") + { + var nonzero = tf.cast(tf.math.count_nonzero(y_true * y_pred, axis: -1), tf.float32); + return 1.0 - nonzero; + } + else + { + var nonzero = tf.cast(tf.math.count_nonzero(y_true - y_pred, axis: -1), tf.float32); + return nonzero / y_true.shape[-1]; + } + } + + /// + /// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch. + /// + /// + /// + /// + public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred) + { + var reshape_matches = false; + var y_true_rank = y_true.shape.ndim; + var y_pred_rank = y_pred.shape.ndim; + var y_true_org_shape = tf.shape(y_true); + + if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim ) + { + reshape_matches = true; + y_true = tf.squeeze(y_true, new Shape(-1)); + } + y_pred = tf.math.argmax(y_pred, axis: -1); + y_pred = tf.cast(y_pred, y_true.dtype); + var matches = tf.cast( + tf.equal(y_true, y_pred), + dtype: keras.backend.floatx() + ); + + if (reshape_matches) + { + return tf.reshape(matches, shape: y_true_org_shape); + } + + return matches; + } + + public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5) + { + var reshape_matches = false; + var y_true_rank = y_true.shape.ndim; + var y_pred_rank = y_pred.shape.ndim; + var y_true_org_shape = tf.shape(y_true); + + if (y_pred_rank > 2) + { + y_pred = tf.reshape(y_pred, (-1, y_pred.shape[-1])); + } + + if (y_true_rank > 1) + { + reshape_matches = true; + y_true = tf.reshape(y_true, new Shape(-1)); + } + + var matches = tf.cast( + tf.math.in_top_k( + predictions: y_pred, targets: tf.cast(y_true, np.int32), k: k + ), + dtype: keras.backend.floatx() + ); + + if (reshape_matches) + { + return tf.reshape(matches, shape: y_true_org_shape); + } + + return matches; + } + + public static Tensor update_confusion_matrix_variables(Dictionary variables_to_update, + Tensor y_true, + Tensor y_pred, + Tensor thresholds, + int top_k, + int class_id, + Tensor sample_weight = null, + bool multi_label = false, + Tensor label_weights = null, + bool thresholds_distributed_evenly = false) + { + var variable_dtype = variables_to_update.Values.First().dtype; + y_true = tf.cast(y_true, dtype: variable_dtype); + y_pred = tf.cast(y_pred, dtype: variable_dtype); + var num_thresholds = thresholds.shape.dims[0]; + + Tensor one_thresh = null; + if (multi_label) + { + one_thresh = tf.equal(tf.cast(constant_op.constant(1), dtype:tf.int32), + tf.rank(thresholds), + name: "one_set_of_thresholds_cond"); + } + else + { + one_thresh = tf.cast(constant_op.constant(true), dtype: dtypes.@bool); + } + + if (sample_weight == null) + { + (y_pred, y_true, _) = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true); + } + else + { + sample_weight = tf.cast(sample_weight, dtype: variable_dtype); + (y_pred, y_true, sample_weight) = losses_utils.squeeze_or_expand_dimensions(y_pred, + y_true, + sample_weight: sample_weight); + } + + if (top_k > 0) + { + y_pred = _filter_top_k(y_pred, top_k); + } + + if (class_id > 0) + { + y_true = y_true[Slice.All, class_id]; + y_pred = y_pred[Slice.All, class_id]; + } + + if (thresholds_distributed_evenly) + { + throw new NotImplementedException(); + } + + var pred_shape = tf.shape(y_pred); + var num_predictions = pred_shape[0]; + + Tensor num_labels; + if (y_pred.shape.ndim == 1) + { + num_labels = constant_op.constant(1); + } + else + { + num_labels = tf.reduce_prod(pred_shape["1:"], axis: 0); + } + var thresh_label_tile = tf.where(one_thresh, num_labels, tf.ones(new int[0], dtype: tf.int32)); + + // Reshape predictions and labels, adding a dim for thresholding. + Tensor predictions_extra_dim, labels_extra_dim; + if (multi_label) + { + predictions_extra_dim = tf.expand_dims(y_pred, 0); + labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype: tf.@bool), 0); + } + + else + { + // Flatten predictions and labels when not multilabel. + predictions_extra_dim = tf.reshape(y_pred, (1, -1)); + labels_extra_dim = tf.reshape(tf.cast(y_true, dtype: tf.@bool), (1, -1)); + } + + // Tile the thresholds for every prediction. + object[] thresh_pretile_shape, thresh_tiles, data_tiles; + + if (multi_label) + { + thresh_pretile_shape = new object[] { num_thresholds, 1, -1 }; + thresh_tiles = new object[] { 1, num_predictions, thresh_label_tile }; + data_tiles = new object[] { num_thresholds, 1, 1 }; + } + else + { + thresh_pretile_shape = new object[] { num_thresholds, -1 }; + thresh_tiles = new object[] { 1, num_predictions * num_labels }; + data_tiles = new object[] { num_thresholds, 1 }; + } + var thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)); + + // Tile the predictions for every threshold. + var preds_tiled = tf.tile(predictions_extra_dim, data_tiles); + + // Compare predictions and threshold. + var pred_is_pos = tf.greater(preds_tiled, thresh_tiled); + + // Tile labels by number of thresholds + var label_is_pos = tf.tile(labels_extra_dim, data_tiles); + + Tensor weights_tiled = null; + + if (sample_weight != null) + { + /*sample_weight = broadcast_weights( + tf.cast(sample_weight, dtype: variable_dtype), y_pred);*/ + weights_tiled = tf.tile( + tf.reshape(sample_weight, thresh_tiles), data_tiles); + } + + if (label_weights != null && !multi_label) + { + throw new NotImplementedException(); + } + + Func weighted_assign_add + = (label, pred, weights, var) => + { + var label_and_pred = tf.cast(tf.logical_and(label, pred), dtype: var.dtype); + if (weights != null) + { + label_and_pred *= tf.cast(weights, dtype: var.dtype); + } + + return var.assign_add(tf.reduce_sum(label_and_pred, 1)); + }; + + + var loop_vars = new Dictionary + { + { "tp", (label_is_pos, pred_is_pos) } + }; + var update_tn = variables_to_update.ContainsKey("tn"); + var update_fp = variables_to_update.ContainsKey("fp"); + var update_fn = variables_to_update.ContainsKey("fn"); + + Tensor pred_is_neg = null; + if (update_fn || update_tn) + { + pred_is_neg = tf.logical_not(pred_is_pos); + loop_vars["fn"] = (label_is_pos, pred_is_neg); + } + + if(update_fp || update_tn) + { + var label_is_neg = tf.logical_not(label_is_pos); + loop_vars["fp"] = (label_is_neg, pred_is_pos); + if (update_tn) + { + loop_vars["tn"] = (label_is_neg, pred_is_neg); + } + } + + var update_ops = new List(); + foreach (var matrix_cond in loop_vars.Keys) + { + var (label, pred) = loop_vars[matrix_cond]; + if (variables_to_update.ContainsKey(matrix_cond)) + { + var op = weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond]); + update_ops.append(op); + } + } + + tf.group(update_ops.ToArray()); + return null; + } + + private static Tensor _filter_top_k(Tensor x, int k) + { + var NEG_INF = -1e10; + var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false); + var top_k_mask = tf.reduce_sum( + tf.one_hot(top_k_idx.Single, (int)x.shape[-1], axis: -1), axis: -2); + return x * top_k_mask + NEG_INF * (1 - top_k_mask); + } +} diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs new file mode 100644 index 000000000..2605c41e3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -0,0 +1,15 @@ +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Models; + +public class ModelsApi: IModelsApi +{ + public Functional from_config(FunctionalConfig config) + => Functional.from_config(config); + + public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) + { + return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; + } +} diff --git a/src/TensorFlowNET.Keras/Open.snk b/src/TensorFlowNET.Keras/Open.snk new file mode 100644 index 000000000..22a3cbd25 Binary files /dev/null and b/src/TensorFlowNET.Keras/Open.snk differ diff --git a/src/TensorFlowNET.Keras/Optimizers/Adam.cs b/src/TensorFlowNET.Keras/Optimizers/Adam.cs new file mode 100644 index 000000000..fc5ee4491 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/Adam.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Optimizer that implements the Adam algorithm. + /// Adam optimization is a stochastic gradient descent method that is based on + /// adaptive estimation of first-order and second-order moments. + /// + public class Adam : OptimizerV2 + { + protected override string _name => "Adam"; + float epsilon = 1e-7f; + bool amsgrad = false; + + public Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam") : base(new OptimizerV2Args { }) + { + _set_hyper("learning_rate", learning_rate); + // _set_hyper("decay", _initial_decay); + _set_hyper("beta_1", beta_1); + _set_hyper("beta_2", beta_2); + this.epsilon = epsilon; + this.amsgrad = amsgrad; + } + + protected override void _create_slots(IVariableV1[] var_list) + { + foreach (var var in var_list) + add_slot(var, "m"); + foreach (var var in var_list) + add_slot(var, "v"); + if (amsgrad) + foreach (var var in var_list) + add_slot(var, "vhat"); + } + + protected override void _prepare_local(DeviceDType device_dtype, Dictionary> apply_state) + { + base._prepare_local(device_dtype, apply_state); + var var_dtype = device_dtype.DType; + var var_device = device_dtype.Device; + var local_step = math_ops.cast(iterations + 1, var_dtype); + var beta_1_t = array_ops.identity(_get_hyper("beta_1", var_dtype)); + var beta_2_t = array_ops.identity(_get_hyper("beta_2", var_dtype)); + var beta_1_power = math_ops.pow(beta_1_t, local_step); + var beta_2_power = math_ops.pow(beta_2_t, local_step); + var lr = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)); + // update state + apply_state[device_dtype]["lr"] = lr; + apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(epsilon); + apply_state[device_dtype]["beta_1_t"] = beta_1_t; + apply_state[device_dtype]["beta_1_power"] = beta_1_power; + apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t; + apply_state[device_dtype]["beta_2_t"] = beta_2_t; + apply_state[device_dtype]["beta_2_power"] = beta_2_power; + apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t; + } + + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> apply_state) + { + var (var_device, var_dtype) = (var.Device, var.dtype.as_base_dtype()); + var coefficients = apply_state.FirstOrDefault(x => x.Key.Device == var_device && x.Key.DType == var_dtype).Value ?? _fallback_apply_state(var_device, var_dtype); + var m = get_slot(var, "m"); + var v = get_slot(var, "v"); + + if (!amsgrad) + return gen_training_ops.resource_apply_adam(var.Handle, + m.Handle, + v.Handle, + coefficients["beta_1_power"], + coefficients["beta_2_power"], + coefficients["lr_t"], + coefficients["beta_1_t"], + coefficients["beta_2_t"], + coefficients["epsilon"], + grad, + use_locking: _use_locking); + else + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/AdamW.cs b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs new file mode 100644 index 000000000..d111b5d3a --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs @@ -0,0 +1,64 @@ +namespace Tensorflow.Keras.Optimizers +{ + public class AdamW : Adam + { + string name; + float weight_decay; + DeviceDType deType; + List no_decay_params = null; + public AdamW(float learning_rate= 0.001f, + float weight_decay= 0.004f, + float beta_1= 0.9f, + float beta_2= 0.999f, + float epsilon= 1e-7f, + bool amsgrad = false, + List no_decay_params = null, + string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad) + { + this.name = name; + this.weight_decay = weight_decay; + this.no_decay_params = no_decay_params; + } + + protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary> apply_state) + { + bool do_decay = _do_use_weight_decay(var.Name); + if (do_decay) return var.assign_add( + -learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]); + return tf.no_op(); + } + + + protected bool _do_use_weight_decay(string param_name) + { + // Whether to use L2 weight decay for `param_name`. + if (this.weight_decay == 0) + return false; + + if (this.no_decay_params != null) + { + foreach (var name in no_decay_params) + { + if (param_name.Contains(name)) return false; + } + + } + return true; + } + + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> apply_state) + { + var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state); + tf.control_dependencies(new[] { decay }); + return base._resource_apply_dense(var, grad, apply_state); + } + + protected override void _prepare_local(DeviceDType device_dtype, Dictionary> apply_state) + { + this.deType = device_dtype; + base._prepare_local(device_dtype, apply_state); + apply_state[device_dtype]["weight_decay"] = tf.constant( + weight_decay, name: "adam_weight_decay_rate"); + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/DeviceDType.cs b/src/TensorFlowNET.Keras/Optimizers/DeviceDType.cs new file mode 100644 index 000000000..deaaf438b --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/DeviceDType.cs @@ -0,0 +1,23 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Optimizers +{ + public class DeviceDType : IEqualityComparer + { + public string Device { get; set; } + public TF_DataType DType { get; set; } + + public bool Equals(DeviceDType x, DeviceDType y) + { + return x.ToString() == y.ToString(); + } + + public int GetHashCode(DeviceDType obj) + { + return 0; + } + + public override string ToString() + => $"{Device}, {DType}"; + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/LearningRateSchedule.cs b/src/TensorFlowNET.Keras/Optimizers/LearningRateSchedule.cs new file mode 100644 index 000000000..8d3f8b065 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/LearningRateSchedule.cs @@ -0,0 +1,10 @@ +namespace Tensorflow.Keras.Optimizers +{ + public class LearningRateSchedule + { + public LearningRateSchedule() + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs new file mode 100644 index 000000000..a237499f9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs @@ -0,0 +1,77 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Optimizers +{ + public class OptimizerApi: IOptimizerApi + { + /// + /// Adam optimization is a stochastic gradient descent method that is based on + /// adaptive estimation of first-order and second-order moments. + /// + /// + /// + /// + /// + /// + /// + /// + public IOptimizer Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam") + => new Adam(learning_rate: learning_rate, + beta_1: beta_1, + beta_2: beta_2, + epsilon: epsilon, + amsgrad: amsgrad, + name: name); + + public IOptimizer AdamW(float learning_rate = 0.001f, + float weight_decay = 0.004f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + List no_decay_params = null, + string name = "AdamW") => new AdamW(learning_rate: learning_rate, + beta_1: beta_1, + beta_2: beta_2, + epsilon: epsilon, + amsgrad: amsgrad, + name: name, + weight_decay: weight_decay, + no_decay_params: no_decay_params); + + /// + /// Construct a new RMSprop optimizer. + /// + /// + /// + /// + /// + /// + /// + /// + public IOptimizer RMSprop(float learning_rate = 0.001f, + float rho = 0.9f, + float momentum = 0.0f, + float epsilon = 1e-7f, + bool centered = false, + string name = "RMSprop") + => new RMSprop(new RMSpropArgs + { + LearningRate = learning_rate, + RHO = rho, + Momentum = momentum, + Epsilon = epsilon, + Centered = centered, + Name = name + }); + + public IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f) + => new SGD(learning_rate, momentum); + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs new file mode 100644 index 000000000..1e4dbe086 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -0,0 +1,323 @@ +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Updated base class for optimizers. + /// + public class OptimizerV2 : Trackable, IOptimizer + { + OptimizerV2Args args; + protected bool _hypers_created; + protected virtual string _name { get; } + + protected IVariableV1 _iterations; + protected ResourceVariable iterations => _iterations as ResourceVariable; + List _weights; + protected Dictionary _hyper; + protected Dictionary _hyper_variables; + protected bool _momentum; + protected float _initial_decay = 0.0f; + protected bool _use_locking = true; + + public IVariableV1 lr + => _hyper_variables["learning_rate"]; + + Dictionary> _slots; + List _slot_names; + + public OptimizerV2(OptimizerV2Args args) : base() + { + this.args = args; + _weights = new List(); + _hyper = new Dictionary(); + _hyper_variables = new Dictionary(); + _slots = new Dictionary>(); + _slot_names = new List(); + + _set_hyper("learning_rate", args.LearningRate); + _set_hyper("decay", args.InitialDecay); + } + + public void apply_gradients((Tensor, IVariableV1) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + => apply_gradients(new[] { grads_and_vars }, + name: name, + experimental_aggregate_gradients: experimental_aggregate_gradients); + + /// + /// Apply gradients to variables. + /// + /// + /// + /// + public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + { + var var_list = grads_and_vars.Select(x => x.Item2).ToArray(); + tf_with(ops.name_scope(_name), delegate + { + ops.init_scope(); + _create_all_weights(var_list); + if (grads_and_vars == null || grads_and_vars.Count() == 0) + return control_flow_ops.no_op(); + + var apply_state = _prepare(var_list); + // if(experimental_aggregate_gradients) + { + // var reduced_grads = _aggregate_gradients(grads_and_vars); + _distributed_apply(grads_and_vars, name, apply_state); + } + + return null; + }); + } + + public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + => apply_gradients(new[] { grads_and_vars }, + name: name, + experimental_aggregate_gradients: experimental_aggregate_gradients); + + /// + /// Apply gradients to variables. + /// + /// + /// + /// + public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + { + var var_list = grads_and_vars.Select(x => x.Item2).ToArray(); + tf_with(ops.name_scope(_name), delegate + { + ops.init_scope(); + _create_all_weights(var_list); + if (grads_and_vars == null || grads_and_vars.Count() == 0) + return control_flow_ops.no_op(); + + var apply_state = _prepare(var_list); + // if(experimental_aggregate_gradients) + { + // var reduced_grads = _aggregate_gradients(grads_and_vars); + _distributed_apply(grads_and_vars.Select(x => (x.Item1, (IVariableV1)x.Item2)), name, apply_state); + } + + return null; + }); + } + + void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary> apply_state) + { + _resource_apply_dense(var, grad, apply_state); + // if var.constraint is not None: + // with ops.control_dependencies([update_op]): + // return var.assign(var.constraint(var)) + } + + protected virtual Operation _resource_apply_dense(IVariableV1 var, + Tensor grad, + Dictionary> _apply_state) + { + throw new NotImplementedException("_resource_apply_dense"); + } + + void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, + string name, + Dictionary> _apply_state) + { + tf_with(ops.name_scope(name, "", new { skip_on_eager = true }), delegate + { + foreach (var (grad, var) in grads_and_vars) + { + tf_with(ops.name_scope("update"), delegate + { + apply_grad_to_update_var(var, grad, _apply_state); + }); + } + + _iterations.assign_add(ops.convert_to_tensor(1, dtype: _iterations.dtype)); + }); + } + + public Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) + { + return grads_and_vars.Select(x => x.Item1).ToArray(); + } + + public Tensor[] clip_gradients(Tensor[] grads) + { + return grads; + } + + protected IVariableV1 get_slot(IVariableV1 var, string slot_name) + { + var slot_dict = _slots[var.UniqueId]; + return slot_dict[slot_name]; + } + + Dictionary> _prepare(IVariableV1[] var_list) + { + var _apply_state = new Dictionary>(); + var keys = var_list.Select(x => new DeviceDType + { + Device = x.Device, + DType = x.dtype.as_base_dtype() + }).Distinct(new DeviceDType()).ToArray(); + + foreach (var device_dtype in keys) + { + _apply_state[device_dtype] = new Dictionary(); + _prepare_local(device_dtype, _apply_state); + } + + return _apply_state; + } + + protected Dictionary _fallback_apply_state(string var_device, TF_DataType var_dtype) + { + throw new NotImplementedException(""); + } + + protected virtual void _prepare_local(DeviceDType device_dtype, + Dictionary> _apply_state) + { + if (_hyper.ContainsKey("learning_rate")) + { + var lr_t = array_ops.identity(_decayed_lr(device_dtype.DType)); + _apply_state[device_dtype]["lr_t"] = lr_t; + } + } + + Tensor _decayed_lr(TF_DataType var_dtype) + { + var lr_t = _get_hyper("learning_rate", var_dtype); + if (_initial_decay > 0.0f) + { + throw new NotImplementedException(""); + } + return lr_t; + } + + protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) + { + var value = _hyper_variables[name]; + return math_ops.cast(value, dtype); + } + + void _create_all_weights(IVariableV1[] var_list) + { + if (_iterations == null) + { + _iterations = add_weight("iter", + shape: new int[0], + dtype: TF_DataType.TF_INT64, + trainable: false, + aggregation: VariableAggregation.OnlyFirstReplica); + _weights.Add(_iterations); + } + + _create_hypers(); + _create_slots(var_list); + } + + protected void _set_hyper(string name, float value) + { + _hyper[name] = value; + } + + void _create_hypers() + { + if (_hypers_created) + return; + foreach (var dict in _hyper) + { + var name = dict.Key; + var value = dict.Value; + _hyper_variables[name] = add_weight( + name, + shape: new int[0], + trainable: false, + initializer: tf.constant_initializer(value), + aggregation: VariableAggregation.OnlyFirstReplica); + } + _hypers_created = true; + } + + protected virtual void _create_slots(IVariableV1[] var_list) + { + if (_momentum) + { + /*for var in var_list: + self.add_slot(var, "momentum")*/ + } + } + + public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) + { + if (initializer == null) + initializer = tf.zeros_initializer; + + if (!_slot_names.Contains(slot_name)) + _slot_names.append(slot_name); + + if (!_slots.ContainsKey(var.UniqueId)) + _slots[var.UniqueId] = new Dictionary(); + var slot_dict = _slots[var.UniqueId]; + if (!slot_dict.ContainsKey(slot_name)) + { + var weight = tf.Variable(initializer, + dtype: var.dtype, + trainable: false, + shape: var.shape, + name: $"{var.Name}/{slot_name}"); + + slot_dict[slot_name] = weight; + _weights.append(weight); + return weight; + } + else + { + return slot_dict[slot_name]; + } + } + + ResourceVariable add_weight(string name, + Shape shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + bool trainable = false, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + if (initializer == null) + initializer = tf.zeros_initializer; + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + var variable = _add_variable_with_custom_getter(new VariableArgs + { + Name = name, + Shape = shape, + Getter = base_layer_utils.make_variable, + DType = dtype, + Overwrite = true, + Initializer = initializer, + Trainable = trainable, + UseResource = true, + Synchronization = synchronization, + Aggregation = aggregation + }); + + return variable as ResourceVariable; + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/PolynomialDecay.cs b/src/TensorFlowNET.Keras/Optimizers/PolynomialDecay.cs new file mode 100644 index 000000000..b2594f442 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/PolynomialDecay.cs @@ -0,0 +1,65 @@ +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// A LearningRateSchedule that uses a polynomial decay schedule. + /// + public class PolynomialDecay : LearningRateSchedule + { + float initial_learning_rate; + float decay_steps; + float end_learning_rate; + float power; + bool cycle; + string name; + + public PolynomialDecay(float initial_learning_rate, + float decay_steps, + float end_learning_rate = 0.0001f, + float power = 1.0f, + bool cycle = false, + string name = null) : base() + { + this.initial_learning_rate = initial_learning_rate; + this.decay_steps = decay_steps; + this.end_learning_rate = end_learning_rate; + this.power = power; + this.cycle = cycle; + this.name = name; + } + + public Tensor __call__(IVariableV1 step) + { + return tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => + { + name = scope; + var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); + var dtype = initial_learning_rate_tensor.dtype; + var end_learning_rate_tensor = constant_op.constant(end_learning_rate, dtype); + var power_tensor = constant_op.constant(power, dtype); + + var global_step_recomp = constant_op.constant(step, dtype); + var decay_steps_recomp = constant_op.constant(decay_steps, dtype); + + if (cycle) + { + throw new NotImplementedException("PolynomialDecay cycle"); + } + else + { + // Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps); + } + + var p = tf.divide(global_step_recomp, decay_steps_recomp); + var pow = tf.pow(1 - p, power_tensor); + var m = math_ops.multiply(initial_learning_rate_tensor - end_learning_rate_tensor, pow); + return math_ops.add(m, + end_learning_rate_tensor, + name: name); + }); + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/RMSprop.cs b/src/TensorFlowNET.Keras/Optimizers/RMSprop.cs new file mode 100644 index 000000000..51fffefcd --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/RMSprop.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Optimizer that implements the RMSprop algorithm. + /// + public class RMSprop : OptimizerV2 + { + RMSpropArgs args; + bool centered => args.Centered; + protected override string _name => "RMSprop"; + + public RMSprop(RMSpropArgs args) : base(args) + { + this.args = args; + _set_hyper("rho", args.RHO); + _set_hyper("momentum", args.Momentum); + } + + protected override void _create_slots(IVariableV1[] var_list) + { + foreach (var var in var_list) + add_slot(var, "rms"); + if (_momentum) + foreach (var var in var_list) + add_slot(var, "momentum"); + if (centered) + foreach (var var in var_list) + add_slot(var, "mg"); + } + + protected override void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state) + { + base._prepare_local(device_dtype, _apply_state); + var rho = array_ops.identity(_get_hyper("rho", device_dtype.DType)); + _apply_state[device_dtype]["neg_lr_t"] = -_apply_state[device_dtype]["lr_t"]; + _apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(args.Epsilon, dtype: device_dtype.DType); + _apply_state[device_dtype]["rho"] = rho; + _apply_state[device_dtype]["momentum"] = array_ops.identity(_get_hyper("momentum", device_dtype.DType)); + _apply_state[device_dtype]["one_minus_rho"] = 1.0f - rho; + } + + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state) + { + Dictionary coefficients = null; + foreach (var state in _apply_state) + { + if (state.Key.DType == var.dtype.as_base_dtype() + && state.Key.Device == var.Device) + { + coefficients = state.Value; + break; + } + } + + var rms = get_slot(var, "rms"); + if (_momentum) + { + throw new NotImplementedException(""); + } + else + { + var rms_t = coefficients["rho"] * rms.AsTensor() + coefficients["one_minus_rho"] * math_ops.square(grad); + rms_t = state_ops.assign(rms, rms_t, use_locking: _use_locking); + var denom_t = rms_t; + if (centered) + { + throw new NotImplementedException(""); + } + var var_t = var.AsTensor() - coefficients["lr_t"] * grad / (math_ops.sqrt(denom_t) + coefficients["epsilon"]); + return state_ops.assign(var, var_t, use_locking: _use_locking).op; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs b/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs new file mode 100644 index 000000000..e5cfd2daa --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Saving; +using Tensorflow.Train; +using Tensorflow.Training; + +namespace Tensorflow.Keras.Optimizers +{ + public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig + { + public String Identifier { get; } = "optimizer"; + public int Version { get; } = 2; + public int MinConsumerVersion { get; } = 1; + public int MinProducerVersion { get; } = 1; + public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" }) + { + _hypers_created = true; + } + + public IKerasConfig get_config() + { + throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " + + "supported. Please file a feature request if this limitation bothers you."); + } + + public void SetValue(object name, object value) + { + if(name is not String str) + { + throw new TypeError($"The name of value to set must be string, but got {name.GetType()}"); + } + if(value is Trackable trackable) + { + _track_trackable(trackable, str, overwrite: true); + } + if(value is IVariableV1 resource_variable) + { + if (!_hyper_variables.ContainsKey(str)) + { + _hyper_variables[str] = resource_variable; + } + else + { + keras.backend.set_value(resource_variable, value); + } + } + else if (value is float f) + { + _hyper[str] = f; + } + else + { + throw new NotImplementedException(); + } + } + + public Trackable FromProto(SavedUserObject proto) + { + return new RestoredOptimizer(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Optimizers/SGD.cs b/src/TensorFlowNET.Keras/Optimizers/SGD.cs new file mode 100644 index 000000000..1d9ceb810 --- /dev/null +++ b/src/TensorFlowNET.Keras/Optimizers/SGD.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Optimizers +{ + public class SGD : OptimizerV2 + { + protected override string _name => "SGD"; + +#pragma warning disable CS0169 // The field 'SGD.nesterov' is never used + bool nesterov; +#pragma warning restore CS0169 // The field 'SGD.nesterov' is never used + + public SGD(float learning_rate, + float momentum = 0.0f, + bool nesterov = false, + float decay = 0.0f) : base(new OptimizerV2Args { }) + { + _set_hyper("learning_rate", learning_rate); + _set_hyper("decay", decay); + + _momentum = momentum > 0; + if (momentum < 0 || momentum > 1) + throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}."); + + _set_hyper("momentum", momentum); + +#pragma warning disable CS1717 // Assignment made to same variable + nesterov = nesterov; +#pragma warning restore CS1717 // Assignment made to same variable + } + + protected override void _create_slots(IVariableV1[] var_list) + { + if (_momentum) + foreach (var var in var_list) + add_slot(var, "momentum"); + } + + protected override void _prepare_local(DeviceDType device_dtype, + Dictionary> _apply_state) + { + base._prepare_local(device_dtype, _apply_state); + + _apply_state[device_dtype]["momentum"] = array_ops.identity( + _get_hyper("momentum", device_dtype.DType)); + } + + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state) + { + if (_momentum) + { + var momentum_var = get_slot(var, "momentum"); + return gen_training_ops.resource_apply_keras_momentum( + var.Handle, + momentum_var.Handle, + _get_hyper("learning_rate", var.dtype), + grad, + _get_hyper("momentum", var.dtype), + use_locking: _use_locking, + use_nesterov: nesterov); + } + var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); + + return gen_training_ops.resource_apply_gradient_descent(var.Handle, + _apply_state[device_dtype]["lr_t"], + grad, + use_locking: _use_locking); + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.cs new file mode 100644 index 000000000..bb17f5941 --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.cs @@ -0,0 +1,18 @@ +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Preprocessings +{ + public partial class DatasetUtils + { + public IDatasetV2 labels_to_dataset(int[] labels, string label_mode, int num_classes) + { + var label_ds = tf.data.Dataset.from_tensor_slices(labels); + if (label_mode == "binary") + throw new NotImplementedException(""); + else if (label_mode == "categorical") + throw new NotImplementedException(""); + return label_ds; + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs new file mode 100644 index 000000000..18ca404ef --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs @@ -0,0 +1,43 @@ +using System; +using System.Linq; + +namespace Tensorflow.Keras.Preprocessings +{ + public partial class DatasetUtils + { + /// + /// Potentially restict samples and labels to a training or validation split. + /// + /// + /// + /// + /// + /// + public (T1[], T2[]) get_training_or_validation_split(T1[] samples, + T2[] labels, + float validation_split, + string subset) + { + if (string.IsNullOrEmpty(subset)) + return (samples, labels); + + var num_val_samples = Convert.ToInt32(samples.Length * validation_split); + if (subset == "training") + { + Binding.tf_output_redirect.WriteLine($"Using {samples.Length - num_val_samples} files for training."); + samples = samples.Take(samples.Length - num_val_samples).ToArray(); + labels = labels.Take(labels.Length - num_val_samples).ToArray(); + } + else if (subset == "validation") + { + Binding.tf_output_redirect.WriteLine($"Using {num_val_samples} files for validation."); + samples = samples.Skip(samples.Length - num_val_samples).ToArray(); + labels = labels.Skip(labels.Length - num_val_samples).ToArray(); + } + else + throw new NotImplementedException(""); + + return (samples, labels); + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs new file mode 100644 index 000000000..1eb4f431c --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs @@ -0,0 +1,69 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Preprocessings +{ + public partial class DatasetUtils + { + /// + /// Make list of all files in the subdirs of `directory`, with their labels. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// file_paths, labels, class_names + /// + public (string[], int[], string[]) index_directory(string directory, + string labels, + string[] formats = null, + string[] class_names = null, + bool shuffle = true, + int? seed = null, + bool follow_links = false) + { + var label_list = new List(); + var file_paths = new List(); + + var class_dirs = Directory.GetDirectories(directory); + class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray(); + + for (var label = 0; label < class_dirs.Length; label++) + { + var files = Directory.GetFiles(class_dirs[label]); + file_paths.AddRange(files); + label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); + } + + var return_labels = label_list.Select(x => x).ToArray(); + var return_file_paths = file_paths.Select(x => x).ToArray(); + + if (shuffle) + { + if (!seed.HasValue) + seed = np.random.randint((int)1e6); + var random_index = np.arange(label_list.Count); + tf.set_random_seed(seed.Value); + np.random.shuffle(random_index); + var index = random_index.ToArray(); + + for (int i = 0; i < label_list.Count; i++) + { + return_labels[i] = label_list[index[i]]; + return_file_paths[i] = file_paths[index[i]]; + } + } + + Binding.tf_output_redirect.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); + return (return_file_paths, return_labels, class_names); + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs new file mode 100644 index 000000000..0be7f1e6c --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs @@ -0,0 +1,26 @@ +using System; +using System.IO; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras +{ + public partial class Preprocessing + { + /// + /// Image resizing layer + /// + /// + /// + /// + /// + public ILayer Resizing(int height, int width, string interpolation = "bilinear") + => new Resizing(new ResizingArgs + { + Height = height, + Width = width, + Interpolation = interpolation + }); + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs new file mode 100644 index 000000000..94fc4a207 --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs @@ -0,0 +1,30 @@ +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Preprocessings; + +namespace Tensorflow.Keras +{ + public partial class Preprocessing : IPreprocessing + { + public Sequence sequence => new Sequence(); + public DatasetUtils dataset_utils => new DatasetUtils(); + + public TextApi text => _text; + + private static TextApi _text = new TextApi(); + + public ILayer TextVectorization(Func standardize = null, + string split = "whitespace", + int max_tokens = -1, + string output_mode = "int", + int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs + { + Standardize = standardize, + Split = split, + MaxTokens = max_tokens, + OutputMode = output_mode, + OutputSequenceLength = output_sequence_length + }); + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs new file mode 100644 index 000000000..377ac4de7 --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -0,0 +1,187 @@ +using static Tensorflow.KerasApi; +using static Tensorflow.Binding; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras +{ + public partial class Preprocessing + { + public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" }; + + /// + /// Function that calculates the classification statistics for a given array of classified data. + /// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array. + /// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model. + /// + /// + /// code from copilot + /// + /// + /// + Dictionary get_classification_statistics(int[] label_ids, string[] label_class_names) + { + var countDict = label_ids.GroupBy(x => x) + .ToDictionary(g => g.Key, g => g.Count()); + var totalCount = label_ids.Length; + var ratioDict = label_class_names.ToDictionary(name => name, + name => + (double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name)) + ? countDict[Array.IndexOf(label_class_names, name)] : 0) + / totalCount); + + print("Classification statistics:"); + foreach (string labelName in label_class_names) + { + double ratio = ratioDict[labelName]; + print($"{labelName}: {ratio * 100:F2}%"); + } + + return ratioDict; + } + + /// + /// Generates a `tf.data.Dataset` from image files in a directory. + /// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory + /// + /// Directory where the data is located. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public IDatasetV2 image_dataset_from_directory(string directory, + string labels = "inferred", + string label_mode = "int", + string[] class_names = null, + string color_mode = "rgb", + int batch_size = 32, + Shape image_size = null, + bool shuffle = true, + int? seed = null, + float validation_split = 0.2f, + string subset = null, + string interpolation = "bilinear", + bool follow_links = false) + { + int num_channels = 0; + if (color_mode == "rgb") + num_channels = 3; + + var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, + labels, + formats: WHITELIST_FORMATS, + class_names: class_names, + shuffle: shuffle, + seed: seed, + follow_links: follow_links); + + (image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset); + get_classification_statistics(label_list, class_name_list); + + var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); + if (shuffle) + dataset = dataset.shuffle(batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + dataset.class_names = class_name_list; + return dataset; + } + + public IDatasetV2 text_dataset_from_directory(string directory, + string labels = "inferred", + string label_mode = "int", + string[] class_names = null, + int batch_size = 32, + bool shuffle = true, + int max_length = -1, + int? seed = null, + float validation_split = 0.2f, + string subset = null, + bool follow_links = false) + { + var (file_paths, label_list, class_name_list) = dataset_utils.index_directory( + directory, + labels, + formats: new[] { ".txt" }, + class_names: class_names, + shuffle: shuffle, + seed: seed, + follow_links: follow_links); + + (file_paths, label_list) = dataset_utils.get_training_or_validation_split( + file_paths, label_list, validation_split, subset); + + var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length); + if (shuffle) + dataset = dataset.shuffle(batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + dataset.class_names = class_name_list; + return dataset; + } + + /// + /// Creates a dataset of sliding windows over a timeseries provided as array. + /// + /// + /// + /// + /// + /// + /// + /// + public IDatasetV2 timeseries_dataset_from_array(Tensor data, int sequence_length, + int sequence_stride = 1, int sampling_rate = 1, int batch_size = 128, + bool shuffle = false, int seed = (int)1e6, int start_index = 0, int? end_index = null) + { + if (!end_index.HasValue) + end_index = len(data); + + var num_seqs = end_index.Value - start_index - (sequence_length * sampling_rate) + 1; + var index_dtype = num_seqs < 2147483647 ? tf.int32 : tf.int64; + var start_positions = np.arange(0, num_seqs, sequence_stride); + if (shuffle) + { + tf.set_random_seed(seed); + np.random.shuffle(start_positions); + } + + var sequence_length_tensor = constant_op.constant(sequence_length, dtype: index_dtype); + var sampling_rate_tensor = constant_op.constant(sampling_rate, dtype: index_dtype); + + var start_positions_tensor = tf.constant(start_positions); + var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); + var r = tf.data.Dataset.range(len(start_positions)); + var z = tf.data.Dataset.zip(r, positions_ds); + var indices = z.map(m => + { + var (i, positions) = m; + return tf.range(positions.Single[i], positions.Single[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); + }, num_parallel_calls: -1); + var dataset = sequences_from_indices(data, indices, start_index, end_index); + + if (shuffle) + dataset = dataset.shuffle(buffer_size: batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + return dataset; + } + + IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start_index, int? end_index) + { + var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); + dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) + .map(x => + { + var (steps, indx) = x; + return array_ops.gather(steps, indx); + }, num_parallel_calls: -1); + return dataset; + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs new file mode 100644 index 000000000..232f81eb5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs @@ -0,0 +1,92 @@ +using System.IO; +using static Tensorflow.Binding; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras +{ + public partial class Preprocessing + { + + /// + /// 图片路径转为数据处理用的dataset + /// 通常用于预测时读取图片 + /// + /// + /// + /// + /// + /// 用于调整大小的插值方法。支持`bilinear`、`nearest`、`bicubic`、`area`、`lanczos3`、`lanczos5`、`gaussian`、`mitchellcubic`。 + /// 默认为`'bilinear'`。 + /// + /// + public IDatasetV2 paths_to_dataset(string[] image_paths, + Shape image_size, + int num_channels = 3, + int num_classes = 6, + string interpolation = "bilinear") + { + var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); + var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation)); + var label_ds = dataset_utils.labels_to_dataset(new int[num_classes] , "", num_classes); + + return img_ds; + } + + public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, + Shape image_size, + int num_channels, + int[] labels, + string label_mode, + int num_classes, + string interpolation) + { + var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); + var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation)); + + if (label_mode == "int") + { + var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); + img_ds = tf.data.Dataset.zip(img_ds, label_ds); + } + + return img_ds; + } + + Tensor path_to_image(Tensor path, Shape image_size, int num_channels, string interpolation) + { + // tf.print(path); + var img = tf.io.read_file(path); + img = tf.image.decode_image( + img, channels: num_channels, expand_animations: false); + img = tf.image.resize_images_v2(img, image_size, method: interpolation); + // img.set_shape((image_size[0], image_size[1], num_channels)); + return img; + } + + public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, + int[] labels, + string label_mode, + int num_classes, + int max_length = -1) + { + var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); + var string_ds = path_ds.map(x => path_to_string_content(x, max_length)); + + if (label_mode == "int") + { + var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); + string_ds = tf.data.Dataset.zip(string_ds, label_ds); + } + + return string_ds; + } + + Tensor path_to_string_content(Tensor path, int max_length) + { + var txt = tf.io.read_file(path); + if (max_length > -1) + txt = tf.strings.substr(txt, 0, max_length); + return txt; + } + } +} diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs new file mode 100644 index 000000000..c103e856c --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs @@ -0,0 +1,444 @@ +using Tensorflow.NumPy; +using Serilog.Debugging; +using System; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Data.SqlTypes; +using System.Linq; +using System.Net.Sockets; +using System.Text; + +namespace Tensorflow.Keras.Text +{ + /// + /// Text tokenization API. + /// This class allows to vectorize a text corpus, by turning each text into either a sequence of integers + /// (each integer being the index of a token in a dictionary) or into a vector where the coefficient for + /// each token could be binary, based on word count, based on tf-idf... + /// + /// + /// This code is a fairly straight port of the Python code for Keras text preprocessing found at: + /// https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/text.py + /// + public class Tokenizer + { + private readonly int num_words; + private readonly string filters; + private readonly bool lower; + private readonly char split; + private readonly bool char_level; + private readonly string oov_token; + private readonly Func> analyzer; + + private int document_count = 0; + + private Dictionary word_docs = new Dictionary(); + private Dictionary word_counts = new Dictionary(); + + public Dictionary word_index = null; + public Dictionary index_word = null; + + private Dictionary index_docs = null; + + public Tokenizer( + int num_words = -1, + string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n", + bool lower = true, + char split = ' ', + bool char_level = false, + string oov_token = null, + Func> analyzer = null) + { + this.num_words = num_words; + this.filters = filters; + this.lower = lower; + this.split = split; + this.char_level = char_level; + this.oov_token = oov_token; + this.analyzer = analyzer != null ? analyzer : (text) => TextApi.text_to_word_sequence(text, filters, lower, split); + } + + /// + /// Updates internal vocabulary based on a list of texts. + /// + /// A list of strings, each containing one or more tokens. + /// Required before using texts_to_sequences or texts_to_matrix. + public void fit_on_texts(IEnumerable texts) + { + foreach (var text in texts) + { + IEnumerable seq = null; + + document_count += 1; + if (char_level) + { + throw new NotImplementedException("char_level == true"); + } + else + { + seq = analyzer(lower ? text.ToLower() : text); + } + + foreach (var w in seq) + { + var count = 0; + word_counts.TryGetValue(w, out count); + word_counts[w] = count + 1; + } + + foreach (var w in new HashSet(seq)) + { + var count = 0; + word_docs.TryGetValue(w, out count); + word_docs[w] = count + 1; + } + } + + var wcounts = word_counts.AsEnumerable().ToList(); + wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order. + + var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token }; + sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); + + if (num_words > 0 - 1) + { + sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList(); + } + + word_index = new Dictionary(sorted_voc.Count); + index_word = new Dictionary(sorted_voc.Count); + index_docs = new Dictionary(word_docs.Count); + + for (int i = 0; i < sorted_voc.Count; i++) + { + word_index.Add(sorted_voc[i], i + 1); + index_word.Add(i + 1, sorted_voc[i]); + } + + foreach (var kv in word_docs) + { + var idx = -1; + if (word_index.TryGetValue(kv.Key, out idx)) + { + index_docs.Add(idx, kv.Value); + } + } + } + + /// + /// Updates internal vocabulary based on a list of texts. + /// + /// A list of list of strings, each containing one token. + /// Required before using texts_to_sequences or texts_to_matrix. + public void fit_on_texts(IEnumerable> texts) + { + foreach (var seq in texts) + { + foreach (var w in seq.Select(s => lower ? s.ToLower() : s)) + { + var count = 0; + word_counts.TryGetValue(w, out count); + word_counts[w] = count + 1; + } + + foreach (var w in new HashSet(word_counts.Keys)) + { + var count = 0; + word_docs.TryGetValue(w, out count); + word_docs[w] = count + 1; + } + } + + var wcounts = word_counts.AsEnumerable().ToList(); + wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); + + var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token }; + sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); + + if (num_words > 0 - 1) + { + sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList(); + } + + word_index = new Dictionary(sorted_voc.Count); + index_word = new Dictionary(sorted_voc.Count); + index_docs = new Dictionary(word_docs.Count); + + for (int i = 0; i < sorted_voc.Count; i++) + { + word_index.Add(sorted_voc[i], i + 1); + index_word.Add(i + 1, sorted_voc[i]); + } + + foreach (var kv in word_docs) + { + var idx = -1; + if (word_index.TryGetValue(kv.Key, out idx)) + { + index_docs.Add(idx, kv.Value); + } + } + } + + /// + /// Updates internal vocabulary based on a list of sequences. + /// + /// + /// Required before using sequences_to_matrix (if fit_on_texts was never called). + public void fit_on_sequences(IEnumerable sequences) + { + throw new NotImplementedException("fit_on_sequences"); + } + + /// + /// Transforms each string in texts to a sequence of integers. + /// + /// + /// + /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account. + public IList texts_to_sequences(IEnumerable texts) + { + return texts_to_sequences_generator(texts).ToArray(); + } + + /// + /// Transforms each token in texts to a sequence of integers. + /// + /// + /// + /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account. + public IList texts_to_sequences(IEnumerable> texts) + { + return texts_to_sequences_generator(texts).ToArray(); + } + + public IEnumerable texts_to_sequences_generator(IEnumerable texts) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + + return texts.Select(text => + { + IEnumerable seq = null; + + if (char_level) + { + throw new NotImplementedException("char_level == true"); + } + else + { + seq = analyzer(lower ? text.ToLower() : text); + } + + return ConvertToSequence(oov_index, seq).ToArray(); + }); + } + + public IEnumerable texts_to_sequences_generator(IEnumerable> texts) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray()); + } + + private List ConvertToSequence(int oov_index, IEnumerable seq) + { + var vect = new List(); + foreach (var w in seq.Select(s => lower ? s.ToLower() : s)) + { + var i = -1; + if (word_index.TryGetValue(w, out i)) + { + if (num_words != -1 && i >= num_words) + { + if (oov_index != -1) + { + vect.Add(oov_index); + } + } + else + { + vect.Add(i); + } + } + else if (oov_index != -1) + { + vect.Add(oov_index); + } + } + + return vect; + } + + /// + /// Transforms each sequence into a list of text. + /// + /// + /// A list of texts(strings) + /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account. + public IList sequences_to_texts(IEnumerable sequences) + { + return sequences_to_texts_generator(sequences).ToArray(); + } + + public IEnumerable sequences_to_texts_generator(IEnumerable> sequences) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + + return sequences.Select(seq => + { + + var bldr = new StringBuilder(); + for (var i = 0; i < seq.Count; i++) + { + if (i > 0) bldr.Append(' '); + + string word = null; + if (index_word.TryGetValue(seq[i], out word)) + { + if (num_words != -1 && i >= num_words) + { + if (oov_index != -1) + { + bldr.Append(oov_token); + } + } + else + { + bldr.Append(word); + } + } + else if (oov_index != -1) + { + bldr.Append(oov_token); + } + } + + return bldr.ToString(); + }); + } + + /// + /// Convert a list of texts to a Numpy matrix. + /// + /// A sequence of strings containing one or more tokens. + /// One of "binary", "count", "tfidf", "freq". + /// + public NDArray texts_to_matrix(IEnumerable texts, string mode = "binary") + { + return sequences_to_matrix(texts_to_sequences(texts), mode); + } + + /// + /// Convert a list of texts to a Numpy matrix. + /// + /// A sequence of lists of strings, each containing one token. + /// One of "binary", "count", "tfidf", "freq". + /// + public NDArray texts_to_matrix(IEnumerable> texts, string mode = "binary") + { + return sequences_to_matrix(texts_to_sequences(texts), mode); + } + + /// + /// Converts a list of sequences into a Numpy matrix. + /// + /// A sequence of lists of integers, encoding tokens. + /// One of "binary", "count", "tfidf", "freq". + /// + public NDArray sequences_to_matrix(IEnumerable> sequences, string mode = "binary") + { + if (!modes.Contains(mode)) throw new InvalidArgumentError($"Unknown vectorization mode: {mode}"); + var word_count = 0; + + if (num_words == -1) + { + if (word_index != null) + { + word_count = word_index.Count + 1; + } + else + { + throw new InvalidOperationException("Specifya dimension ('num_words' arugment), or fit on some text data first."); + } + } + else + { + word_count = num_words; + } + + if (mode == "tfidf" && this.document_count == 0) + { + throw new InvalidOperationException("Fit the Tokenizer on some text data before using the 'tfidf' mode."); + } + + var x = np.zeros((sequences.Count(), word_count)); + + for (int i = 0; i < sequences.Count(); i++) + { + var seq = sequences.ElementAt(i); + if (seq == null || seq.Count == 0) + continue; + + var counts = new Dictionary(); + + var seq_length = seq.Count; + + foreach (var j in seq) + { + if (j >= word_count) + continue; + var count = 0; + counts.TryGetValue(j, out count); + counts[j] = count + 1; + } + + if (mode == "count") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value + 0.0; + x[i, j] = c; + } + } + else if (mode == "freq") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value + 0.0; + x[i, j] = ((double)c) / seq_length; + } + } + else if (mode == "binary") + { + foreach (var kv in counts) + { + var j = kv.Key; + // var c = kv.Value + 0.0; + x[i, j] = 1.0; + } + } + else if (mode == "tfidf") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value + 0.0; + var id = 0; + var _ = index_docs.TryGetValue(j, out id); + var tf = 1.0 + (double)np.log(c); + var idf = np.log(1.0 + document_count / (1 + id)); + x[i, j] = tf * (double)idf; + } + } + } + + return x; + } + + private string[] modes = new string[] { "binary", "count", "tfidf", "freq" }; + } +} diff --git a/src/TensorFlowNET.Keras/Protobuf/ProjectorConfig.cs b/src/TensorFlowNET.Keras/Protobuf/ProjectorConfig.cs new file mode 100644 index 000000000..78ab79f89 --- /dev/null +++ b/src/TensorFlowNET.Keras/Protobuf/ProjectorConfig.cs @@ -0,0 +1,669 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/python/keras/protobuf/projector_config.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { + + /// Holder for reflection information generated from tensorflow/python/keras/protobuf/projector_config.proto + public static partial class ProjectorConfigReflection { + + #region Descriptor + /// File descriptor for tensorflow/python/keras/protobuf/projector_config.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ProjectorConfigReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cjd0ZW5zb3JmbG93L3B5dGhvbi9rZXJhcy9wcm90b2J1Zi9wcm9qZWN0b3Jf", + "Y29uZmlnLnByb3RvEix0aGlyZF9wYXJ0eS50ZW5zb3JmbG93LnB5dGhvbi5r", + "ZXJhcy5wcm90b2J1ZiI+Cg5TcHJpdGVNZXRhZGF0YRISCgppbWFnZV9wYXRo", + "GAEgASgJEhgKEHNpbmdsZV9pbWFnZV9kaW0YAiADKA0izAEKDUVtYmVkZGlu", + "Z0luZm8SEwoLdGVuc29yX25hbWUYASABKAkSFQoNbWV0YWRhdGFfcGF0aBgC", + "IAEoCRIWCg5ib29rbWFya3NfcGF0aBgDIAEoCRIUCgx0ZW5zb3Jfc2hhcGUY", + "BCADKA0STAoGc3ByaXRlGAUgASgLMjwudGhpcmRfcGFydHkudGVuc29yZmxv", + "dy5weXRob24ua2VyYXMucHJvdG9idWYuU3ByaXRlTWV0YWRhdGESEwoLdGVu", + "c29yX3BhdGgYBiABKAkinwEKD1Byb2plY3RvckNvbmZpZxIdChVtb2RlbF9j", + "aGVja3BvaW50X3BhdGgYASABKAkSTwoKZW1iZWRkaW5ncxgCIAMoCzI7LnRo", + "aXJkX3BhcnR5LnRlbnNvcmZsb3cucHl0aG9uLmtlcmFzLnByb3RvYnVmLkVt", + "YmVkZGluZ0luZm8SHAoUbW9kZWxfY2hlY2twb2ludF9kaXIYAyABKAliBnBy", + "b3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata.Parser, new[]{ "ImagePath", "SingleImageDim" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.EmbeddingInfo), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.EmbeddingInfo.Parser, new[]{ "TensorName", "MetadataPath", "BookmarksPath", "TensorShape", "Sprite", "TensorPath" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.ProjectorConfig), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.ProjectorConfig.Parser, new[]{ "ModelCheckpointPath", "Embeddings", "ModelCheckpointDir" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class SpriteMetadata : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SpriteMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.ProjectorConfigReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SpriteMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SpriteMetadata(SpriteMetadata other) : this() { + imagePath_ = other.imagePath_; + singleImageDim_ = other.singleImageDim_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SpriteMetadata Clone() { + return new SpriteMetadata(this); + } + + /// Field number for the "image_path" field. + public const int ImagePathFieldNumber = 1; + private string imagePath_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ImagePath { + get { return imagePath_; } + set { + imagePath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "single_image_dim" field. + public const int SingleImageDimFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_singleImageDim_codec + = pb::FieldCodec.ForUInt32(18); + private readonly pbc::RepeatedField singleImageDim_ = new pbc::RepeatedField(); + /// + /// [width, height] of a single image in the sprite. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField SingleImageDim { + get { return singleImageDim_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SpriteMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SpriteMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ImagePath != other.ImagePath) return false; + if(!singleImageDim_.Equals(other.singleImageDim_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ImagePath.Length != 0) hash ^= ImagePath.GetHashCode(); + hash ^= singleImageDim_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ImagePath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ImagePath); + } + singleImageDim_.WriteTo(output, _repeated_singleImageDim_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ImagePath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ImagePath); + } + size += singleImageDim_.CalculateSize(_repeated_singleImageDim_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SpriteMetadata other) { + if (other == null) { + return; + } + if (other.ImagePath.Length != 0) { + ImagePath = other.ImagePath; + } + singleImageDim_.Add(other.singleImageDim_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ImagePath = input.ReadString(); + break; + } + case 18: + case 16: { + singleImageDim_.AddEntriesFrom(input, _repeated_singleImageDim_codec); + break; + } + } + } + } + + } + + public sealed partial class EmbeddingInfo : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EmbeddingInfo()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.ProjectorConfigReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EmbeddingInfo() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EmbeddingInfo(EmbeddingInfo other) : this() { + tensorName_ = other.tensorName_; + metadataPath_ = other.metadataPath_; + bookmarksPath_ = other.bookmarksPath_; + tensorShape_ = other.tensorShape_.Clone(); + sprite_ = other.sprite_ != null ? other.sprite_.Clone() : null; + tensorPath_ = other.tensorPath_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EmbeddingInfo Clone() { + return new EmbeddingInfo(this); + } + + /// Field number for the "tensor_name" field. + public const int TensorNameFieldNumber = 1; + private string tensorName_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TensorName { + get { return tensorName_; } + set { + tensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "metadata_path" field. + public const int MetadataPathFieldNumber = 2; + private string metadataPath_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MetadataPath { + get { return metadataPath_; } + set { + metadataPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "bookmarks_path" field. + public const int BookmarksPathFieldNumber = 3; + private string bookmarksPath_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string BookmarksPath { + get { return bookmarksPath_; } + set { + bookmarksPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tensor_shape" field. + public const int TensorShapeFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_tensorShape_codec + = pb::FieldCodec.ForUInt32(34); + private readonly pbc::RepeatedField tensorShape_ = new pbc::RepeatedField(); + /// + /// Shape of the 2D tensor [N x D]. If missing, it will be inferred from the + /// model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField TensorShape { + get { return tensorShape_; } + } + + /// Field number for the "sprite" field. + public const int SpriteFieldNumber = 5; + private global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata sprite_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata Sprite { + get { return sprite_; } + set { + sprite_ = value; + } + } + + /// Field number for the "tensor_path" field. + public const int TensorPathFieldNumber = 6; + private string tensorPath_ = ""; + /// + /// Path to the TSV file holding the tensor values. If missing, the tensor + /// is assumed to be stored in the model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TensorPath { + get { return tensorPath_; } + set { + tensorPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EmbeddingInfo); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EmbeddingInfo other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TensorName != other.TensorName) return false; + if (MetadataPath != other.MetadataPath) return false; + if (BookmarksPath != other.BookmarksPath) return false; + if(!tensorShape_.Equals(other.tensorShape_)) return false; + if (!object.Equals(Sprite, other.Sprite)) return false; + if (TensorPath != other.TensorPath) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (TensorName.Length != 0) hash ^= TensorName.GetHashCode(); + if (MetadataPath.Length != 0) hash ^= MetadataPath.GetHashCode(); + if (BookmarksPath.Length != 0) hash ^= BookmarksPath.GetHashCode(); + hash ^= tensorShape_.GetHashCode(); + if (sprite_ != null) hash ^= Sprite.GetHashCode(); + if (TensorPath.Length != 0) hash ^= TensorPath.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (TensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TensorName); + } + if (MetadataPath.Length != 0) { + output.WriteRawTag(18); + output.WriteString(MetadataPath); + } + if (BookmarksPath.Length != 0) { + output.WriteRawTag(26); + output.WriteString(BookmarksPath); + } + tensorShape_.WriteTo(output, _repeated_tensorShape_codec); + if (sprite_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Sprite); + } + if (TensorPath.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TensorPath); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (TensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorName); + } + if (MetadataPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MetadataPath); + } + if (BookmarksPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(BookmarksPath); + } + size += tensorShape_.CalculateSize(_repeated_tensorShape_codec); + if (sprite_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Sprite); + } + if (TensorPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TensorPath); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EmbeddingInfo other) { + if (other == null) { + return; + } + if (other.TensorName.Length != 0) { + TensorName = other.TensorName; + } + if (other.MetadataPath.Length != 0) { + MetadataPath = other.MetadataPath; + } + if (other.BookmarksPath.Length != 0) { + BookmarksPath = other.BookmarksPath; + } + tensorShape_.Add(other.tensorShape_); + if (other.sprite_ != null) { + if (sprite_ == null) { + Sprite = new global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata(); + } + Sprite.MergeFrom(other.Sprite); + } + if (other.TensorPath.Length != 0) { + TensorPath = other.TensorPath; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + TensorName = input.ReadString(); + break; + } + case 18: { + MetadataPath = input.ReadString(); + break; + } + case 26: { + BookmarksPath = input.ReadString(); + break; + } + case 34: + case 32: { + tensorShape_.AddEntriesFrom(input, _repeated_tensorShape_codec); + break; + } + case 42: { + if (sprite_ == null) { + Sprite = new global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SpriteMetadata(); + } + input.ReadMessage(Sprite); + break; + } + case 50: { + TensorPath = input.ReadString(); + break; + } + } + } + } + + } + + public sealed partial class ProjectorConfig : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ProjectorConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.ProjectorConfigReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ProjectorConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ProjectorConfig(ProjectorConfig other) : this() { + modelCheckpointPath_ = other.modelCheckpointPath_; + embeddings_ = other.embeddings_.Clone(); + modelCheckpointDir_ = other.modelCheckpointDir_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ProjectorConfig Clone() { + return new ProjectorConfig(this); + } + + /// Field number for the "model_checkpoint_path" field. + public const int ModelCheckpointPathFieldNumber = 1; + private string modelCheckpointPath_ = ""; + /// + /// Path to the checkpoint file. Use either this or model_checkpoint_dir. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ModelCheckpointPath { + get { return modelCheckpointPath_; } + set { + modelCheckpointPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "embeddings" field. + public const int EmbeddingsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_embeddings_codec + = pb::FieldCodec.ForMessage(18, global::ThirdParty.Tensorflow.Python.Keras.Protobuf.EmbeddingInfo.Parser); + private readonly pbc::RepeatedField embeddings_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Embeddings { + get { return embeddings_; } + } + + /// Field number for the "model_checkpoint_dir" field. + public const int ModelCheckpointDirFieldNumber = 3; + private string modelCheckpointDir_ = ""; + /// + /// Path to the checkpoint directory. The directory will be scanned for the + /// latest checkpoint file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ModelCheckpointDir { + get { return modelCheckpointDir_; } + set { + modelCheckpointDir_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ProjectorConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ProjectorConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ModelCheckpointPath != other.ModelCheckpointPath) return false; + if(!embeddings_.Equals(other.embeddings_)) return false; + if (ModelCheckpointDir != other.ModelCheckpointDir) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); + hash ^= embeddings_.GetHashCode(); + if (ModelCheckpointDir.Length != 0) hash ^= ModelCheckpointDir.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ModelCheckpointPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ModelCheckpointPath); + } + embeddings_.WriteTo(output, _repeated_embeddings_codec); + if (ModelCheckpointDir.Length != 0) { + output.WriteRawTag(26); + output.WriteString(ModelCheckpointDir); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ModelCheckpointPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointPath); + } + size += embeddings_.CalculateSize(_repeated_embeddings_codec); + if (ModelCheckpointDir.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ModelCheckpointDir); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ProjectorConfig other) { + if (other == null) { + return; + } + if (other.ModelCheckpointPath.Length != 0) { + ModelCheckpointPath = other.ModelCheckpointPath; + } + embeddings_.Add(other.embeddings_); + if (other.ModelCheckpointDir.Length != 0) { + ModelCheckpointDir = other.ModelCheckpointDir; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ModelCheckpointPath = input.ReadString(); + break; + } + case 18: { + embeddings_.AddEntriesFrom(input, _repeated_embeddings_codec); + break; + } + case 26: { + ModelCheckpointDir = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs new file mode 100644 index 000000000..f29f2dec3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs @@ -0,0 +1,459 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/python/keras/protobuf/saved_metadata.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { + + /// Holder for reflection information generated from tensorflow/python/keras/protobuf/saved_metadata.proto + public static partial class SavedMetadataReflection { + + #region Descriptor + /// File descriptor for tensorflow/python/keras/protobuf/saved_metadata.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SavedMetadataReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjV0ZW5zb3JmbG93L3B5dGhvbi9rZXJhcy9wcm90b2J1Zi9zYXZlZF9tZXRh", + "ZGF0YS5wcm90bxIsdGhpcmRfcGFydHkudGVuc29yZmxvdy5weXRob24ua2Vy", + "YXMucHJvdG9idWYaL3RlbnNvcmZsb3cvcHl0aG9uL2tlcmFzL3Byb3RvYnVm", + "L3ZlcnNpb25zLnByb3RvIlkKDVNhdmVkTWV0YWRhdGESSAoFbm9kZXMYASAD", + "KAsyOS50aGlyZF9wYXJ0eS50ZW5zb3JmbG93LnB5dGhvbi5rZXJhcy5wcm90", + "b2J1Zi5TYXZlZE9iamVjdCKoAQoLU2F2ZWRPYmplY3QSDwoHbm9kZV9pZBgC", + "IAEoBRIRCglub2RlX3BhdGgYAyABKAkSEgoKaWRlbnRpZmllchgEIAEoCRIQ", + "CghtZXRhZGF0YRgFIAEoCRJJCgd2ZXJzaW9uGAYgASgLMjgudGhpcmRfcGFy", + "dHkudGVuc29yZmxvdy5weXRob24ua2VyYXMucHJvdG9idWYuVmVyc2lvbkRl", + "ZkoECAEQAmIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionsReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadata), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadata.Parser, new[]{ "Nodes" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject.Parser, new[]{ "NodeId", "NodePath", "Identifier", "Metadata", "Version" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class SavedMetadata : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedMetadata()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadataReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedMetadata() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedMetadata(SavedMetadata other) : this() { + nodes_ = other.nodes_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedMetadata Clone() { + return new SavedMetadata(this); + } + + /// Field number for the "nodes" field. + public const int NodesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_nodes_codec + = pb::FieldCodec.ForMessage(10, global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject.Parser); + private readonly pbc::RepeatedField nodes_ = new pbc::RepeatedField(); + /// + /// Nodes represent trackable objects in the SavedModel. The data for every + /// Keras object is stored. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Nodes { + get { return nodes_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SavedMetadata); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SavedMetadata other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!nodes_.Equals(other.nodes_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= nodes_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + nodes_.WriteTo(output, _repeated_nodes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += nodes_.CalculateSize(_repeated_nodes_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SavedMetadata other) { + if (other == null) { + return; + } + nodes_.Add(other.nodes_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + nodes_.AddEntriesFrom(input, _repeated_nodes_codec); + break; + } + } + } + } + + } + + /// + /// Metadata of an individual Keras object. + /// + public sealed partial class SavedObject : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedObject()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadataReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject() { + OnConstruction(); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(int nodeId, string nodePath, + global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) + { + OnConstruction(); + nodeId_ = nodeId; + nodePath_ = nodePath; + identifier_ = identifier; + metadata_ = metadata; + version_ = version; + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(SavedObject other) : this() { + nodeId_ = other.nodeId_; + nodePath_ = other.nodePath_; + identifier_ = other.identifier_; + metadata_ = other.metadata_; + version_ = other.version_ != null ? other.version_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject Clone() { + return new SavedObject(this); + } + + /// Field number for the "node_id" field. + public const int NodeIdFieldNumber = 2; + private int nodeId_; + /// + /// Index of the node in the SavedModel SavedObjectGraph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NodeId { + get { return nodeId_; } + set { + nodeId_ = value; + } + } + + /// Field number for the "node_path" field. + public const int NodePathFieldNumber = 3; + private string nodePath_ = ""; + /// + /// String path from root (e.g. "root.child_layer") + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string NodePath { + get { return nodePath_; } + set { + nodePath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "identifier" field. + public const int IdentifierFieldNumber = 4; + private string identifier_ = ""; + /// + /// Identifier to determine loading function. + /// Must be one of: + /// _tf_keras_input_layer, _tf_keras_layer, _tf_keras_metric, + /// _tf_keras_model, _tf_keras_network, _tf_keras_rnn_layer, + /// _tf_keras_sequential + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Identifier { + get { return identifier_; } + set { + identifier_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "metadata" field. + public const int MetadataFieldNumber = 5; + private string metadata_ = ""; + /// + /// Metadata containing a JSON-serialized object with the non-TensorFlow + /// attributes for this Keras object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Metadata { + get { return metadata_; } + set { + metadata_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 6; + private global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version_; + /// + /// Version defined by the code serializing this Keras object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef Version { + get { return version_; } + set { + version_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SavedObject); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SavedObject other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NodeId != other.NodeId) return false; + if (NodePath != other.NodePath) return false; + if (Identifier != other.Identifier) return false; + if (Metadata != other.Metadata) return false; + if (!object.Equals(Version, other.Version)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NodeId != 0) hash ^= NodeId.GetHashCode(); + if (NodePath.Length != 0) hash ^= NodePath.GetHashCode(); + if (Identifier.Length != 0) hash ^= Identifier.GetHashCode(); + if (Metadata.Length != 0) hash ^= Metadata.GetHashCode(); + if (version_ != null) hash ^= Version.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NodeId != 0) { + output.WriteRawTag(16); + output.WriteInt32(NodeId); + } + if (NodePath.Length != 0) { + output.WriteRawTag(26); + output.WriteString(NodePath); + } + if (Identifier.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Identifier); + } + if (Metadata.Length != 0) { + output.WriteRawTag(42); + output.WriteString(Metadata); + } + if (version_ != null) { + output.WriteRawTag(50); + output.WriteMessage(Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NodeId != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NodeId); + } + if (NodePath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(NodePath); + } + if (Identifier.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Identifier); + } + if (Metadata.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Metadata); + } + if (version_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Version); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SavedObject other) { + if (other == null) { + return; + } + if (other.NodeId != 0) { + NodeId = other.NodeId; + } + if (other.NodePath.Length != 0) { + NodePath = other.NodePath; + } + if (other.Identifier.Length != 0) { + Identifier = other.Identifier; + } + if (other.Metadata.Length != 0) { + Metadata = other.Metadata; + } + if (other.version_ != null) { + if (version_ == null) { + Version = new global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef(); + } + Version.MergeFrom(other.Version); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 16: { + NodeId = input.ReadInt32(); + break; + } + case 26: { + NodePath = input.ReadString(); + break; + } + case 34: { + Identifier = input.ReadString(); + break; + } + case 42: { + Metadata = input.ReadString(); + break; + } + case 50: { + if (version_ == null) { + Version = new global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef(); + } + input.ReadMessage(Version); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Keras/Protobuf/Versions.cs b/src/TensorFlowNET.Keras/Protobuf/Versions.cs new file mode 100644 index 000000000..ff9a23c62 --- /dev/null +++ b/src/TensorFlowNET.Keras/Protobuf/Versions.cs @@ -0,0 +1,255 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/python/keras/protobuf/versions.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { + + /// Holder for reflection information generated from tensorflow/python/keras/protobuf/versions.proto + public static partial class VersionsReflection { + + #region Descriptor + /// File descriptor for tensorflow/python/keras/protobuf/versions.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VersionsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci90ZW5zb3JmbG93L3B5dGhvbi9rZXJhcy9wcm90b2J1Zi92ZXJzaW9ucy5w", + "cm90bxIsdGhpcmRfcGFydHkudGVuc29yZmxvdy5weXRob24ua2VyYXMucHJv", + "dG9idWYiSwoKVmVyc2lvbkRlZhIQCghwcm9kdWNlchgBIAEoBRIUCgxtaW5f", + "Y29uc3VtZXIYAiABKAUSFQoNYmFkX2NvbnN1bWVycxgDIAMoBWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef), global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Version information for a piece of serialized data + /// + /// There are different types of versions for each type of data + /// (GraphDef, etc.), but they all have the same common shape + /// described here. + /// + /// Each consumer has "consumer" and "min_producer" versions (specified + /// elsewhere). A consumer is allowed to consume this data if + /// + /// producer >= min_producer + /// consumer >= min_consumer + /// consumer not in bad_consumers + /// + /// LINT.IfChange + /// + public sealed partial class VersionDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VersionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef() { + OnConstruction(); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(int producer, int minConsumer) { + OnConstruction(); + producer_ = producer; + minConsumer_ = minConsumer; + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(VersionDef other) : this() { + producer_ = other.producer_; + minConsumer_ = other.minConsumer_; + badConsumers_ = other.badConsumers_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef Clone() { + return new VersionDef(this); + } + + /// Field number for the "producer" field. + public const int ProducerFieldNumber = 1; + private int producer_; + /// + /// The version of the code that produced this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Producer { + get { return producer_; } + set { + producer_ = value; + } + } + + /// Field number for the "min_consumer" field. + public const int MinConsumerFieldNumber = 2; + private int minConsumer_; + /// + /// Any consumer below this version is not allowed to consume this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinConsumer { + get { return minConsumer_; } + set { + minConsumer_ = value; + } + } + + /// Field number for the "bad_consumers" field. + public const int BadConsumersFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_badConsumers_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField badConsumers_ = new pbc::RepeatedField(); + /// + /// Specific consumer versions which are disallowed (e.g. due to bugs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BadConsumers { + get { return badConsumers_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as VersionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(VersionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Producer != other.Producer) return false; + if (MinConsumer != other.MinConsumer) return false; + if(!badConsumers_.Equals(other.badConsumers_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Producer != 0) hash ^= Producer.GetHashCode(); + if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); + hash ^= badConsumers_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Producer != 0) { + output.WriteRawTag(8); + output.WriteInt32(Producer); + } + if (MinConsumer != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinConsumer); + } + badConsumers_.WriteTo(output, _repeated_badConsumers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Producer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); + } + if (MinConsumer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); + } + size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(VersionDef other) { + if (other == null) { + return; + } + if (other.Producer != 0) { + Producer = other.Producer; + } + if (other.MinConsumer != 0) { + MinConsumer = other.MinConsumer; + } + badConsumers_.Add(other.badConsumers_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Producer = input.ReadInt32(); + break; + } + case 16: { + MinConsumer = input.ReadInt32(); + break; + } + case 26: + case 24: { + badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Keras/Regularizers.cs b/src/TensorFlowNET.Keras/Regularizers.cs new file mode 100644 index 000000000..73b72a051 --- /dev/null +++ b/src/TensorFlowNET.Keras/Regularizers.cs @@ -0,0 +1,51 @@ +using Tensorflow.Operations.Regularizers; + +namespace Tensorflow.Keras +{ + public class Regularizers: IRegularizerApi + { + private static Dictionary _nameActivationMap; + + public IRegularizer l1(float l1 = 0.01f) + => new L1(l1); + public IRegularizer l2(float l2 = 0.01f) + => new L2(l2); + + //From TF source + //# The default value for l1 and l2 are different from the value in l1_l2 + //# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 + //# and no l1 penalty. + public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f) + => new L1L2(l1, l2); + + static Regularizers() + { + _nameActivationMap = new Dictionary(); + _nameActivationMap["L1"] = new L1(); + _nameActivationMap["L1"] = new L2(); + _nameActivationMap["L1"] = new L1L2(); + } + + public IRegularizer L1 => l1(); + + public IRegularizer L2 => l2(); + + public IRegularizer L1L2 => l1l2(); + + public IRegularizer GetRegularizerFromName(string name) + { + if (name == null) + { + throw new Exception($"Regularizer name cannot be null"); + } + if (!_nameActivationMap.TryGetValue(name, out var res)) + { + throw new Exception($"Regularizer {name} not found"); + } + else + { + return res; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs new file mode 100644 index 000000000..9c82370a9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -0,0 +1,44 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public class KerasMetaData + { + [JsonProperty("name")] + public string Name { get; set; } + [JsonProperty("class_name")] + public string ClassName { get; set; } + [JsonProperty("trainable")] + public bool Trainable { get; set; } + [JsonProperty("dtype")] + public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; + [JsonProperty("is_graph_network")] + public bool IsGraphNetwork { get; set; } + [JsonProperty("shared_object_id")] + public int SharedObjectId { get; set; } + [JsonProperty("must_restore_from_config")] + public bool MustRestoreFromConfig { get; set; } + [JsonProperty("config")] + public JObject Config { get; set; } + [JsonProperty("build_input_shape")] + public KerasShapesWrapper BuildInputShape { get; set; } + [JsonProperty("batch_input_shape")] + public KerasShapesWrapper BatchInputShape { get; set; } + [JsonProperty("activity_regularizer")] + public IRegularizer ActivityRegularizer { get; set; } + [JsonProperty("input_spec")] + public JToken InputSpec { get; set; } + [JsonProperty("stateful")] + public bool? Stateful { get; set; } + [JsonProperty("model_config")] + public KerasModelConfig? ModelConfig { get; set; } + [JsonProperty("sparse")] + public bool Sparse { get; set; } + [JsonProperty("ragged")] + public bool Ragged { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs b/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs new file mode 100644 index 000000000..256c284a5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs @@ -0,0 +1,16 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public class KerasModelConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } + [JsonProperty("config")] + public JObject Config { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs new file mode 100644 index 000000000..0bd816ccb --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -0,0 +1,795 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text.RegularExpressions; +using Tensorflow.Common.Extensions; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Util; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Saving +{ + public class KerasObjectLoader + { + internal static readonly IDictionary PUBLIC_ATTRIBUTES; + private SavedMetadata _metadata; + private SavedObjectGraph _proto; + private Dictionary _node_paths = new Dictionary(); + private Dictionary model_layer_ids_dependencies = new Dictionary(); + private Dictionary model_layer_dependencies = new Dictionary(); + private List _traversed_nodes_from_config = new List(); + private Dictionary)> loaded_nodes; + private List _models_to_reconstruct; + public Dictionary)> LoadedNodes => loaded_nodes; + + static KerasObjectLoader() + { + var endPoints = new CommonEndPoints(); + PUBLIC_ATTRIBUTES = new Dictionary(); + foreach (var key in endPoints._all_checkpointable_objects.Concat(endPoints._all_functions)) + { + PUBLIC_ATTRIBUTES[key] = null; + } + PUBLIC_ATTRIBUTES[SavedModel.Constants.KERAS_ATTR] = null; + } + + public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) + { + _metadata = metadata; + _proto = object_graph_def; + _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); + _models_to_reconstruct = new List(); + loaded_nodes = new Dictionary)>(); + } + + /// + /// Load all layer nodes from the metadata. + /// + /// + public void load_layers(bool compile = true) + { + var metric_list = new List(); + foreach (var node_metadata in _metadata.Nodes) + { + if (node_metadata.Identifier == "_tf_keras_metric") + { + metric_list.Add(node_metadata); + continue; + } + + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + } + foreach(var node_metadata in metric_list) + { + try + { + if (node_metadata.Identifier.Equals("_tf_keras_metric")) + { + continue; + } + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, + node_metadata.Metadata); + } + catch(ValueError e) + { + if (compile) + { + throw e; + } + // TODO: add logging.warning. + } + } + } + + public string get_path(int node_id) + { + return _node_paths[node_id]; + } + + /// + /// Finish setting up Keras objects. + /// + /// This function is executed after all objects and functions have been created. + /// Call functions and losses are attached to each layer, and once all layers + /// have been fully set up, graph networks are initialized. + /// + /// Subclassed models that are revived from the SavedModel are treated like + /// layers, and have their call/loss functions attached here. + /// + public void finalize_objects() + { + List layers_revived_from_config = new(); + List layers_revived_from_saved_model = new(); + foreach(var item in loaded_nodes) + { + var node_id = item.Key; + var node = item.Value.Item1; + if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) + { + continue; + } + + _unblock_model_reconstruction(node_id, node as Layer); + + if(node is InputLayer or Metric) + { + continue; + } + + if(node is RevivedLayer or RevivedInputLayer) + { + layers_revived_from_saved_model.Add(node as Layer); + } + else + { + layers_revived_from_config.Add(node as Layer); + } + } + + _finalize_saved_model_layers(layers_revived_from_saved_model); + _finalize_config_layers(layers_revived_from_config); + + _reconstruct_all_models(); + } + + /// + /// Removes tracked references that are only used when loading the model. + /// Now that the node object has been fully loaded, and the checkpoint has + /// been restored, the object no longer needs to track objects added from + /// SerializedAttributes. (Note that saving a training checkpoint still + /// functions correctly, because layers and variables are tracked + /// separately by the Layer object.) + /// + public void del_tracking() + { + foreach(var (node, _) in loaded_nodes.Values) + { + if(node is not Layer layer) + { + continue; + } + foreach(var name in PUBLIC_ATTRIBUTES.Keys) + { + layer._delete_tracking(name); + } + if(node is Functional functional) + { + foreach(var name in functional.UnconditionalDependencyNames.Keys.ToArray()) + { + if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success) + { + functional._delete_tracking(name); + } + } + } + } + } + + private void _reconstruct_all_models() + { + HashSet all_initialized_models = new(); + for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) + { + int model_id = _models_to_reconstruct[i]; + all_initialized_models.Add(model_id); + var (model, layers) = model_layer_dependencies[model_id]; + _reconstruct_model(model_id, model, layers.ToList()); + _finalize_config_layers(new List() { model }); + } + + Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); + } + + private void _reconstruct_model(int model_id, Model model, List layers) + { + var config = JsonConvert.DeserializeObject(_metadata.Nodes[model_id].Metadata)["config"]; + + if(model.input is not null && model.input.Length > 0) + { + + } + else if(model is Sequential s) + { + if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) + { + if (config["layers"][0]["class_name"].ToObject() == "InputLayer") + { + layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject())); + } + else if (config["layers"][0]["config"]["batch_input_shape"] is not null) + { + // TODO(Rinne): implement it + } + } + + // `model.__init__(layers, config["name"])`InitLayers(layers); + s.InitLayers(layers.Select(x => x as ILayer)); + s.Name = config["name"].ToObject(); + if(s.inputs is null || s.inputs.Length == 0) + { + var first_layer = _get_child_layer_node_ids(model_id)[0]; + var input_specs = _infer_inputs(first_layer); + var input_shapes = _infer_input_shapes(first_layer); + // `model._set_inputs(input_specs)` + s._set_inputs(input_specs); + + // skip the check of input_specs is Dictionary + if (!s.Built) + { + s.build(input_shapes); + } + } + } + else + { + // skip the parameter `created_layers`. + var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), + layers.ToDictionary(x => x.Name, x => x as ILayer)); + // skip the `model.__init__` + (model as Functional).Initialize(inputs, outputs, config["name"].ToObject()); + (model as Functional).connect_ancillary_layers(created_layers); + } + + _set_network_attributes_from_metadata(model); + _unblock_model_reconstruction(model_id, model); + } + + private void _set_network_attributes_from_metadata(Model revived_object) + { + var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData; + if (metadata.DType != TF_DataType.DtInvalid) + { + // TODO(Rinne): set_dtype_policy. + } + revived_object.args.Trainable = metadata.Trainable; + } + + /// + /// Runs the final steps of loading Keras Layers from config. + /// + /// + private void _finalize_config_layers(List layers) + { + foreach(var layer in layers) + { + if (_is_graph_network(layer)) + { + _restore_layer_unconditional_losses(layer); + } + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + + // TODO(Rinne): deal with RNN. + } + } + + /// + /// Runs the final steps of loading Keras Layers from SavedModel. + /// + /// + private void _finalize_saved_model_layers(List layers) + { + foreach(var layer in layers) + { + layer.Built = true; + var keras_attr = _get_keras_attr(layer); + if(keras_attr is not Trackable trackable) + { + continue; + } + if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call)) + { + Debug.Assert(layer_call is RestoredFunction); + var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions; + if (concrete_functions is not null && concrete_functions.Count() > 0) + { + layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply); + } + } + } + + foreach(var layer in layers) + { + // TODO(Rinne): deal with `RevivedNetwork`. + + _restore_layer_unconditional_losses(layer); + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + } + } + + private Func use_wrapped_call(Layer layer, Func call) + { + // TODO(Rinne): revise it. + return call; + } + + private void _restore_layer_unconditional_losses(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_activation_loss(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_metrics(Layer layer) + { + // TODO(Rinne): implement it. + } + + /// + /// Removes layer from blocking model reconstruction. + /// + /// + /// + private void _unblock_model_reconstruction(int layer_id, Layer layer) + { + foreach(var depencency in model_layer_ids_dependencies) + { + var layer_ids = depencency.Value.Item2; + var layers = model_layer_dependencies.SetDefault(depencency.Key, + (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; + if (!layer_ids.Contains(layer_id)) + { + continue; + } + layers[Array.IndexOf(layer_ids, layer_id)] = layer; + if (layers.All(x => x is not null)) + { + _models_to_reconstruct.Add(depencency.Key); + } + } + } + + private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) + { + var metadata = JsonConvert.DeserializeObject(metadata_json); + + if (loaded_nodes.ContainsKey(node_id)) + { + var (node, setter) = loaded_nodes[node_id]; + + _maybe_add_serialized_attributes(node as Layer, metadata); + var config = metadata.Config; + if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) + { + Debug.Assert(node is Model); + var child_nodes = _get_child_layer_node_ids(node_id); + model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); + if(child_nodes is null || child_nodes.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } + } + return (node, setter); + } + else + { + var (obj, setter) = _revive_from_config(identifier, metadata, node_id); + if (obj is null) + { + (obj, setter) = revive_custom_object(identifier, metadata); + } + if(obj is null) + { + throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object."); + } + Debug.Assert(obj is Layer); + _maybe_add_serialized_attributes(obj as Layer, metadata); + return (obj, setter); + } + } + + /// + /// Revives a layer/model from config, or returns None. + /// + /// + /// + /// + private (Trackable, Action) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) + { + Trackable obj; + if(identifier == SavedModel.Constants.METRIC_IDENTIFIER) + { + // TODO(Rinne): implement it. + return (null, null); + //throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + else + { + obj = _revive_graph_network(identifier, metadata, node_id); + obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + } + + if(obj is null) + { + return (null, null); + } + var setter = _config_node_setter(_revive_setter); + _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); + return (obj, setter); + } + + private (Trackable, Action) revive_custom_object(string identifier, KerasMetaData metadata) + { + if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) + { + return RevivedLayer.init_from_metadata(metadata); + } + else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER + || identifier == SavedModel.Constants.NETWORK_IDENTIFIER) + { + return RevivedNetwork.init_from_metadata(metadata); + } + else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) + { + return RevivedInputLayer.init_from_metadata(metadata); + } + else + { + throw new ValueError($"Cannot revive the layer {identifier}."); + } + } + + Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) + { + var config = metadata.Config; + var class_name = metadata.ClassName; + Model model = null; + + if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") + { + return null; + } + + if (class_name == "Sequential") + { + model = new Sequential(new SequentialArgs + { + Name = config.GetValue("name").ToString() + }); + } + else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) + { + model = new Sequential(new SequentialArgs + { + Name = class_name + }); + } + else + { + model = new Functional(new Tensors(), new Tensors(), config.TryGetOrReturnNull("name")); + } + + // Record this model and its layers. This will later be used to reconstruct + // the model. + var layers = _get_child_layer_node_ids(node_id); + model_layer_ids_dependencies[node_id] = (model, layers); + if(layers is null || layers.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } + return model; + } + + Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) + { + var config = metadata.Config; + var class_name = metadata.ClassName; + var shared_object_id = metadata.SharedObjectId; + var must_restore_from_config = metadata.MustRestoreFromConfig; + + var obj = generic_utils.deserialize_keras_object(class_name, config); + + if(obj is null) + { + return null; + } + obj.Name = metadata.Name; + // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` + + + var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); + if (!built) + { + return null; + } + return obj; + } + + private void _revive_setter(object obj, object name, object value) + { + Debug.Assert(name is string); + Debug.Assert(obj is Layer); + Layer layer = (Layer)obj; + if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) + { + if(value is Trackable) + { + layer._track_trackable(value as Trackable, name as string); + } + if(layer.SerializedAttributes is null) + { + layer.SerializedAttributes = new Dictionary(); + } + layer.SerializedAttributes[name as string] = value; + } + else if(layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + { + functional._track_trackable(value as Trackable, name as string, overwrite: true); + } + else + { + layer.SetAttr(name as string, value); + } + } + + /// + /// Returns the node ids of each layer in a Sequential/Functional model. + /// + /// + int[] _get_child_layer_node_ids(int node_id) + { + int num_layers = 0; + Dictionary child_layers = new Dictionary(); + foreach (var child in _proto.Nodes[node_id].Children) + { + var m = Regex.Match(child.LocalName, @"layer-(\d+)"); + if (!m.Success) + continue; + var layer_n = int.Parse(m.Groups[1].Value); + num_layers = max(layer_n + 1, num_layers); + child_layers[layer_n] = child.NodeId; + } + + var ordered = new List(); + foreach (var n in range(num_layers)) + { + if (child_layers.ContainsKey(n)) + ordered.Add(child_layers[n]); + else + break; + } + return ordered.ToArray(); + } + + /// + /// Recursively records objects recreated from config. + /// + /// + /// + /// + void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) + { + if (_traversed_nodes_from_config.Contains(node_id)) + return; + var parent_path = _node_paths[node_id]; + _traversed_nodes_from_config.Add(node_id); + obj._maybe_initialize_trackable(); + + if(obj is Layer layer && !layer.Built) + { + var metadata = JsonConvert.DeserializeObject(_metadata.Nodes[node_id].Metadata); + _try_build_layer(layer, node_id, metadata.BuildInputShape); + } + + + List<(Trackable, int, string)> children = new(); + foreach(var refer in proto.Children) + { + var obj_child = obj._lookup_dependency(refer.LocalName); + children.Add((obj_child, refer.NodeId, refer.LocalName)); + } + + var metric_list_node_id = _search_for_child_node(node_id, new string[] { + SavedModel.Constants.KERAS_ATTR, "layer_metrics" + }); + if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) + { + var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); + foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) + { + if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) + { + var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; + children.Add((metric as Metric, refer.NodeId, metric_path)); + } + } + } + + foreach(var (obj_child, child_id, child_name) in children) + { + if(obj_child is null) + { + continue; + } + var child_proto = _proto.Nodes[child_id]; + + // skip the check for registered identifier + + Action setter; + if (SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) + { + setter = _revive_setter; + } + else + { + setter = Loader.setattr; + } + + if (loaded_nodes.ContainsKey(child_id)) + { + // skip the logging.warning + continue; + } + + if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) + { + (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; + } + + if(obj_child is TrackableDataStructure) + { + setter = (x, y, z) => { }; + } + + var child_path = $"{parent_path}.{child_name}"; + _node_paths[child_id] = child_path; + _add_children_recreated_from_config(obj_child, child_proto, child_id); + loaded_nodes[child_id] = (obj_child, setter); + } + } + + private bool _try_build_layer(Layer obj, int node_id, KerasShapesWrapper build_input_shape) + { + if (obj.Built) + return true; + + if(build_input_shape is null) + { + build_input_shape = _infer_input_shapes(node_id); + } + + if(build_input_shape is not null) + { + obj.build(build_input_shape); + // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. + // On the one hand, C# does not support call a method from specified parent class. + // On the other hand, currently All class derived from Layer call `Layer.Build` or + // move the implementation of `Layer.build` to its own `build` method. + // Therefore we do not call it here. + // However, it's still quite risky once in the future a certain class derived from + // `Layer` does not call `Layer.build`. + + return true; + } + + return false; + } + + /// + /// Infers input shape of layer from SavedModel functions. + /// + /// + /// + private TensorSpec _infer_inputs(int layer_node_id) + { + var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); + if(call_fn_id is null) + { + return null; + } + + var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; + if(concrete_functions is null) + { + return null; + } + var call_fn_name = concrete_functions[0]; + var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; + var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature); + Debug.Assert(structured_input_signature is IEnumerable); + var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator(); + first_enumerator.MoveNext(); + var first = first_enumerator.Current; + Debug.Assert(first is IEnumerable); + var inputs_enumerator = (first as IEnumerable).GetEnumerator(); + inputs_enumerator.MoveNext(); + var inputs = inputs_enumerator.Current as TensorSpec; + return inputs; + } + + private KerasShapesWrapper _infer_input_shapes(int layer_node_id) + { + var inputs = _infer_inputs(layer_node_id); + return new KerasShapesWrapper(nest.map_structure(x => x.shape, inputs)); + } + + private int? _search_for_child_node(int parent_id, IEnumerable path_to_child) + { + if(path_to_child is null || path_to_child.Count() == 0) + { + return parent_id; + } + + foreach(var child in _proto.Nodes[parent_id].Children) + { + if(child.LocalName == path_to_child.First()) + { + return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); + } + } + return null; + } + + private bool _is_graph_network(Layer layer) + { + // TODO: deal with `RevivedLayer` + if(layer is Functional) + { + return (layer as Functional).IsGraphNetwork || layer is Sequential; + } + return false; + } + + private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) + { + if(layer.SerializedAttributes is null || layer.SerializedAttributes.Count == 0) + { + layer.SerializedAttributes = new Dictionary(); + layer.SerializedAttributes["metadata"] = metadata; + } + } + + private static object _get_keras_attr(Layer layer) + { + if((layer.SerializedAttributes ?? new Dictionary()).TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) + { + return value; + } + else + { + return null; + } + } + + /// + /// Creates edges for nodes that are recreated from config. + /// + /// + private Action _config_node_setter(Action setter) + { + void setattr_wrapper(object obj, object name, object value) + { + Debug.Assert(obj is Trackable); + Debug.Assert(name is string); + if((obj as Trackable)._lookup_dependency(name as string) is null) + { + setter(obj, name, value); + } + } + return setattr_wrapper; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..3ea4f067e --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public static class Constants +{ + /// + /// Namespace used to store all attributes added during serialization. + /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an + /// object loaded from `tf.saved_model.load()`. + /// + public static readonly string KERAS_ATTR = "keras_api"; + /// + /// Keys for the serialization cache. + /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} + /// + public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; + /// + /// Name of Keras metadata file stored in the SavedModel. + /// + public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; + + public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; + public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; + public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; + public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; + public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; + public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; + public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; + + public static readonly IList KERAS_OBJECT_IDENTIFIERS = new List() + { + INPUT_LAYER_IDENTIFIER, + LAYER_IDENTIFIER, + METRIC_IDENTIFIER, + MODEL_IDENTIFIER, + NETWORK_IDENTIFIER, + RNN_LAYER_IDENTIFIER, + SEQUENTIAL_IDENTIFIER + }; +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs new file mode 100644 index 000000000..6970b04e5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs @@ -0,0 +1,55 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Text.RegularExpressions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + internal static class ReviveUtils + { + public static T recursively_deserialize_keras_object(JToken config) + { + throw new NotImplementedException(); + if(config is JObject jobject) + { + if (jobject.ContainsKey("class_name")) + { + + } + } + } + + public static void _revive_setter(object obj, object name, object value) + { + Debug.Assert(name is string); + Debug.Assert(obj is Layer); + Layer layer = (Layer)obj; + if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) + { + if (value is Trackable trackable) + { + layer._track_trackable(trackable, name as string); + } + if (layer.SerializedAttributes is null) + { + layer.SerializedAttributes = new Dictionary(); + } + layer.SerializedAttributes[name as string] = value; + } + else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + { + Debug.Assert(value is Trackable); + functional._track_trackable(value as Trackable, name as string); + } + else + { + layer.SetAttr(name as string, value); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs new file mode 100644 index 000000000..036d517b1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs @@ -0,0 +1,37 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + [JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))] + public class RevivedConfig: IKerasConfig + { + public JObject Config { get; set; } + } + + public class CustomizedRevivedConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(RevivedConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + ((RevivedConfig)value).Config.WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var config = (JObject)serializer.Deserialize(reader, typeof(JObject)); + return new RevivedConfig() { Config = config }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs new file mode 100644 index 000000000..e2cad8a37 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedInputLayer: InputLayer + { + protected RevivedConfig _config = null; + private RevivedInputLayer(InputLayerArgs args): base(args) + { + + } + + public override IKerasConfig get_config() + { + return _config; + } + + public static (RevivedInputLayer, Action) init_from_metadata(KerasMetaData metadata) + { + InputLayerArgs args = new InputLayerArgs() + { + Name = metadata.Name, + DType = metadata.DType, + Sparse = metadata.Sparse, + Ragged = metadata.Ragged, + BatchInputShape = metadata.BatchInputShape + }; + + RevivedInputLayer revived_obj = new RevivedInputLayer(args); + + revived_obj._config = new RevivedConfig() { Config = metadata.Config }; + + return (revived_obj, Loader.setattr); + } + + public override string ToString() + { + return $"Customized keras input layer: {Name}."; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs new file mode 100644 index 000000000..51e367ce8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs @@ -0,0 +1,88 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedLayer: Layer + { + public static (RevivedLayer, Action) init_from_metadata(KerasMetaData metadata) + { + LayerArgs args = new LayerArgs() + { + Name = metadata.Name, + Trainable = metadata.Trainable + }; + if(metadata.DType != TF_DataType.DtInvalid) + { + args.DType = metadata.DType; + } + if(metadata.BatchInputShape is not null) + { + args.BatchInputShape = metadata.BatchInputShape; + } + + RevivedLayer revived_obj = new RevivedLayer(args); + + // TODO(Rinne): implement `expects_training_arg`. + var config = metadata.Config; + if (generic_utils.validate_config(config)) + { + revived_obj._config = new RevivedConfig() { Config = config }; + } + if(metadata.InputSpec is not null) + { + throw new NotImplementedException(); + } + if(metadata.ActivityRegularizer is not null) + { + throw new NotImplementedException(); + } + // TODO(Rinne): `_is_feature_layer` + if(metadata.Stateful is not null) + { + revived_obj.stateful = metadata.Stateful.Value; + } + + return (revived_obj, ReviveUtils._revive_setter); + } + + protected RevivedConfig _config = null; + + public object keras_api + { + get + { + if (SerializedAttributes.TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) + { + return value; + } + else + { + return null; + } + } + } + + protected RevivedLayer(LayerArgs args): base(args) + { + + } + + public override string ToString() + { + return $"Customized keras layer: {Name}."; + } + + public override IKerasConfig get_config() + { + return _config; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs new file mode 100644 index 000000000..1860c8c75 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedNetwork: RevivedLayer + { + private RevivedNetwork(LayerArgs args) : base(args) + { + + } + + public static (RevivedNetwork, Action) init_from_metadata(KerasMetaData metadata) + { + RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); + + // TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) + // TODO(Rinne): revived_obj._expects_training_arg + var config = metadata.Config; + if (generic_utils.validate_config(config)) + { + revived_obj._config = new RevivedConfig() { Config = config }; + } + if(metadata.ActivityRegularizer is not null) + { + throw new NotImplementedException(); + } + + return (revived_obj, ReviveUtils._revive_setter); + } + + public override string ToString() + { + return $"Customized keras Network: {Name}."; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs new file mode 100644 index 000000000..331b283a0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Google.Protobuf; +using Tensorflow.Functions; +using Tensorflow.Keras.Engine; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Keras.Optimizers; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; +using Tensorflow.Training; +using System.Diagnostics; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, + SaveOptions? options, bool save_traces = true) + { + if (!overwrite && File.Exists(filepath)) + { + throw new Exception("The file already exists but is not allowed to overwrite it."); + } + + if (save_traces) + { + if(should_skip_serialization(model)) + { + throw new NotImplementedException(); + } + } + + IOptimizer? orig_optimizer = null; + if (!include_optimizer) + { + orig_optimizer = model.Optimizer; + model.Optimizer = null; + model._delete_tracking("optimizer"); + } + + IList saved_nodes; + IDictionary> node_paths; + // skip two scopes of python + using (KerasSavedModelUtils.keras_option_scope(save_traces)) + { + (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); + } + + var metadata = generate_keras_metadata(saved_nodes, node_paths); + File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); + //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); + + if (!include_optimizer) + { + model.Optimizer = orig_optimizer!; + } + } + + public static SavedMetadata generate_keras_metadata(IList saved_nodes, + IDictionary> node_paths) + { + var metadata = new SavedMetadata(); + for (int i = 0; i < saved_nodes.Count; i++) + { + var node = saved_nodes[i]; + if (node is not Layer) + { + continue; + } + + Layer layer = (Layer)node; + + var path = node_paths[node]; + string node_path; + if (path is null || path.Count() == 0) + { + node_path = "root"; + } + else + { + node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; + } + + ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() + { + NodeId = i, + NodePath = node_path, + Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() + { + Producer = 2, + MinConsumer = 1, + BadConsumers = { } + }, + Identifier = layer.ObjectIdentifier, + Metadata = layer.GetTrackingMetadata() + }; + + metadata.Nodes.Add(saved_object); + } + + return metadata; + } + + public static bool should_skip_serialization(object layer) + { + return false; + } + + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. + + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }).ToArray()); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }).ToArray()); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }).ToArray()); + var layers = TrackableDataStructure.wrap_or_unwrap(list_all_layers(layer).Select(x => x.GetTrackable()).ToArray()); + + Dictionary res = new(); + Debug.Assert(variables is Trackable); + Debug.Assert(trainable_variables is Trackable); + Debug.Assert(non_trainable_variables is Trackable); + Debug.Assert(layers is Trackable); + res["variables"] = variables as Trackable; + res["trainable_variables"] = trainable_variables as Trackable; + res["non_trainable_variables"] = non_trainable_variables as Trackable; + res["layers"] = layers as Trackable; + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) + { + + // high priority + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs new file mode 100644 index 000000000..eb88c8953 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -0,0 +1,37 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Newtonsoft.Json; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public abstract class SavedModelSaver +{ + protected Trackable _obj; + public SavedModelSaver(Trackable obj) + { + _obj = obj; + } + + public abstract string ObjectIdentifier { get; } + public abstract string TrackingMetadata { get; } + + public abstract IDictionary objects_to_serialize( + IDictionary> serialization_cache); + + public abstract IDictionary functions_to_serialize( + IDictionary> serialization_cache); + + public IDictionary trackable_children(IDictionary> serialization_cache) + { + if (!KerasSavedModelUtils.ShouldHaveTraces) + { + return new Dictionary(); + } + + var children = objects_to_serialize(serialization_cache); + return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + .ToDictionary(x => x.Key, x => x.Value); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs new file mode 100644 index 000000000..03693cb57 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -0,0 +1,165 @@ +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class LayerSavedModelSaver: SavedModelSaver +{ + private Layer _layer; + public LayerSavedModelSaver(Layer obj): base(obj) + { + _obj = obj; + _layer = obj; + } + public override string ObjectIdentifier + { + get => Constants.LAYER_IDENTIFIER; + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return get_serialized_attributes(serialization_cache).ObjectsToSerialize; + } + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return get_serialized_attributes(serialization_cache).FunctionsToSerialize; + } + + /// + /// Generates or retrieves serialized attributes from cache. + /// + /// + protected ISerializedAttributes get_serialized_attributes(IDictionary> serialization_cache) + { + // TODO: deal with cache. + IDictionary keras_cache; + if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) + { + keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; + } + else + { + serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary(); + } + if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; + + var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); + + // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. + if (KerasSavedModelUtils.should_skip_serialization(_obj)) + { + return serialized_attr; + } + + var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); + + serialized_attr.set_and_validate_objects(object_dict); + serialized_attr.set_and_validate_functions(function_dict); + return serialized_attr; + } + + /// + /// Returns dictionary of serialized attributes. + /// + /// + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) + { + var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); + + functions["_default_save_signature"] = null; + + return (objects, functions); + } + + public override string TrackingMetadata + { + get + { + JObject metadata = new JObject(); + metadata["name"] = _layer.Name; + metadata["trainable"] = _layer.Trainable; + // TODO: implement `expects_training_arg`. + metadata["expects_training_arg"] = false; + metadata["dtype"] = _layer.DType.as_python_name(); + metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); + // metadata["stateful"] = _obj.stateful; + // metadata["must_restore_from_config"] = _obj.must_restore_from_config; + // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; + metadata["autocast"] = _layer.AutoCast; + + if(_layer.InputSpec is not null) + { + metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); + } + + metadata.Merge(get_serialized(_layer), new JsonMergeSettings + { + // Handle conflicts by using values from obj2 + MergeArrayHandling = MergeArrayHandling.Merge + }); + // skip the check of `input_spec` and `build_input_shape` for the lack of members. + // skip the check of `activity_regularizer` for the type problem. + if(_layer.BuildInputShape is not null) + { + metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); + } + return metadata.ToString(); + } + } + + public static JObject get_serialized(Layer obj) + { + return generic_utils.serialize_keras_object(obj); + } +} + +public class InputLayerSavedModelSaver: SavedModelSaver +{ + public InputLayerSavedModelSaver(Layer obj) : base(obj) + { + + } + public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override string TrackingMetadata + { + get + { + if(_obj is not InputLayer) + { + throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); + } + var layer = (InputLayer)_obj; + var config = (layer.get_config() as InputLayerArgs)!; + var info = new + { + class_name = layer.GetType().Name, + name = layer.Name, + dtype = layer.DType, + sparse = config.Sparse, + ragged = config.Ragged, + batch_input_shape = layer.BatchInputShape, + config = layer.get_config() + }; + return JsonConvert.SerializeObject(info); + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs new file mode 100644 index 000000000..091dbb810 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs @@ -0,0 +1,89 @@ +using System.IO; +using Tensorflow.Train; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class KerasLoadModelUtils +{ + /// + /// Corresponding to keras/saving/save.py/load_model + /// + /// + /// + /// + /// + /// + public static Trackable load_model(string filepath, IDictionary? custom_objects = null, + bool compile = true, LoadOptions? options = null) + { + using var savingScope = SharedObjectSavingScope.Enter(); + + using var ctx = LoadContext.load_context(options); + + if (!File.Exists(filepath) && !Directory.Exists(filepath)) + { + throw new IOException($"No file or directory found at {filepath}."); + } + + if (Directory.Exists(filepath)) + { + return load(filepath, compile, options); + } + else + { + throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); + } + } + + private static Trackable load(string path, bool compile = true, LoadOptions? options = null) + { + SavedMetadata metadata; + var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; + var object_graph_def = meta_graph_def.ObjectGraphDef; + string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); + if (File.Exists(path_to_metadata_pb)) + { + using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read); + metadata = SavedMetadata.Parser.ParseFrom(stream); + } + else + { + throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" + + " use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it."); + } + + if (metadata.Nodes is null || metadata.Nodes.Count == 0) + { + return Loader.load(path, options: options) as Model; + } + + var keras_loader = new KerasObjectLoader(metadata, object_graph_def); + keras_loader.load_layers(compile: compile); + + Dictionary)> nodes_to_load = new(); + nodes_to_load["root"] = (null, null); + foreach(var item in keras_loader.LoadedNodes) + { + nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; + } + var loaded = Loader.load_partial(path, nodes_to_load, options); + + keras_loader.finalize_objects(); + keras_loader.del_tracking(); + + var model = loaded["root"]; + + if (model is Model && compile) + { + // TODO(Rinne): implement it. + } + + if (!tf.Context.executing_eagerly()) + { + // TODO(Rinne): implement it. + } + + return model; + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs new file mode 100644 index 000000000..11b1201d0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + // TODO: remove this class to common project. + public class ContextHandler: IDisposable + { + public Action DisposeCallBack { get; set; } + public void Dispose() + { + DisposeCallBack.Invoke(true); + } + } + public class LoadContext + { + private bool _entered_load_context; + private LoadOptions? _load_options; + private static ThreadLocal _load_context = new(); + private LoadContext() + { + _entered_load_context = false; + _load_options = null; + } + + public void set_load_options(LoadOptions load_options) + { + _load_options = load_options; + _entered_load_context = true; + } + + private void clear_load_options() + { + _load_options = null; + _entered_load_context = false; + } + + private LoadOptions? load_options() + { + return _load_options; + } + + public static ContextHandler load_context(LoadOptions? load_options) + { + if(_load_context.Value is null) + { + _load_context.Value = new LoadContext(); + } + _load_context.Value.set_load_options(load_options); + return new ContextHandler() + { + DisposeCallBack = _ => _load_context.Value.clear_load_options() + }; + } + + public static LoadOptions? get_load_option() + { + return _load_context.Value.load_options(); + } + + public static bool in_load_context() + { + return _load_context.Value._entered_load_context; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs new file mode 100644 index 000000000..325d3327a --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -0,0 +1,284 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Metrics; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, + // Using the name "Attributes" may be quite confusing. + /// + /// Class that tracks and validates all serialization attributes. + /// + public abstract class SerializedAttributes: ISerializedAttributes + { + protected IDictionary _object_dict; + protected IDictionary _function_dict; + protected AutoTrackable _keras_trackable; + internal HashSet _all_functions; + internal HashSet _all_checkpointable_objects; + + private SerializedAttributes() + { + _object_dict= new Dictionary(); + _function_dict= new Dictionary(); + _keras_trackable= new AutoTrackable(); + _all_functions= new HashSet(); + _all_checkpointable_objects= new HashSet(); + } + + protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(checkpointable_objects); + _all_functions = new HashSet(functions); + } + + protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); + _all_functions = new HashSet(objects_and_functions.Item2); + } + + public IDictionary Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + public IDictionary CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + /// + /// Returns functions to attach to the root object during serialization. + /// + public IDictionary FunctionsToSerialize + { + get + { + Dictionary functions = new(); + foreach(var pair in Functions) + { + if (_all_functions.Contains(pair.Key)) + { + // TODO: deal with `LayerCall`. + functions[pair.Key] = pair.Value; + } + } + return functions; + } + } + + /// + /// Returns objects to attach to the root object during serialization. + /// + public IDictionary ObjectsToSerialize + { + get + { + var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + objects[Constants.KERAS_ATTR] = _keras_trackable; + return objects; + } + } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + public IDictionary set_and_validate_functions(IDictionary function_dict) + { + foreach(var key in _all_functions) + { + if (function_dict.ContainsKey(key)) + { + // TODO: deal with type `LayerCall`. + var fn = function_dict[key]; + if (fn is not null && (fn is not Function)) + { + throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); + } + _function_dict[key] = fn; + + var tf_fn = fn; // TODO: deal with type `LayerCall`. + + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if(property.Name == key) + { + property.SetValue(_keras_trackable, tf_fn); + break; + } + } + } + else + { + // high priority + // TODO(Rinne): complete the implementation. + continue; + //throw new ValueError($"Function {key} missing from serialized function dict."); + } + } + return Functions; + } + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + public IDictionary set_and_validate_objects(IDictionary object_dict) + { + foreach(var key in _all_checkpointable_objects) + { + if (object_dict.ContainsKey(key)) + { + _object_dict[key] = object_dict[key]; + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if (property.Name == key) + { + property.SetValue(_keras_trackable, object_dict[key]); + break; + } + } + } + else + { + // high priority. + // TODO(Rinne): Add the implementation. + continue; + //throw new ValueError($"Object {key} missing from serialized object dict."); + } + } + return CheckpointableObjects; + } + + /// + /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). + /// + /// + public static SerializedAttributes Create(Trackable obj) + { + if(obj is Model) + { + return new ModelAttributes(); + } + else if(obj is Metric) + { + return new MetricAttributes(); + } + else if(obj is RNN) + { + return new RNNAttributes(); + } + else if(obj is Layer) + { + return new LayerAttributes(); + } + else + { + throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); + } + } + + protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + return (checkpointable_objects ?? (new List()), functions ?? (new List())); + } + } + + // Note that the current implementation still has some potential risks. + // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". + // However, currently it's just a normal class. + public class CommonEndPoints: SerializedAttributes + { + public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) + { + + } + + public CommonEndPoints() : + base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + { + + } + } + + public class LayerAttributes: CommonEndPoints + { + public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), + functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) + { + + } + + public LayerAttributes() : + //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, + // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(new string[] { "non_trainable_variables", "layers" }, + new string[] { }) + { + + } + } + + public class ModelAttributes: LayerAttributes + { + public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): + base(checkpointable_objects, functions) + { + + } + + public ModelAttributes(): base() + { + + } + } + + public class MetricAttributes : SerializedAttributes + { + public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables" }), functions) + { + + } + + public MetricAttributes() : + base(new string[] { "variables" }, new string[] {}) + { + + } + } + + public class RNNAttributes: LayerAttributes + { + public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects, functions.Concat(new string[] {"states"})) + { + + } + + public RNNAttributes() : + base(new string[] { }, new string[] { "states" }) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..51f8d2c91 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool ShouldHaveTraces { get; internal set; } = true; + + public static SaveOptionsContext keras_option_scope(bool save_traces) + { + var res = new SaveOptionsContext(ShouldHaveTraces); + ShouldHaveTraces = save_traces; + return res; + } + + public static IEnumerable list_all_layers(Layer layer) + { + if(layer is Model) + { + return (layer as Model).Layers; + } + else + { + return new List(layer._flatten_layers(false, false)); + } + } +} + +/// +/// Implementation of this class is different with that of python. +/// But it could be used with `using` the same as `with` of python. +/// +public class SaveOptionsContext: IDisposable +{ + public bool _old_value; + public SaveOptionsContext(bool old_value) + { + _old_value = old_value; + } + + public void Dispose() + { + KerasSavedModelUtils.ShouldHaveTraces = _old_value; + } +} diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs new file mode 100644 index 000000000..68b73953d --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -0,0 +1,354 @@ +using System; +using System.Collections.Generic; +using System.Text; +using HDF.PInvoke; +using Tensorflow.NumPy; +using HDF5CSharp; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Linq; +using System.Text.RegularExpressions; + +namespace Tensorflow.Keras.Saving +{ + public class hdf5_format + { + private static int HDF5_OBJECT_HEADER_LIMIT = 64512; + public static void load_model_from_hdf5(string filepath = "", Dictionary custom_objects = null, bool compile = false) + { + long root = Hdf5.OpenFile(filepath,true); + load_model_from_hdf5(root, custom_objects, compile); + } + public static void load_model_from_hdf5(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + //long fileId = filepath; + //try + //{ + // groupId = H5G.open(fileId, "/"); + // (bool success, string[] attrId) = Hdf5.ReadStringAttributes(groupId, "model_config", ""); + // H5G.close(groupId); + // if (success == true) { + // Console.WriteLine(attrId[0]); + // } + //} + //catch (Exception ex) + //{ + // if (filepath != -1) { + // Hdf5.CloseFile(filepath); + // } + // if (groupId != -1) { + // H5G.close(groupId); + // } + // throw new Exception(ex.ToString()); + //} + + } + public static void save_model_to_hdf5(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + /// + /// Preprocess layer weights between different Keras formats. + /// + /// + /// + /// + /// + public static List preprocess_weights_for_loading(ILayer layer, List weights, string original_keras_version = null, string original_backend = null) + { + // convert CuDNN layers + return _convert_rnn_weights(layer, weights); + } + + /// + /// Converts weights for RNN layers between native and CuDNN format. + /// + /// + /// + static List _convert_rnn_weights(ILayer layer, List weights) + { + var target_class = layer.GetType().Name; + return weights; + } + + public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List layers) + { + string original_keras_version = "2.5.0"; + string original_backend = null; + var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "", true); + if (success) + original_keras_version = attr.First(); + // keras version should be 2.5.0+ + var ver_major = int.Parse(original_keras_version.Split('.')[0]); + var ver_minor = int.Parse(original_keras_version.Split('.')[1]); + if (ver_major < 2 || (ver_major == 2 && ver_minor < 5)) + throw new ValueError("keras version should be 2.5.0 or later."); + + (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "", true); + if (success) + original_backend = attr.First(); + + var filtered_layers = new List(); + foreach (var layer in layers) + { + var weights = _legacy_weights(layer); + if (weights.Count > 0) + filtered_layers.append(layer); + } + + string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names"); + var filtered_layer_names = new List(); + foreach(var name in layer_names) + { + if (!filtered_layers.Select(x => x.Name).Contains(name)) + continue; + long g = H5G.open(f, name); + var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + if (weight_names.Count() > 0) + filtered_layer_names.Add(name); + H5G.close(g); + } + + layer_names = filtered_layer_names.ToArray(); + if (layer_names.Length != filtered_layers.Count()) + throw new ValueError("You are trying to load a weight file " + + $"containing {layer_names}" + + $" layers into a model with {filtered_layers.Count} layers."); + + var weight_value_tuples = new List<(IVariableV1, NDArray)>(); + foreach (var (k, name) in enumerate(layer_names)) + { + var weight_values = new List(); + long g = H5G.open(f, name); + var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + foreach (var i_ in weight_names) + { + (success, Array result) = Hdf5.ReadDataset(g, i_); + if (success) + weight_values.Add(np.array(result)); + } + H5G.close(g); + var layer = filtered_layers[k]; + var symbolic_weights = _legacy_weights(layer); + preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend); + if (weight_values.Count() != symbolic_weights.Count()) + throw new ValueError($"Layer #{k} (named {layer.Name}" + + "in the current model) was found to " + + $"correspond to layer {name} in the save file." + + $"However the new layer {layer.Name} expects " + + $"{symbolic_weights.Count()} weights, but the saved weights have " + + $"{weight_values.Count()} elements."); + weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); + } + + return weight_value_tuples; + } + + public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static void save_weights_to_hdf5_group(long f, List layers) + { + List layerName=new List(); + foreach (var layer in layers) + { + layerName.Add(layer.Name); + } + save_attributes_to_hdf5_group(f, "layer_names", layerName.ToArray()); + Hdf5.WriteAttribute(f, "backend", "tensorflow"); + Hdf5.WriteAttribute(f, "keras_version", "2.5.0"); + + foreach (var layer in layers) + { + var weights = _legacy_weights(layer); + if (weights.Count == 0) + continue; + + var weight_names = new List(); + // weight_values= keras.backend.batch_get_value(weights); + foreach (var weight in weights) + weight_names.Add(weight.Name); + + var g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name)); + save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray()); + foreach (var (name, val) in zip(weight_names, weights)) + { + var tensor = val.AsTensor(); + if (name.IndexOf("/") > 1) + { + var crDataGroup = g; + string[] name_split = name.Split('/'); + for(int i = 0; i < name_split.Length - 1; i++) + { + crDataGroup = Hdf5.CreateOrOpenGroup(crDataGroup, Hdf5Utils.NormalizedName(name_split[i])); + } + WriteDataset(crDataGroup, name_split[name_split.Length - 1], tensor); + Hdf5.CloseGroup(crDataGroup); + } + else + { + WriteDataset(g, name, tensor); + } + } + Hdf5.CloseGroup(g); + } + } + + private static void save_attributes_to_hdf5_group(long f, string name, Array data) + { + int num_chunks = 1; + + var chunked_data = Split(data, num_chunks); + int getSize = 0; + + string getType = data.Length > 0 ? data.GetValue(0).GetType().Name.ToLower() : "string"; + + switch (getType) + { + case "single": + getSize = sizeof(float); + break; + case "double": + getSize = sizeof(double); + break; + case "string": + getSize = -1; + break; + case "int32": + getSize = sizeof(int); + break; + case "int64": + getSize = sizeof(long); + break; + default: + getSize = -1; + break; + } + int getCount = chunked_data.Count; + + if (getSize != -1) + { + num_chunks = (int)Math.Ceiling((double)(getCount * getSize) / HDF5_OBJECT_HEADER_LIMIT); + if (num_chunks > 1) chunked_data = Split(data, num_chunks); + } + + if (num_chunks > 1) + { + foreach (var (chunk_id, chunk_data) in enumerate(chunked_data)) + WriteAttrs(f, getType, $"{name}{chunk_id}", chunk_data.ToArray()); + } + else + { + WriteAttrs(f, getType, name, data); + } + } + + private static void WriteDataset(long f, string name, Tensor data) + { + switch (data.dtype) + { + case TF_DataType.TF_FLOAT: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); + break; + case TF_DataType.TF_DOUBLE: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); + break; + case TF_DataType.TF_INT32: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); + break; + case TF_DataType.TF_INT64: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); + break; + default: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMultiDimArray()); + break; + } + } + + private static void WriteAttrs(long f,string typename, string name, Array data) + { + switch (typename) + { + case "single": + Hdf5.WriteAttributes(f, name, data); + break; + case "double": + Hdf5.WriteAttributes(f, name, data); + break; + case "string": + Hdf5.WriteAttributes(f, name, data); + break; + case "int32": + Hdf5.WriteAttributes(f, name, data); + break; + case "int64": + Hdf5.WriteAttributes(f, name, data); + break; + default: + Hdf5.WriteAttributes(f, name,data); + break; + } + } + + private static List> Split(Array list, int chunkSize) + { + var splitList = new List>(); + var chunkCount = (int)Math.Ceiling((double)list.Length / (double)chunkSize); + + for (int c = 0; c < chunkCount; c++) + { + var skip = c * chunkSize; + var take = skip + chunkSize; + var chunk = new List(chunkSize); + + for (int e = skip; e < take && e < list.Length; e++) + { + chunk.Add(list.GetValue(e)); + } + splitList.Add(chunk); + } + + return splitList; + } + + public static string[] load_attributes_from_hdf5_group(long group, string name) + { + var (success, attr) = Hdf5.ReadStringAttributes(group, name, "", true); + if (success) + return attr.ToArray(); + + return null; + } + + public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static List _legacy_weights(ILayer layer) + { + var weights = layer.TrainableWeights.Select(x => x).ToList(); + weights.AddRange(layer.NonTrainableWeights); + return weights; + } + } +} + diff --git a/src/TensorFlowNET.Keras/Saving/serialization.cs b/src/TensorFlowNET.Keras/Saving/serialization.cs new file mode 100644 index 000000000..d5e46d11c --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/serialization.cs @@ -0,0 +1,125 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving +{ + // TODO: make it thread safe. + public class SharedObjectSavingScope: IDisposable + { + private class WeakReferenceEqualityComparer: IEqualityComparer> + { + public bool Equals(WeakReference x, WeakReference y) + { + if(!x.TryGetTarget(out var tx)) + { + return false; + } + if(!y.TryGetTarget(out var ty)) + { + return false; + } + return tx.Equals(ty); + } + public int GetHashCode(WeakReference obj) + { + if (!obj.TryGetTarget(out var w)) + { + return 0; + } + return w.GetHashCode(); + } + } + private static SharedObjectSavingScope? _instance = null; + private readonly Dictionary, int> _shared_object_ids= new Dictionary, int>(); + private int _currentId = 0; + /// + /// record how many times the scope is nested. + /// + private int _nestedDepth = 0; + private SharedObjectSavingScope() + { + + } + + public static SharedObjectSavingScope Enter() + { + if(_instance is not null) + { + _instance._nestedDepth++; + return _instance; + } + else + { + _instance = new SharedObjectSavingScope(); + _instance._nestedDepth++; + return _instance; + } + } + + public static SharedObjectSavingScope GetScope() + { + return _instance; + } + + public int GetId(object? obj) + { + if(obj is null) + { + return _currentId++; + } + var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference(obj))); + if (maybe_key is not null) + { + return _shared_object_ids[maybe_key]; + } + _shared_object_ids[new WeakReference(obj)] = _currentId++; + return _currentId; + } + + public void Dispose() + { + _nestedDepth--; + if(_nestedDepth== 0) + { + _instance = null; + } + } + } + + public static class serialize_utils + { + public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; + /// + /// Returns the serialization of the class with the given config. + /// + /// + /// + /// + /// + /// + public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) + { + JObject res = new JObject(); + res["class_name"] = class_name; + res["config"] = config; + + if(shared_object_id is not null) + { + res[SHARED_OBJECT_KEY] = shared_object_id!; + } + + var scope = SharedObjectSavingScope.GetScope(); + if(scope is not null && obj is not null) + { + res[SHARED_OBJECT_KEY] = scope.GetId(obj); + } + + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Sequence.cs b/src/TensorFlowNET.Keras/Sequence.cs new file mode 100644 index 000000000..cda3f30fe --- /dev/null +++ b/src/TensorFlowNET.Keras/Sequence.cs @@ -0,0 +1,75 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras +{ + public class Sequence + { + /// + /// Pads sequences to the same length. + /// https://keras.io/preprocessing/sequence/ + /// https://faroit.github.io/keras-docs/1.2.0/preprocessing/sequence/ + /// + /// List of lists, where each element is a sequence. + /// Int, maximum length of all sequences. + /// Type of the output sequences. + /// String, 'pre' or 'post': + /// String, 'pre' or 'post' + /// Float or String, padding value. + /// + public NDArray pad_sequences(IEnumerable sequences, + int? maxlen = null, + string dtype = "int32", + string padding = "pre", + string truncating = "pre", + object value = null) + { + if (value != null) throw new NotImplementedException("padding with a specific value."); + if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'."); + if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'."); + + var length = sequences.Select(s => s.Length); + + if (maxlen == null) + maxlen = length.Max(); + + if (value == null) + value = 0f; + + var type = dtypes.tf_dtype_from_name(dtype); + var nd = np.zeros((length.Count(), maxlen.Value), dtype: type); + + for (int i = 0; i < nd.dims[0]; i++) + { + var s = sequences.ElementAt(i); + if (s.Length > maxlen.Value) + { + s = (truncating == "pre") ? s.Skip(s.Length - maxlen.Value).ToArray() : s.Take(maxlen.Value).ToArray(); + } + var sliceString = (padding == "pre") ? $"{i},{maxlen - s.Length}:" : $"{i},:{s.Length}"; + var slices = sliceString.Split(',').Select(x => new Slice(x)).ToArray(); + nd[slices] = np.array(s); + } + + return nd; + } + } +} diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj new file mode 100644 index 000000000..eb8ebf93c --- /dev/null +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -0,0 +1,163 @@ + + + + netstandard2.0;net6.0 + Tensorflow.Keras + 10.0 + enable + Tensorflow.Keras + AnyCPU;x64 + 0.15.0 + Haiping Chen + Keras for .NET + Apache 2.0, Haiping Chen since 2018 + TensorFlow.Keras + https://github.com/SciSharp/TensorFlow.NET + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + https://github.com/SciSharp/TensorFlow.NET + + Keras for .NET is a C# version of Keras ported from the python version. + + * Support CIFAR-10 dataset in keras.datasets. + * Support Conv2D functional API. + * Support BatchNormalization layer. + * Building keras model in subclass, functional and sequential api + * Implemented backward_function. + * Support model.load_weights. + * Add Subtract layer + * Text preprocessing + * Preprocessing.timeseries_dataset_from_array + * Fixed memory leak for YOLOv3 model. + * Support RNN and LSTM models + * Support Transformer model + * Support BERT model + + Keras for .NET + +Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. + SciSharp STACK + False + tensorflow, keras, deep learning, machine learning + true + packages + Git + False + Open.snk + 0.15.0.0 + 0.15.0.0 + LICENSE + Debug;Release;GPU + + + + DEBUG;TRACE + false + + + + DEBUG;TRACE + false + + + + false + + + + Tensorflow.Keras.xml + + + + Tensorflow.Keras.xml + + + + + + + + True + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + True + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + True + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + True + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + False + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + False + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + False + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + False + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + 1 + $(NoWarn),1573,1591,1712,8602,8603,8625,CS0612 + + + + + + + + + + + True + + + + + + + + + diff --git a/src/TensorFlowNET.Keras/TextApi.cs b/src/TensorFlowNET.Keras/TextApi.cs new file mode 100644 index 000000000..8ce8d6859 --- /dev/null +++ b/src/TensorFlowNET.Keras/TextApi.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Text; + +namespace Tensorflow.Keras +{ + public class TextApi + { + public Tensorflow.Keras.Text.Tokenizer Tokenizer( + int num_words = -1, + string filters = DefaultFilter, + bool lower = true, + char split = ' ', + bool char_level = false, + string oov_token = null, + Func> analyzer = null) + { + return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer); + } + + public static IEnumerable text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ') + { + if (lower) + { + text = text.ToLower(); + } + var newText = new String(text.Where(c => !filters.Contains(c)).ToArray()); + return newText.Split(split); + } + + private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n"; + } +} diff --git a/src/TensorFlowNET.Keras/Utils/Compress.cs b/src/TensorFlowNET.Keras/Utils/Compress.cs new file mode 100644 index 000000000..397108868 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/Compress.cs @@ -0,0 +1,105 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using ICSharpCode.SharpZipLib.Core; +using ICSharpCode.SharpZipLib.GZip; +using ICSharpCode.SharpZipLib.Tar; +using System; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow.Keras.Utils +{ + public class Compress + { + public static void ExtractGZip(string gzipFileName, string targetDir) + { + // Use a 4K buffer. Any larger is a waste. + byte[] dataBuffer = new byte[4096]; + + using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) + { + using (GZipInputStream gzipStream = new GZipInputStream(fs)) + { + // Change this to your needs + string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); + + using (FileStream fsOut = File.Create(fnOut)) + { + StreamUtils.Copy(gzipStream, fsOut, dataBuffer); + } + } + } + } + + public static void UnZip(String gzArchiveName, String destFolder) + { + var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; + if (File.Exists(Path.Combine(destFolder, flag))) return; + + var destFileName = gzArchiveName.Replace(".zip", string.Empty); + if (File.Exists(destFileName)) return; + + Binding.tf_output_redirect.WriteLine($"Extracting."); + var task = Task.Run(() => + { + ZipFile.ExtractToDirectory(gzArchiveName, destFolder); + }); + + while (!task.IsCompleted) + { + Thread.Sleep(200); + Binding.tf_output_redirect.Write("."); + } + + File.Create(Path.Combine(destFolder, flag)); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extracting is completed."); + } + + public static void ExtractTGZ(String gzArchiveName, String destFolder) + { + var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; + if (File.Exists(Path.Combine(destFolder, flag))) return; + + Binding.tf_output_redirect.WriteLine($"Extracting."); + var task = Task.Run(() => + { + using (var inStream = File.OpenRead(gzArchiveName)) + { + using (var gzipStream = new GZipInputStream(inStream)) + { + using (TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream)) + tarArchive.ExtractContents(destFolder); + } + } + }); + + while (!task.IsCompleted) + { + Thread.Sleep(200); + Binding.tf_output_redirect.Write("."); + } + + File.Create(Path.Combine(destFolder, flag)); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extracting is completed."); + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/KerasUtils.cs b/src/TensorFlowNET.Keras/Utils/KerasUtils.cs new file mode 100644 index 000000000..567bee91e --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/KerasUtils.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Utils +{ + public class KerasUtils + { + /// + /// Downloads a file from a URL if it not already in the cache. + /// + /// Name of the file + /// Original URL of the file + /// + /// + /// + /// + /// + /// + /// + /// + /// + public string get_file(string fname, string origin, + bool untar = false, + string md5_hash = null, + string file_hash = null, + string cache_subdir = "datasets", + string hash_algorithm = "auto", + bool extract = false, + string archive_format = "auto", + string cache_dir = null) + => data_utils.get_file(fname, origin, + untar: untar, + md5_hash: md5_hash, + file_hash: file_hash, + cache_subdir: cache_subdir, + hash_algorithm: hash_algorithm, + extract: extract, + archive_format: archive_format, + cache_dir: cache_dir); + } +} diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs new file mode 100644 index 000000000..1e9f6d845 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Layers; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Keras.Utils +{ + internal static class RnnUtils + { + internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure state_size, TF_DataType dtype) + { + Func create_zeros = (unnested_state_size) => + { + var flat_dims = new Shape(unnested_state_size).dims; + var init_state_size = new Tensor[] { batch_size_tensor }. + Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); + return array_ops.zeros(init_state_size, dtype: dtype); + }; + + // TODO(Rinne): map structure with nested tensors. + if(state_size.TotalNestedCount > 1) + { + return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray()); + } + else + { + return create_zeros(state_size.Flatten().First()); + } + + } + + internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + if (inputs is not null) + { + batch_size = array_ops.shape(inputs)[0]; + dtype = inputs.dtype; + } + return generate_zero_filled_state(batch_size, cell.StateSize, dtype); + } + + /// + /// Standardizes `__call__` to a single list of tensor inputs. + /// + /// When running a model loaded from a file, the input tensors + /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part + /// of `inputs` instead of by the dedicated keyword arguments.This method + /// makes sure the arguments are separated and that `initial_state` and + /// `constants` are lists of tensors(or None). + /// + /// Tensor or list/tuple of tensors. which may include constants + /// and initial states.In that case `num_constant` must be specified. + /// Tensor or list of tensors or None, initial states. + /// Tensor or list of tensors or None, constant tensors. + /// Expected number of constants (if constants are passed as + /// part of the `inputs` list. + /// + internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants) + { + if(inputs.Length > 1) + { + // There are several situations here: + // In the graph mode, __call__ will be only called once. The initial_state + // and constants could be in inputs (from file loading). + // In the eager mode, __call__ will be called twice, once during + // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be + // model.fit/train_on_batch/predict with real np data. In the second case, + // the inputs will contain initial_state and constants as eager tensor. + // + // For either case, the real input is the first item in the list, which + // could be a nested structure itself. Then followed by initial_states, which + // could be a list of items, or list of list if the initial_state is complex + // structure, and finally followed by constants which is a flat list. + Debug.Assert(initial_state is null && constants is null); + if(num_constants > 0) + { + constants = inputs.TakeLast(num_constants).ToArray().ToTensors(); + inputs = inputs.SkipLast(num_constants).ToArray().ToTensors(); + } + if(inputs.Length > 1) + { + initial_state = inputs.Skip(1).ToArray().ToTensors(); + inputs = inputs.Take(1).ToArray().ToTensors(); + } + } + + return (inputs, initial_state, constants); + } + + /// + /// Check whether the state_size contains multiple states. + /// + /// + /// + public static bool is_multiple_state(INestStructure state_size) + { + return state_size.TotalNestedCount > 1; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/Web.cs b/src/TensorFlowNET.Keras/Utils/Web.cs new file mode 100644 index 000000000..9f10feb8b --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/Web.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.IO; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow.Keras.Utils +{ + public class Web + { + public static bool Download(string url, string destDir, string destFileName) + { + if (destFileName == null) + destFileName = url.Split(Path.DirectorySeparatorChar).Last(); + + Directory.CreateDirectory(destDir); + + string relativeFilePath = Path.Combine(destDir, destFileName); + + if (File.Exists(relativeFilePath)) + { + Binding.tf_output_redirect.WriteLine($"{relativeFilePath} already exists."); + return false; + } + + var wc = new WebClient(); + Binding.tf_output_redirect.WriteLine($"Downloading from {url}"); + var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); + while (!download.IsCompleted) + { + Thread.Sleep(1000); + Binding.tf_output_redirect.Write("."); + } + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine($"Downloaded to {relativeFilePath}"); + + return true; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs new file mode 100644 index 000000000..e6c9ed422 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -0,0 +1,194 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Utils +{ + public class base_layer_utils + { + /// + /// Adds a new variable to the layer. + /// + /// + /// + public static IVariableV1 make_variable(VariableArgs args) + { +#pragma warning disable CS0219 // Variable is assigned but its value is never used + var initializing_from_value = false; +#pragma warning restore CS0219 // Variable is assigned but its value is never used + + Func init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType)); + + var variable_dtype = args.DType.as_base_dtype(); + return tf.Variable(init_val, + dtype: variable_dtype, + shape: args.Shape, + name: args.Name, + trainable: args.Trainable, + validate_shape: args.ValidateShape, + use_resource: args.UseResource); + } + + /// + /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) + /// + /// + /// + public static string unique_layer_name(string name, Dictionary name_uid_map = null, + string[] avoid_names = null, bool zero_based = false) + { + if (name_uid_map == null) + name_uid_map = get_default_graph_uid_map(); + if (avoid_names == null) + avoid_names = new string[0]; + + string proposed_name = null; + while (proposed_name == null || avoid_names.Contains(proposed_name)) + { + if (!name_uid_map.ContainsKey(name)) + name_uid_map[name] = 0; + + if (zero_based) + { + int number = name_uid_map[name]; + if (number > 0) + proposed_name = $"{name}_{number}"; + else + proposed_name = name; + + name_uid_map[name] += 1; + } + else + { + name_uid_map[name] += 1; + proposed_name = $"{name}_{name_uid_map[name]}"; + } + } + + return proposed_name; + } + + public static Dictionary get_default_graph_uid_map() + { + var graph = ops.get_default_graph(); + Dictionary name_uid_map = null; + if (keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) + { + name_uid_map = keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; + } + else + { + name_uid_map = new Dictionary(); + keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; + } + + return name_uid_map; + } + + public static bool needs_keras_history(Tensors inputs) + { + if (inputs.Any(x => x.KerasHistory == null)) + return true; + + return false; + } + + public static Layer[] create_keras_history(Tensors inputs) + { + var processed_ops = new List(); + var created_layers = new List(); + CreateKerasHistoryHelper(inputs, processed_ops, created_layers); + return created_layers.ToArray(); + } + + public static void CreateKerasHistoryHelper(Tensors tensors, List processed_ops, List created_layers) + { + foreach (var tensor in tensors) + { + if (tensor.KerasHistory != null) + continue; + + var op = tensor.op; + if (!processed_ops.Contains(op)) + { + var layer_inputs = new List(); + var constants = new Dictionary(); + foreach (var (i, op_input) in enumerate(op.inputs._inputs)) + { + if (uses_keras_history(op_input)) + layer_inputs.Add(op_input); + else + { + tf_with(ops.init_scope(), delegate + { + constants[i] = keras.backend.eval_in_eager_or_function(op_input); + }); + } + } + + // recursively + CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); + var opLayerArgs = new TensorFlowOpLayerArgs + { + NodeDef = op.node_def, + Constants = constants, + Name = op.name + }; + var op_layer = new TensorFlowOpLayer(opLayerArgs); + created_layers.Add(op_layer); + op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); + processed_ops.Add(op); + } + } + } + + public static bool has_weights(object obj) + { + var obj_type = obj.GetType(); + return obj_type.GetField("trainable_weights") is not null && + obj_type.GetField("non_trainable_weights") is not null && + obj is not Type; + } + + public static Tensor generate_placeholders_from_shape(Shape shape) + { + return array_ops.placeholder(keras.backend.floatx(), shape); + } + + // recusive + static bool uses_keras_history(Tensor op_input) + { + if (op_input.KerasHistory != null) + return true; + + foreach (var input in op_input.op.inputs._inputs) + if (uses_keras_history(input)) + return true; + + return false; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/compile_utils.cs b/src/TensorFlowNET.Keras/Utils/compile_utils.cs new file mode 100644 index 000000000..cd4112616 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/compile_utils.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework.Models; +using Tensorflow.Util; + +namespace Tensorflow.Keras.Utils +{ + internal static class compile_utils + { + public static List create_pseudo_input_names(TensorSpec inputs) + { + return _create_pseudo_names(inputs, "input_"); + } + + private static List _create_pseudo_names(TensorSpec tensors, string prefix) + { + // TODO(Rinne): align with tensorflow + return new List() { $"{prefix}1" }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Keras/Utils/conv_utils.cs new file mode 100644 index 000000000..baedca925 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/conv_utils.cs @@ -0,0 +1,97 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Utils +{ + public class conv_utils + { + public static string convert_data_format(string data_format, int ndim) + { + if (data_format == "channels_last") + if (ndim == 3) + return "NWC"; + else if (ndim == 4) + return "NHWC"; + else if (ndim == 5) + return "NDHWC"; + else + throw new ValueError($"Input rank not supported: {ndim}"); + else if (data_format == "channels_first") + if (ndim == 3) + return "NCW"; + else if (ndim == 4) + return "NCHW"; + else if (ndim == 5) + return "NCDHW"; + else + throw new ValueError($"Input rank not supported: {ndim}"); + else + throw new ValueError($"Invalid data_format: {data_format}"); + } + + public static int[] normalize_tuple(int[] value, int n, string name) + { + if (value.Length == 1) + return Enumerable.Range(0, n).Select(x => value[0]).ToArray(); + else + return value; + } + + public static string normalize_padding(string value) + { + return value.ToLower(); + } + + public static string normalize_data_format(string value) + { + if (string.IsNullOrEmpty(value)) + return ImageDataFormat.channels_last.ToString(); + return value.ToLower(); + } + + public static int deconv_output_length(int input_length, + int filter_size, + string padding, + int output_padding = -1, + int stride = 0, + int dilation = 1) + { + // Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1); + + // Infer length if output padding is None, else compute the exact length + int length = -1; + if (output_padding == -1) + { + if (padding == "valid") + length = input_length * stride + max(filter_size - stride, 0); + else if (padding == "full") + length = input_length * stride - (stride + filter_size - 2); + else if (padding == "same") + length = input_length * stride; + } + else + { + throw new NotImplementedException(""); + } + return length; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/data_utils.cs b/src/TensorFlowNET.Keras/Utils/data_utils.cs new file mode 100644 index 000000000..b0bc15540 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/data_utils.cs @@ -0,0 +1,92 @@ +using System; +using System.Linq; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Tensorflow.Keras.Utils +{ + public class data_utils + { + public static string get_file(string fname, string origin, + bool untar = false, + string md5_hash = null, + string file_hash = null, + string cache_subdir = "datasets", + string hash_algorithm = "auto", + bool extract = false, + string archive_format = "auto", + string cache_dir = null) + { + if (string.IsNullOrEmpty(cache_dir)) + cache_dir = Path.GetTempPath(); + var datadir_base = cache_dir; + Directory.CreateDirectory(datadir_base); + + var datadir = Path.Combine(datadir_base, cache_subdir); + Directory.CreateDirectory(datadir); + + Web.Download(origin, datadir, fname); + + var archive = Path.Combine(datadir, fname); + + if (untar) + Compress.ExtractTGZ(archive, datadir); + else if (extract && fname.EndsWith(".gz")) + Compress.ExtractGZip(archive, datadir); + else if (extract && fname.EndsWith(".zip")) + Compress.UnZip(archive, datadir); + + return datadir; + } + + public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] label) + { + /*Removes sequences that exceed the maximum length. + + Args: + maxlen: Int, maximum length of the output sequences. + seq: List of lists, where each sublist is a sequence. + label: List where each element is an integer. + + Returns: + new_seq, new_label: shortened lists for `seq` and `label`. + + */ + var nRow = seq.GetLength(0); + var nCol = seq.GetLength(1); + List new_seq = new List(); + List new_label = new List(); + + for (var i = 0; i < nRow; i++) + { + if (maxlen < nCol && seq[i, maxlen] != 0) + continue; + int[] sentence = new int[maxlen]; + for (var j = 0; j < maxlen && j < nCol; j++) + { + sentence[j] = seq[i, j]; + } + new_seq.Add(sentence); + new_label.Add(label[i]); + } + + int[,] new_seq_array = new int[new_seq.Count, maxlen]; + long[] new_label_array = new long[new_label.Count]; + + for (var i = 0; i < new_seq.Count; i++) + { + for (var j = 0; j < maxlen; j++) + { + new_seq_array[i, j] = new_seq[i][j]; + } + } + + for (var i = 0; i < new_label.Count; i++) + { + new_label_array[i] = new_label[i]; + } + return (new_seq_array, new_label_array); + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs new file mode 100644 index 000000000..20937e2e5 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -0,0 +1,162 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Security.AccessControl; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Train; +using System.Text.RegularExpressions; + +namespace Tensorflow.Keras.Utils +{ + public class generic_utils + { + private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; + /// + /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. + /// + /// + /// + public static LayerConfig serialize_layer_to_config(ILayer instance) + { + var config = instance.get_config(); + Debug.Assert(config is LayerArgs); + return new LayerConfig + { + Config = config as LayerArgs, + ClassName = instance.GetType().Name + }; + } + + public static JObject serialize_keras_object(IKerasConfigable instance) + { + var config = JToken.FromObject(instance.get_config()); + // TODO: change the class_name to registered name, instead of system class name. + return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); + } + + public static Layer deserialize_keras_object(string class_name, JToken config) + { + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + if(argType is null) + { + return null; + } + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + Debug.Assert(layer is Layer); + + // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + + return layer as Layer; + } + + public static Layer deserialize_keras_object(string class_name, LayerArgs args) + { + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + if (layer is null) + { + return null; + } + Debug.Assert(layer is Layer); + + // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + + return layer as Layer; + } + + public static LayerArgs deserialize_layer_args(string class_name, JToken config) + { + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + Debug.Assert(args is LayerArgs); + return args as LayerArgs; + } + + public static FunctionalConfig deserialize_model_config(JToken json) + { + FunctionalConfig config = new FunctionalConfig(); + config.Name = json["name"].ToObject(); + config.Layers = new List(); + var layersToken = json["layers"]; + foreach (var token in layersToken) + { + var args = deserialize_layer_args(token["class_name"].ToObject(), token["config"]); + + List nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array + if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0) + { + nodeConfig = token["inbound_nodes"].ToObject>>().FirstOrDefault() ?? new List(); + } + else + { + nodeConfig = token["inbound_nodes"].ToObject>(); + } + + config.Layers.Add(new LayerConfig() + { + Config = args, + Name = token["name"].ToObject(), + ClassName = token["class_name"].ToObject(), + InboundNodes = nodeConfig, + }); + } + config.InputLayers = json["input_layers"].ToObject>(); + config.OutputLayers = json["output_layers"].ToObject>(); + return config; + } + + public static string to_snake_case(string name) + { + string intermediate = Regex.Replace(name, "(.)([A-Z][a-z0-9]+)", "$1_$2"); + string insecure = Regex.Replace(intermediate, "([a-z])([A-Z])", "$1_$2").ToLower(); + + if (insecure[0] != '_') + { + return insecure; + } + + return "private" + insecure; + } + + /// + /// Determines whether config appears to be a valid layer config. + /// + /// + /// + public static bool validate_config(JObject config) + { + return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs new file mode 100644 index 000000000..07d9f685e --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs @@ -0,0 +1,220 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Utils +{ + internal class layer_utils + { + public static void print_summary(Model model, int line_length = -1, float[] positions = null) + { + bool sequential_like = model is Sequential; + // || model.IsGraphNetwork; + + if (!sequential_like) + { + sequential_like = true; + var nodes = new List(); + + foreach (var v in model.NodesByDepth) + { + // if the model has multiple nodes + // or if the nodes have multiple inbound_layers + // the model is no longer sequential + if (v.Value.Count > 1 || (v.Value.Count == 1 && v.Value[0].KerasInputs.Count > 1)) + { + sequential_like = false; + break; + } + + nodes.AddRange(v.Value); + } + + if (sequential_like) + { + // search for shared layers + foreach (var layer in model.Layers) + { + var flag = false; + foreach (var node in layer.InboundNodes) + { + if (nodes.Contains(node)) + { + if (flag) + { + sequential_like = false; + break; + } + else + flag = true; + } + } + if (!sequential_like) + break; + } + } + } + + string[] to_display; + var relevant_nodes = new List(); + + if (sequential_like) + { + if (line_length < 0) + line_length = 65; + if (positions == null) + positions = new[] { 0.45f, 0.85f, 1.0f }; + if (positions.Last() <= 1) + positions = positions.Select(p => line_length * p).ToArray(); + to_display = new[] { "Layer (type)", "Output Shape", "Param #" }; + } + else + { + if (line_length < 0) + line_length = 98; + if (positions == null) + positions = new[] { 0.33f, 0.55f, 0.67f, 1.0f }; + if (positions.Last() <= 1) + positions = positions.Select(p => line_length * p).ToArray(); + to_display = new[] { "Layer (type)", "Output Shape", "Param #", "Connected to" }; + + foreach (var v in model.NodesByDepth) + relevant_nodes.AddRange(v.Value); + } + + int[] positions_int = positions.Select(x => Convert.ToInt32(x)).ToArray(); + print($"Model: {model.Name}"); + print(string.Join("", range(line_length).Select(x => "_"))); + print_row(to_display, positions_int); + print(string.Join("", range(line_length).Select(x => "="))); + + foreach (var (i, layer) in enumerate(model.Layers)) + { + if (sequential_like) + print_layer_summary(layer, positions_int); + else + print_layer_summary_with_connections(layer, positions_int, relevant_nodes); + if (i == model.Layers.Count - 1) + print(string.Join("", range(line_length).Select(x => "="))); + else + print(string.Join("", range(line_length).Select(x => "_"))); + } + + var trainable_count = count_params(model, model.TrainableVariables); + var non_trainable_count = count_params(model, model.NonTrainableVariables); + + print($"Total params: {trainable_count + non_trainable_count}"); + print($"Trainable params: {trainable_count}"); + print($"Non-trainable params: {non_trainable_count}"); + print(string.Join("", range(line_length).Select(x => "_"))); + } + + static void print_row(string[] fields, int[] positions) + { + var line = ""; + foreach (var i in range(fields.Length)) + { + if (i > 0) + line = line + " "; + line += fields[i]; + line = string.Join("", line.Take(positions[i])); + line += string.Join("", range(positions[i] - len(line)).Select(x => " ")); + } + print(line); + } + + /// + /// Prints a summary for a single layer. + /// + /// + static void print_layer_summary(ILayer layer, int[] positions) + { + var name = layer.Name; + + var fields = new string[] + { + $"{name} ({layer.GetType().Name})", + $"{layer.OutputShape}", + $"{layer.count_params()}" + }; + + print_row(fields, positions); + } + + static void print_layer_summary_with_connections(ILayer layer, int[] positions, List relevant_nodes) + { + var connections = new List(); + foreach (var node in layer.InboundNodes) + { + if (!relevant_nodes.Contains(node)) + continue; + + foreach (var (inbound_layer, node_index, tensor_index, _) in node.iterate_inbound()) + connections.append($"{inbound_layer.Name}[{node_index}][{tensor_index}]"); + } + + var name = layer.Name; + string first_connection = ""; + if (connections.Count > 0) + first_connection = connections[0]; + + var fields = new string[] + { + $"{name}({layer.GetType().Name})", + $"{layer.OutputShape}", + $"{layer.count_params()}", + first_connection + }; + + print_row(fields, positions); + + if (connections.Count > 1) + { + foreach (var i in range(1, connections.Count)) + { + fields = new string[] { "", "", "", connections[i] }; + print_row(fields, positions); + } + } + } + + public static int count_params(Layer layer, List weights) + { + var weight_shapes = weights.Select(x => x.shape).ToArray(); + var total = weight_shapes.Select(p => (int)p.size).Sum(); + return total; + } + + public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1) + { + if (layer == null) + (layer, node_index, _) = tensor.KerasHistory; + if (layer.InboundNodes == null || layer.InboundNodes.Count == 0) + return tensor; + else + { + var node = layer.InboundNodes[node_index]; + if (node.is_input) + return node.input_tensors; + else + { + var source_tensors = new List(); + foreach (var _layer in node.iterate_inbound()) + { + (layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4); + var previous_sources = get_source_inputs(tensor, layer, node_index); + foreach(var x in previous_sources) + { + // should be check if exist? + source_tensors.append(x); + } + } + return source_tensors; + } + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs new file mode 100644 index 000000000..9ba40ca04 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -0,0 +1,117 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Xml.Linq; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Utils +{ + public class losses_utils + { + public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) + { + return tf_with(ops.name_scope("weighted_loss"), scope => + { + if (sample_weight == null) + sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f); + var weighted_losses = math_ops.multiply(losses, sample_weight); + // Apply reduction function to the individual weighted losses. + var loss = reduce_weighted_loss(weighted_losses, reduction); + // Convert the result back to the input type. + // loss = math_ops.cast(loss, losses.dtype); + return loss; + }); + } + + public static (Tensor, Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) + { + var y_pred_shape = y_pred.shape; + var y_pred_rank = y_pred_shape.ndim; + if (y_true != null) + { + var y_true_shape = y_true.shape; + var y_true_rank = y_true_shape.ndim; + if (y_true_rank > -1 && y_pred_rank > -1) + { + if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1) + { + (y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred); + } + } + } + + if (sample_weight == null) + { + return (y_pred, y_true, sample_weight); + } + + var weights_shape = sample_weight.shape; + var weights_rank = weights_shape.ndim; + if (weights_rank == 0) + return (y_pred, y_true, sample_weight); + + if (y_pred_rank > -1 && weights_rank > -1) + { + if (weights_rank - y_pred_rank == 1) + { + sample_weight = tf.squeeze(sample_weight, -1); + } + else if (y_pred_rank - weights_rank == 1) + { + sample_weight = tf.expand_dims(sample_weight, -1); + } + + return (y_pred, y_true, sample_weight); + } + + throw new NotImplementedException(""); + } + + public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null) + { + return (labels, predictions); + } + + public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) + { + if (reduction == ReductionV2.NONE) + return weighted_losses; + else + { + var loss = math_ops.reduce_sum(weighted_losses); + if (reduction == ReductionV2.SUM_OVER_BATCH_SIZE) + loss = _safe_mean(loss, _num_elements(weighted_losses)); + return loss; + } + } + + static Tensor _safe_mean(Tensor losses, Tensor num_present) + { + var total_loss = math_ops.reduce_sum(losses); + return math_ops.div_no_nan(total_loss, num_present, name: "value"); + } + + static Tensor _num_elements(Tensor losses) + { + return tf_with(ops.name_scope("num_elements"), scope => + { + return math_ops.cast(array_ops.size(losses, name: scope), dtype: losses.dtype); + }); + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/np_utils.cs b/src/TensorFlowNET.Keras/Utils/np_utils.cs new file mode 100644 index 000000000..ef29b0464 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/np_utils.cs @@ -0,0 +1,31 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Utils +{ + public class np_utils + { + /// + /// Converts a class vector (integers) to binary class matrix. + /// + /// + /// + /// + /// + public static NDArray to_categorical(NDArray y, int num_classes = -1, TF_DataType dtype = TF_DataType.TF_FLOAT) + { + var y1 = y.astype(np.int32).ToArray(); + // var input_shape = y.shape[..^1]; + var categorical = np.zeros(((int)y.size, num_classes), dtype: dtype); + // categorical[np.arange(y.size), y] = 1; + for (var i = 0; i < (int)y.size; i++) + { + categorical[i, y1[i]] = 1.0f; + } + + return categorical; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Keras/Utils/tf_utils.cs new file mode 100644 index 000000000..ad31fd7ca --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/tf_utils.cs @@ -0,0 +1,98 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; + +namespace Tensorflow.Keras.Utils +{ + public class tf_utils + { + public static bool are_all_symbolic_tensors(Tensor[] tensors) + { + return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length; + } + + public static bool? constant_value(Tensor pred) + { + return smart_module.smart_constant_value(pred); + } + + public static bool is_symbolic_tensor(Tensor tensor) + { + return true; + } + + public static Tensor[] smart_cond(IVariableV1 pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return control_flow_ops.cond(pred.AsTensor(), + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return smart_module.smart_cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + + public static Tensor smart_cond(bool pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return smart_module.smart_cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } + + public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null) + { + throw new NotImplementedException("The function is waited to be implemented in the future."); + } + + public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null) + { + var spec = t; + if (!dynamic_batch) + { + return spec; + } + var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name); + var shape = dynamic_batch_spec.shape; + if(shape.rank > 0) + { + var shape_list = shape.as_int_list(); + // TODO(Rinne): check if -1 is equivalent to None in python. + shape_list[0] = -1; + dynamic_batch_spec.shape = new Shape(shape_list); + } + return dynamic_batch_spec; + } + } +} diff --git a/src/TensorFlowNET.Keras/defaultdict.cs b/src/TensorFlowNET.Keras/defaultdict.cs new file mode 100644 index 000000000..9c1f2df60 --- /dev/null +++ b/src/TensorFlowNET.Keras/defaultdict.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace System.Collections.Generic +{ + public class defaultdict : Dictionary where TValue : new() + { + public new TValue this[TKey key] + { + get + { + TValue val; + if (!TryGetValue(key, out val)) + { + val = default(TValue); + Add(key, val); + } + return val; + } + set { base[key] = value; } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/tf.layers.cs b/src/TensorFlowNET.Keras/tf.layers.cs new file mode 100644 index 000000000..da7c23471 --- /dev/null +++ b/src/TensorFlowNET.Keras/tf.layers.cs @@ -0,0 +1,248 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public class tensorflow_layers + { + public layers_internal layers { get; } = new layers_internal(); + + public class layers_internal + { + public Tensor conv2d(Tensor inputs, + int filters, + int[] kernel_size, + int[] strides = null, + string padding = "valid", + string data_format = "channels_last", + int[] dilation_rate = null, + bool use_bias = true, + Activation activation = null, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + bool trainable = true, + string name = null) + { + if (strides == null) + strides = new int[] { 1, 1 }; + if (dilation_rate == null) + dilation_rate = new int[] { 1, 1 }; + if (bias_initializer == null) + bias_initializer = tf.zeros_initializer; + + var layer = new Conv2D(new Conv2DArgs + { + Filters = filters, + KernelSize = kernel_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate, + Activation = activation, + UseBias = use_bias, + KernelInitializer = kernel_initializer, + BiasInitializer = bias_initializer, + Trainable = trainable, + Name = name + }); + + return layer.Apply(inputs); + } + + /// + /// Functional interface for the batch normalization layer. + /// http://arxiv.org/abs/1502.03167 + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensors batch_normalization(Tensor inputs, + int axis = -1, + float momentum = 0.99f, + float epsilon = 0.001f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null, + IInitializer moving_mean_initializer = null, + IInitializer moving_variance_initializer = null, + Tensor training = null, + bool trainable = true, + string name = null, + bool renorm = false, + float renorm_momentum = 0.99f) + { + var layer = new BatchNormalization(new BatchNormalizationArgs + { + Axis = axis, + Momentum = momentum, + Epsilon = epsilon, + Center = center, + Scale = scale, + BetaInitializer = beta_initializer, + GammaInitializer = gamma_initializer, + MovingMeanInitializer = moving_mean_initializer, + MovingVarianceInitializer = moving_variance_initializer, + Renorm = renorm, + RenormMomentum = renorm_momentum, + Trainable = trainable, + Name = name + }); + + return layer.Apply(inputs); + } + + /// + /// Max pooling layer for 2D inputs (e.g. images). + /// + /// The tensor over which to pool. Must have rank 4. + /// + /// + /// + /// + /// + /// + public Tensor MaxPooling2D(Tensor inputs, + int[] pool_size, + int[] strides, + string padding = "valid", + string data_format = "channels_last", + string name = null) + { + var layer = new MaxPooling2D(new MaxPooling2DArgs + { + PoolSize = pool_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + Name = name + }); + + return layer.Apply(inputs); + } + + /// + /// Densely-connected layer class. aka fully-connected

+ /// `outputs = activation(inputs * kernel + bias)` + ///
+ /// + /// Python integer, dimensionality of the output space. + /// + /// Boolean, whether the layer uses a bias. + /// + /// + /// + /// + /// + /// + public Tensor dense(Tensor inputs, + int units, + Activation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + bool trainable = true, + string name = null, + bool? reuse = null) + { + if (bias_initializer == null) + bias_initializer = tf.zeros_initializer; + + var layer = new Dense(new DenseArgs + { + Units = units, + Activation = activation, + UseBias = use_bias, + BiasInitializer = bias_initializer, + KernelInitializer = kernel_initializer, + Trainable = trainable, + Name = name + }); + + return layer.Apply(inputs); + } + + /// + /// Flattens an input tensor while preserving the batch axis (axis 0). + /// + /// Tensor input. + /// The name of the layer. + /// + /// A string, one of `channels_last` (default) or `channels_first`.

+ /// The ordering of the dimensions in the inputs.

+ /// `channels_last` corresponds to inputs with shape

+ /// `(batch, height, width, channels)` while `channels_first` corresponds to

+ /// inputs with shape `(batch, channels, height, width)`. + /// + /// + public Tensor flatten(Tensor inputs, + string name = null, + string data_format = "channels_last") + { + var input_shape = inputs.shape; + if (inputs.shape.ndim == 0) + throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); + + var premutation = new List() { 0 }; + if (data_format == "channels_first" && inputs.ndim > 1) + { + premutation.AddRange(Binding.range(2, inputs.ndim)); + premutation.Add(1); + inputs = array_ops.transpose(inputs, premutation.ToArray()); + } + + var ret = array_ops.reshape(inputs, compute_output_shape(input_shape)); + //ret.set_shape(compute_output_shape(ret.shape)); + return ret; + + int[] compute_output_shape(int[] inputshape) + { + if (inputshape == null || inputshape.Length == 0) + inputshape = new int[] { 1 }; + + if (inputshape.Skip(1).All(d => d > 0)) + { + int[] output_shape = new int[2]; + output_shape[0] = inputshape[0]; + output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc * rhs); //calculate size of all the rest dimensions + return output_shape; + } + else + return new int[] { inputshape[0], -1 }; //-1 == Binding.None + } + } + } + } +} diff --git a/src/TensorFlowNET.Keras/tf.optimizers.cs b/src/TensorFlowNET.Keras/tf.optimizers.cs new file mode 100644 index 000000000..aa61cfd96 --- /dev/null +++ b/src/TensorFlowNET.Keras/tf.optimizers.cs @@ -0,0 +1,28 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Keras +{ + public class tensorflow_backup + { + public KerasOptimizers optimizers => new KerasOptimizers(); + + public class KerasOptimizers + { + + } + } +} diff --git a/src/TensorFlowNET.Native/BUILD b/src/TensorFlowNET.Native/BUILD new file mode 100644 index 000000000..8a5815436 --- /dev/null +++ b/src/TensorFlowNET.Native/BUILD @@ -0,0 +1,19 @@ +# Description: +# CSharp Native Interface library intended for implementing the +# TensorFlow .NET Standard API using the TensorFlow C library. + +licenses(["notice"]) # Apache 2.0 + +cc_import( + name = "libtensorflow", + hdrs = ["c_api.h"], + interface_library = "libtensorflow.lib", + shared_library = "libtensorflow.dll", +) + +cc_binary( + name = "csni", + srcs = ["csni.cc"], + deps = [":libtensorflow"], + linkstatic = 0 +) diff --git a/docs/The-Definitive-Guide/CH_2 Constant.md b/src/TensorFlowNET.Native/WORKSPACE similarity index 100% rename from docs/The-Definitive-Guide/CH_2 Constant.md rename to src/TensorFlowNET.Native/WORKSPACE diff --git a/src/TensorFlowNET.Native/csni.cc b/src/TensorFlowNET.Native/csni.cc new file mode 100644 index 000000000..40dfe7dc2 --- /dev/null +++ b/src/TensorFlowNET.Native/csni.cc @@ -0,0 +1,21 @@ +#include +#include +#include "c_api.h" +typedef char* (__stdcall *TFFunc)(); + +int main() { + HINSTANCE hinstLib = LoadLibrary(TEXT("libtensorflow.dll")); + if (!hinstLib) { + std::cout << "could not load the dynamic library" << std::endl; + return EXIT_FAILURE; + } + + TFFunc version = (TFFunc) GetProcAddress(hinstLib, "TF_Version"); + if (!version) { + std::cout << "could not locate the function" << std::endl; + return EXIT_FAILURE; + } + + printf("Hello from TensorFlow C library version %s", version()); + return 0; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj new file mode 100644 index 000000000..e3374f958 --- /dev/null +++ b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj @@ -0,0 +1,22 @@ + + + + netstandard2.0 + 0.0.1 + TensorFlow Recommenders is a library for building recommender system models using TensorFlow. + LICENSE + AnyCPU;x64 + + + + + True + + + + + + + + + diff --git a/src/TensorFlowNET.Text/Enums/Reduction.cs b/src/TensorFlowNET.Text/Enums/Reduction.cs new file mode 100644 index 000000000..aa7252290 --- /dev/null +++ b/src/TensorFlowNET.Text/Enums/Reduction.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text +{ + public enum Reduction + { + None, + STRING_JOIN + } +} diff --git a/src/TensorFlowNET.Text/Enums/WordShape.cs b/src/TensorFlowNET.Text/Enums/WordShape.cs new file mode 100644 index 000000000..c11173122 --- /dev/null +++ b/src/TensorFlowNET.Text/Enums/WordShape.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text +{ + public enum WordShape + { + HAS_TITLE_CASE, + IS_UPPERCASE, + HAS_SOME_PUNCT_OR_SYMBOL, + IS_NUMERIC_VALUE + } +} diff --git a/src/TensorFlowNET.Text/Operations/TextOps.ngrams.cs b/src/TensorFlowNET.Text/Operations/TextOps.ngrams.cs new file mode 100644 index 000000000..0ea953dd4 --- /dev/null +++ b/src/TensorFlowNET.Text/Operations/TextOps.ngrams.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text +{ + public partial class TextOps + { + public static Tensor ngrams(Tensor input, int width, + int axis = -1, + Reduction reduction_type = Reduction.None, + string string_separator = " ", + string name = null) + => throw new NotImplementedException(""); + } +} diff --git a/src/TensorFlowNET.Text/Operations/TextOps.wordshape.cs b/src/TensorFlowNET.Text/Operations/TextOps.wordshape.cs new file mode 100644 index 000000000..b0b2bf4fb --- /dev/null +++ b/src/TensorFlowNET.Text/Operations/TextOps.wordshape.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text +{ + public partial class TextOps + { + public static Tensor wordshape(Tensor input, WordShape pattern, string name = null) + => throw new NotImplementedException(""); + } +} diff --git a/src/TensorFlowNET.Text/Tensorflow.Text.csproj b/src/TensorFlowNET.Text/Tensorflow.Text.csproj new file mode 100644 index 000000000..f27f680e2 --- /dev/null +++ b/src/TensorFlowNET.Text/Tensorflow.Text.csproj @@ -0,0 +1,32 @@ + + + + netstandard2.0 + Tensorflow.Text + Tensorflow.Text + true + 0.0.1 + LICENSE + AnyCPU;x64 + + + + DEBUG;TRACE + + + + DEBUG;TRACE + + + + + True + + + + + + + + + diff --git a/src/TensorFlowNET.Text/TextApi.cs b/src/TensorFlowNET.Text/TextApi.cs new file mode 100644 index 000000000..68a9c740f --- /dev/null +++ b/src/TensorFlowNET.Text/TextApi.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Text; + +namespace Tensorflow +{ + public class TextApi + { + public static TextInterface text { get; } = new TextInterface(); + } +} diff --git a/src/TensorFlowNET.Text/TextInterface.cs b/src/TensorFlowNET.Text/TextInterface.cs new file mode 100644 index 000000000..a631bd570 --- /dev/null +++ b/src/TensorFlowNET.Text/TextInterface.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Text.Tokenizers; + +namespace Tensorflow.Text +{ + public class TextInterface + { + public ITokenizer WhitespaceTokenizer() + => new WhitespaceTokenizer(); + + public Tensor wordshape(Tensor input, WordShape pattern, string name = null) + => TextOps.wordshape(input, pattern, name: name); + + /// + /// Create a tensor of n-grams based on the input data `data`. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor ngrams(Tensor input, int width, + int axis = -1, + Reduction reduction_type = Reduction.None, + string string_separator = " ", + string name = null) + => TextOps.ngrams(input, width, + axis: axis, + reduction_type: reduction_type, + string_separator: string_separator, + name: name); + } +} diff --git a/src/TensorFlowNET.Text/Tokenizers/ITokenizer.cs b/src/TensorFlowNET.Text/Tokenizers/ITokenizer.cs new file mode 100644 index 000000000..8b585d4df --- /dev/null +++ b/src/TensorFlowNET.Text/Tokenizers/ITokenizer.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text.Tokenizers +{ + public interface ITokenizer + { + Tensor tokenize(Tensor input); + } +} diff --git a/src/TensorFlowNET.Text/Tokenizers/UnicodeScriptTokenizer.cs b/src/TensorFlowNET.Text/Tokenizers/UnicodeScriptTokenizer.cs new file mode 100644 index 000000000..c9c84525b --- /dev/null +++ b/src/TensorFlowNET.Text/Tokenizers/UnicodeScriptTokenizer.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Text.Tokenizers +{ + public class UnicodeScriptTokenizer : ITokenizer + { + public Tensor tokenize(Tensor input) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs new file mode 100644 index 000000000..46231546e --- /dev/null +++ b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs @@ -0,0 +1,45 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Text.Tokenizers +{ + public class WhitespaceTokenizer : ITokenizer + { + /// + /// Tokenizes a tensor of UTF-8 strings on whitespaces. + /// + /// + /// + public Tensor tokenize(Tensor input) + { + tokenize_with_offsets(input); + throw new NotImplementedException(""); + } + + Tensor[] tokenize_with_offsets(Tensor input) + { + tf_with(ops.name_scope(null, "WhitespaceTokenize"), scope => + { + _whitespace_tokenize_with_offsets_encode_decode_wrapper(input); + }); + throw new NotImplementedException(""); + } + + Tensor _whitespace_tokenize_with_offsets_encode_decode_wrapper(Tensor input_tensor) + { + // Decode the strings and get byte offsets + var (codepoints, byte_start_offsets) = tf.strings.unicode_decode_with_offsets(input_tensor, "UTF-8"); + var byte_end_offsets = array_ops.concat(new Tensor[] + { + byte_start_offsets[Slice.All, new Slice(1)], + math_ops.cast( + array_ops.expand_dims(tf.strings.string_length(input_tensor), 1), + dtypes.int64) + }, 1); + return input_tensor; + } + } +} diff --git a/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs b/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs new file mode 100644 index 000000000..f3e1b9723 --- /dev/null +++ b/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs @@ -0,0 +1,57 @@ +using System.IO; +using System.Threading.Tasks; + +namespace Tensorflow.Hub +{ + public class GcsCompressedFileResolver : IResolver + { + const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; + public string Call(string handle) + { + var module_dir = _module_dir(handle); + + return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) + .GetAwaiter().GetResult(); + } + public bool IsSupported(string handle) + { + return handle.StartsWith("gs://") && _is_tarfile(handle); + } + + private async Task download(string handle, string tmp_dir) + { + new resolver.DownloadManager(handle).download_and_uncompress( + new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); + await Task.Run(() => { }); + } + + private static string _module_dir(string handle) + { + var cache_dir = resolver.tfhub_cache_dir(use_temp: true); + var sha1 = ComputeSha1(handle); + return resolver.create_local_module_dir(cache_dir, sha1); + } + + private static bool _is_tarfile(string filename) + { + return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); + } + + private static string ComputeSha1(string s) + { + using (var sha = new System.Security.Cryptography.SHA1Managed()) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(s); + var hash = sha.ComputeHash(bytes); + var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); + + foreach (var b in hash) + { + stringBuilder.Append(b.ToString("x2")); + } + + return stringBuilder.ToString(); + } + } + } +} diff --git a/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs b/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs new file mode 100644 index 000000000..a127b28c0 --- /dev/null +++ b/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs @@ -0,0 +1,78 @@ +using System; +using System.Net.Http; +using System.Threading.Tasks; + +namespace Tensorflow.Hub +{ + public class HttpCompressedFileResolver : HttpResolverBase + { + const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes + + private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = + ("tf-hub-format", "compressed"); + + private static string _module_dir(string handle) + { + var cache_dir = resolver.tfhub_cache_dir(use_temp: true); + var sha1 = ComputeSha1(handle); + return resolver.create_local_module_dir(cache_dir, sha1); + } + + public override bool IsSupported(string handle) + { + if (!is_http_protocol(handle)) + { + return false; + } + var load_format = resolver.model_load_format(); + return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) + || load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); + } + + public override string Call(string handle) + { + var module_dir = _module_dir(handle); + + return resolver.atomic_download_async( + handle, + download, + module_dir, + LOCK_FILE_TIMEOUT_SEC + ).GetAwaiter().GetResult(); + } + + private async Task download(string handle, string tmp_dir) + { + var client = new HttpClient(); + + var response = await client.GetAsync(_append_compressed_format_query(handle)); + + using (var httpStream = await response.Content.ReadAsStreamAsync()) + { + new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); + } + } + + private string _append_compressed_format_query(string handle) + { + return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); + } + + private static string ComputeSha1(string s) + { + using (var sha = new System.Security.Cryptography.SHA1Managed()) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(s); + var hash = sha.ComputeHash(bytes); + var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); + + foreach (var b in hash) + { + stringBuilder.Append(b.ToString("x2")); + } + + return stringBuilder.ToString(); + } + } + } +} diff --git a/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs b/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs new file mode 100644 index 000000000..09a497484 --- /dev/null +++ b/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs @@ -0,0 +1,65 @@ +using System; +using System.Net; + +namespace Tensorflow.Hub +{ + public class HttpUncompressedFileResolver : HttpResolverBase + { + private readonly PathResolver _pathResolver; + + public HttpUncompressedFileResolver() + { + _pathResolver = new PathResolver(); + } + + public override string Call(string handle) + { + handle = AppendUncompressedFormatQuery(handle); + var gsLocation = RequestGcsLocation(handle); + return _pathResolver.Call(gsLocation); + } + + public override bool IsSupported(string handle) + { + if (!is_http_protocol(handle)) + { + return false; + } + + var load_format = resolver.model_load_format(); + return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); + } + + protected virtual string AppendUncompressedFormatQuery(string handle) + { + return append_format_query(handle, ("tf-hub-format", "uncompressed")); + } + + protected virtual string RequestGcsLocation(string handleWithParams) + { + var request = WebRequest.Create(handleWithParams); + var response = request.GetResponse() as HttpWebResponse; + + if (response == null) + { + throw new Exception("Failed to get a response from the server."); + } + + var statusCode = (int)response.StatusCode; + + if (statusCode != 303) + { + throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); + } + + var location = response.Headers["Location"]; + + if (!location.StartsWith("gs://")) + { + throw new Exception($"Expected Location:GS path but received {location}"); + } + + return location; + } + } +} \ No newline at end of file diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs new file mode 100644 index 000000000..20d9851b1 --- /dev/null +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -0,0 +1,158 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using static Tensorflow.Binding; + +namespace Tensorflow.Hub +{ + public class KerasLayer : Layer + { + private string _handle; + private LoadOptions? _load_options; + private Trackable _func; + private Func _callable; + + public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : + base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) + { + _handle = handle; + _load_options = load_options; + + _func = load_module(_handle, _load_options); + _track_trackable(_func, "_func"); + // TODO(Rinne): deal with _is_hub_module_v1. + + _callable = _get_callable(); + _setup_layer(trainable); + } + + private void _setup_layer(bool trainable = false) + { + HashSet trainable_variables; + if (_func is Layer layer) + { + foreach (var v in layer.TrainableVariables) + { + _add_existing_weight(v, true); + } + trainable_variables = new HashSet(layer.TrainableVariables.Select(v => v.UniqueId)); + } + else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable trackables) + { + foreach (var trackable in trackables) + { + if (trackable is IVariableV1 v) + { + _add_existing_weight(v, true); + } + } + trainable_variables = new HashSet(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); + } + else + { + trainable_variables = new HashSet(); + } + + if (_func is Layer) + { + layer = (Layer)_func; + foreach (var v in layer.Variables) + { + if (!trainable_variables.Contains(v.UniqueId)) + { + _add_existing_weight(v, false); + } + } + } + else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable total_trackables) + { + foreach (var trackable in total_trackables) + { + if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) + { + _add_existing_weight(v, false); + } + } + } + + if (_func.CustomizedFields.ContainsKey("regularization_losses")) + { + if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) + { + throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + + "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); + } + } + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null) + { + _check_trainability(); + + // TODO(Rinne): deal with training_argument + + var result = _callable(inputs); + + return _apply_output_shape_if_set(inputs, result); + } + + private void _check_trainability() + { + if (!Trainable) return; + + // TODO(Rinne): deal with _is_hub_module_v1 and signature + + if (TrainableWeights is null || TrainableWeights.Count == 0) + { + tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); + } + } + + private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) + { + // TODO(Rinne): implement it. + return result; + } + + private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) + { + bool is_trainable; + if (trainable is null) + { + is_trainable = weight.Trainable; + } + else + { + is_trainable = trainable.Value; + } + add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); + } + + private Func _get_callable() + { + if (_func is Layer layer) + { + return x => layer.Apply(x); + } + if (_func.CustomizedFields.ContainsKey("__call__")) + { + if (_func.CustomizedFields["__call__"] is RestoredFunction function) + { + return x => function.Apply(x); + } + } + throw new ValueError("Cannot get the callable from the model."); + } + + private static Trackable load_module(string handle, LoadOptions? load_options = null) + { + //var set_load_options = load_options ?? LoadContext.get_load_option(); + return module_v2.load(handle, load_options); + } + } +} diff --git a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj new file mode 100644 index 000000000..efa37598d --- /dev/null +++ b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj @@ -0,0 +1,36 @@ + + + + netstandard2.0;net6 + 10 + enable + 1.0.0 + TensorFlow.Hub + Apache2.0 + true + true + Yaohui Liu, Haiping Chen + SciSharp STACK + true + Apache 2.0, Haiping Chen $([System.DateTime]::UtcNow.ToString(yyyy)) + https://github.com/SciSharp/TensorFlow.NET + git + http://scisharpstack.org + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow, SciSharp, Machine Learning, Deep Learning, Transfer Learning, TensorFlow Hub, TensorFlow.NET, TF.NET, AI + packages + + Google's TensorFlow Hub full binding in .NET Standard. + A library for transfer learning with TensorFlow.NET. + + + + + + + + + + + + diff --git a/src/TensorflowNET.Hub/file_utils.cs b/src/TensorflowNET.Hub/file_utils.cs new file mode 100644 index 000000000..3e959afef --- /dev/null +++ b/src/TensorflowNET.Hub/file_utils.cs @@ -0,0 +1,74 @@ +using SharpCompress.Common; +using SharpCompress.Readers; +using System; +using System.IO; + +namespace Tensorflow.Hub +{ + internal static class file_utils + { + //public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action logFunction = null) + //{ + // using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) + // { + // if (src is null) + // { + // return; + // } + + // using (var dst = File.Create(dstPath)) + // { + // var buffer = new byte[bufferSize]; + // int count; + + // while ((count = src.Read(buffer, 0, buffer.Length)) > 0) + // { + // dst.Write(buffer, 0, count); + // logFunction?.Invoke(count); + // } + // } + // } + //} + + public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action logFunction = null) + { + using (IReader reader = ReaderFactory.Open(fileobj)) + { + while (reader.MoveToNextEntry()) + { + if (!reader.Entry.IsDirectory) + { + reader.WriteEntryToDirectory( + dst_path, + new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } + ); + } + } + } + } + + public static string merge_relative_path(string dstPath, string relPath) + { + var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); + + if (cleanRelPath == ".") + { + return dstPath; + } + + if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) + { + throw new InvalidDataException($"Relative path '{relPath}' is invalid."); + } + + var merged = Path.Combine(dstPath, cleanRelPath); + + if (!merged.StartsWith(dstPath)) + { + throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); + } + + return merged; + } + } +} diff --git a/src/TensorflowNET.Hub/hub.cs b/src/TensorflowNET.Hub/hub.cs new file mode 100644 index 000000000..4fefe0cc2 --- /dev/null +++ b/src/TensorflowNET.Hub/hub.cs @@ -0,0 +1,17 @@ +using Tensorflow.Hub; + +namespace Tensorflow +{ + public static class HubAPI + { + public static HubMethods hub { get; } = new HubMethods(); + } + + public class HubMethods + { + public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) + { + return new KerasLayer(handle, trainable, load_options); + } + } +} diff --git a/src/TensorflowNET.Hub/module_v2.cs b/src/TensorflowNET.Hub/module_v2.cs new file mode 100644 index 000000000..a8e67311b --- /dev/null +++ b/src/TensorflowNET.Hub/module_v2.cs @@ -0,0 +1,33 @@ +using System.IO; +using Tensorflow.Train; + +namespace Tensorflow.Hub +{ + internal static class module_v2 + { + public static Trackable load(string handle, LoadOptions? options) + { + var module_path = resolve(handle); + + // TODO(Rinne): deal with is_hub_module_v1 + + var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); + var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); + if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) + && !Directory.Exists(saved_model_pb_txt_path)) + { + throw new ValueError($"Trying to load a model of incompatible/unknown type. " + + $"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + + $"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); + } + + var obj = Loader.load(module_path, options: options); + return obj; + } + + public static string resolve(string handle) + { + return MultiImplRegister.GetResolverRegister().Call(handle); + } + } +} diff --git a/src/TensorflowNET.Hub/registry.cs b/src/TensorflowNET.Hub/registry.cs new file mode 100644 index 000000000..cdc4589b2 --- /dev/null +++ b/src/TensorflowNET.Hub/registry.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Hub +{ + internal class MultiImplRegister + { + private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); + private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); + + static MultiImplRegister() + { + resolver.add_implementation(new PathResolver()); + resolver.add_implementation(new HttpUncompressedFileResolver()); + resolver.add_implementation(new GcsCompressedFileResolver()); + resolver.add_implementation(new HttpCompressedFileResolver()); + } + + string _name; + List _impls; + public MultiImplRegister(string name, IEnumerable impls) + { + _name = name; + _impls = impls.ToList(); + } + + public void add_implementation(IResolver resolver) + { + _impls.Add(resolver); + } + + public string Call(string handle) + { + foreach (var impl in _impls.Reverse()) + { + if (impl.IsSupported(handle)) + { + return impl.Call(handle); + } + } + throw new RuntimeError($"Cannot resolve the handle {handle}"); + } + + public static MultiImplRegister GetResolverRegister() + { + return resolver; + } + + public static MultiImplRegister GetLoaderRegister() + { + return loader; + } + } +} diff --git a/src/TensorflowNET.Hub/resolver.cs b/src/TensorflowNET.Hub/resolver.cs new file mode 100644 index 000000000..2f8c45ba6 --- /dev/null +++ b/src/TensorflowNET.Hub/resolver.cs @@ -0,0 +1,580 @@ +using ICSharpCode.SharpZipLib.Tar; +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Security; +using System.Security.Authentication; +using System.Threading.Tasks; +using System.Web; +using static Tensorflow.Binding; + +namespace Tensorflow.Hub +{ + internal static class resolver + { + public enum ModelLoadFormat + { + [Description("COMPRESSED")] + COMPRESSED, + [Description("UNCOMPRESSED")] + UNCOMPRESSED, + [Description("AUTO")] + AUTO + } + public class DownloadManager + { + private readonly string _url; + private double _last_progress_msg_print_time; + private long _total_bytes_downloaded; + private int _max_prog_str; + + private bool _interactive_mode() + { + return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); + } + + private void _print_download_progress_msg(string msg, bool flush = false) + { + if (_interactive_mode()) + { + // Print progress message to console overwriting previous progress + // message. + _max_prog_str = Math.Max(_max_prog_str, msg.Length); + Console.Write($"\r{msg.PadRight(_max_prog_str)}"); + Console.Out.Flush(); + + //如果flush参数为true,则输出换行符减少干扰交互式界面。 + if (flush) + Console.WriteLine(); + + } + else + { + // Interactive progress tracking is disabled. Print progress to the + // standard TF log. + tf.Logger.Information(msg); + } + } + + private void _log_progress(long bytes_downloaded) + { + // Logs progress information about ongoing module download. + + _total_bytes_downloaded += bytes_downloaded; + var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; + if (_interactive_mode() || now - _last_progress_msg_print_time > 15) + { + // Print progress message every 15 secs or if interactive progress + // tracking is enabled. + _print_download_progress_msg($"Downloading {_url}:" + + $"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); + _last_progress_msg_print_time = now; + } + } + + public DownloadManager(string url) + { + _url = url; + _last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; + _total_bytes_downloaded = 0; + _max_prog_str = 0; + } + + public void download_and_uncompress(Stream fileobj, string dst_path) + { + // Streams the content for the 'fileobj' and stores the result in dst_path. + + try + { + file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); + var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); + _print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); + } + catch (TarException ex) + { + throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); + } + } + } + private static Dictionary _flags = new(); + private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; + private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; + private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; + private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; + private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; + + static resolver() + { + set_new_flag("tfhub_model_load_format", "AUTO"); + set_new_flag("tfhub_cache_dir", null); + } + + public static string model_load_format() + { + return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); + } + + public static string? get_env_setting(string env_var, string flag_name) + { + string value = System.Environment.GetEnvironmentVariable(env_var); + if (string.IsNullOrEmpty(value)) + { + if (_flags.ContainsKey(flag_name)) + { + return _flags[flag_name]; + } + else + { + return null; + } + } + else + { + return value; + } + } + + public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) + { + var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; + if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) + { + // Place all TF-Hub modules under /tfhub_modules. + cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); + } + if (!string.IsNullOrWhiteSpace(cache_dir)) + { + Console.WriteLine("Using {0} to cache modules.", cache_dir); + } + return cache_dir; + } + + public static string create_local_module_dir(string cache_dir, string module_name) + { + Directory.CreateDirectory(cache_dir); + return Path.Combine(cache_dir, module_name); + } + + public static void set_new_flag(string name, string value) + { + string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, + _TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; + if (!tokens.Contains(name)) + { + tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + + "may not affect anything in tensorflow.hub."); + } + _flags[name] = value; + } + + public static string _merge_relative_path(string dstPath, string relPath) + { + return file_utils.merge_relative_path(dstPath, relPath); + } + + public static string _module_descriptor_file(string moduleDir) + { + return $"{moduleDir}.descriptor.txt"; + } + + public static void _write_module_descriptor_file(string handle, string moduleDir) + { + var readme = _module_descriptor_file(moduleDir); + var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; + tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); + } + + public static string _lock_file_contents(string task_uid) + { + return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; + } + + public static string _lock_filename(string moduleDir) + { + return tf_utils.absolute_path(moduleDir) + ".lock"; + } + + private static string _module_dir(string lockFilename) + { + var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); + if (!string.IsNullOrEmpty(path)) + { + return Path.Combine(path, "hub_modules"); + } + + throw new Exception("Unable to resolve hub_modules directory from lock file name."); + } + + private static string _task_uid_from_lock_file(string lockFilename) + { + // Returns task UID of the task that created a given lock file. + var lockstring = File.ReadAllText(lockFilename); + return lockstring.Split('.').Last(); + } + + private static string _temp_download_dir(string moduleDir, string taskUid) + { + // Returns the name of a temporary directory to download module to. + return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; + } + + private static long _dir_size(string directory) + { + // Returns total size (in bytes) of the given 'directory'. + long size = 0; + foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) + { + var stat = new FileInfo(elem); + size += stat.Length; + if ((stat.Attributes & FileAttributes.Directory) != 0) + size += _dir_size(stat.FullName); + } + return size; + } + + public static long _locked_tmp_dir_size(string lockFilename) + { + //Returns the size of the temp dir pointed to by the given lock file. + var taskUid = _task_uid_from_lock_file(lockFilename); + try + { + return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); + } + catch (DirectoryNotFoundException) + { + return 0; + } + } + + private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) + { + long? lockedTmpDirSize = null; + var lockedTmpDirSizeCheckTime = DateTime.Now; + var lockFileContent = ""; + + while (File.Exists(lockFile)) + { + try + { + Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); + + if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) + { + var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); + var curLockFileContent = File.ReadAllText(lockFile); + + if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) + { + Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); + File.Delete(lockFile); + break; + } + + lockedTmpDirSize = curLockedTmpDirSize; + lockedTmpDirSizeCheckTime = DateTime.Now; + lockFileContent = curLockFileContent; + } + } + catch (FileNotFoundException) + { + // Lock file or temp directory were deleted during check. Continue + // to check whether download succeeded or we need to start our own + // download. + } + + System.Threading.Thread.Sleep(5000); + } + } + + public static async Task atomic_download_async( + string handle, + Func downloadFn, + string moduleDir, + int lock_file_timeout_sec = 10 * 60) + { + var lockFile = _lock_filename(moduleDir); + var taskUid = Guid.NewGuid().ToString("N"); + var lockContents = _lock_file_contents(taskUid); + var tmpDir = _temp_download_dir(moduleDir, taskUid); + + // Function to check whether model has already been downloaded. + Func checkModuleExists = () => + Directory.Exists(moduleDir) && + Directory.EnumerateFileSystemEntries(moduleDir).Any(); + + // Check whether the model has already been downloaded before locking + // the destination path. + if (checkModuleExists()) + { + return moduleDir; + } + + // Attempt to protect against cases of processes being cancelled with + // KeyboardInterrupt by using a try/finally clause to remove the lock + // and tmp_dir. + while (true) + { + try + { + tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); + // Must test condition again, since another process could have created + // the module and deleted the old lock file since last test. + if (checkModuleExists()) + { + // Lock file will be deleted in the finally-clause. + return moduleDir; + } + if (Directory.Exists(moduleDir)) + { + Directory.Delete(moduleDir, true); + } + break; // Proceed to downloading the module. + } + // These errors are believed to be permanent problems with the + // module_dir that justify failing the download. + catch (FileNotFoundException) + { + throw; + } + catch (UnauthorizedAccessException) + { + throw; + } + catch (IOException) + { + throw; + } + // All other errors are retried. + // TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write + // should be good enough, but see discussion about misc filesystem types. + // TODO(b/144475403): How atomic is the overwrite=False check? + catch (Exception) + { + } + + // Wait for lock file to disappear. + _wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); + // At this point we either deleted a lock or a lock got removed by the + // owner or another process. Perform one more iteration of the while-loop, + // we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or + // because we would obtain a lock ourselves, or wait again for the lock to + // disappear. + } + + // Lock file acquired. + tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); + Directory.CreateDirectory(tmpDir); + await downloadFn(handle, tmpDir); + // Write module descriptor to capture information about which module was + // downloaded by whom and when. The file stored at the same level as a + // directory in order to keep the content of the 'model_dir' exactly as it + // was define by the module publisher. + // + // Note: The descriptor is written purely to help the end-user to identify + // which directory belongs to which module. The descriptor is not part of the + // module caching protocol and no code in the TF-Hub library reads its + // content. + _write_module_descriptor_file(handle, moduleDir); + try + { + Directory.Move(tmpDir, moduleDir); + Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); + } + catch (IOException e) + { + Console.WriteLine(e.Message); + Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); + // Keep the temp directory so we will retry building vocabulary later. + } + + // Temp directory is owned by the current process, remove it. + try + { + Directory.Delete(tmpDir, true); + } + catch (DirectoryNotFoundException) + { + } + + // Lock file exists and is owned by this process. + try + { + var contents = File.ReadAllText(lockFile); + if (contents == lockContents) + { + File.Delete(lockFile); + } + } + catch (Exception) + { + } + + return moduleDir; + } + } + internal interface IResolver + { + string Call(string handle); + bool IsSupported(string handle); + } + + internal class PathResolver : IResolver + { + public string Call(string handle) + { + if (!File.Exists(handle) && !Directory.Exists(handle)) + { + throw new IOException($"{handle} does not exist in file system."); + } + return handle; + } + public bool IsSupported(string handle) + { + return true; + } + } + + public abstract class HttpResolverBase : IResolver + { + private readonly HttpClient httpClient; + private SslProtocol sslProtocol; + private RemoteCertificateValidationCallback certificateValidator; + + protected HttpResolverBase() + { + httpClient = new HttpClient(); + _maybe_disable_cert_validation(); + } + + public abstract string Call(string handle); + public abstract bool IsSupported(string handle); + + protected async Task GetLocalFileStreamAsync(string filePath) + { + try + { + var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); + return await Task.FromResult(fs); + } + catch (Exception ex) + { + Console.WriteLine($"Failed to read file stream: {ex.Message}"); + return null; + } + } + + protected async Task GetFileStreamAsync(string filePath) + { + if (!is_http_protocol(filePath)) + { + // If filePath is not an HTTP(S) URL, delegate to a file resolver. + return await GetLocalFileStreamAsync(filePath); + } + + var request = new HttpRequestMessage(HttpMethod.Get, filePath); + var response = await _call_urlopen(request); + + if (response.IsSuccessStatusCode) + { + return await response.Content.ReadAsStreamAsync(); + } + else + { + Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); + return null; + } + } + + protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) + { + sslProtocol = protocol; + certificateValidator = validator; + } + + public static string append_format_query(string handle, (string, string) formatQuery) + { + var parsed = new Uri(handle); + + var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); + queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); + + parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, + "?" + queryBuilder.ToString()).Uri; + + return parsed.ToString(); + } + + protected bool is_http_protocol(string handle) + { + return handle.StartsWith("http://") || handle.StartsWith("https://"); + } + + protected async Task _call_urlopen(HttpRequestMessage request) + { + if (sslProtocol != null) + { + var handler = new HttpClientHandler() + { + SslProtocols = sslProtocol.AsEnum(), + }; + if (certificateValidator != null) + { + handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => + { + return certificateValidator(x, y, z, w); + }; + } + + var client = new HttpClient(handler); + return await client.SendAsync(request); + } + else + { + return await httpClient.SendAsync(request); + } + } + + protected void _maybe_disable_cert_validation() + { + if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") + { + ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; + Console.WriteLine("Disabled certificate validation for resolving handles."); + } + } + } + + public class SslProtocol + { + private readonly string protocolString; + + public static readonly SslProtocol Tls = new SslProtocol("TLS"); + public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); + public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); + + private SslProtocol(string protocolString) + { + this.protocolString = protocolString; + } + + public SslProtocols AsEnum() + { + switch (protocolString.ToUpper()) + { + case "TLS": + return SslProtocols.Tls; + case "TLS 1.1": + return SslProtocols.Tls11; + case "TLS 1.2": + return SslProtocols.Tls12; + default: + throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); + } + } + } +} diff --git a/src/TensorflowNET.Hub/tf_utils.cs b/src/TensorflowNET.Hub/tf_utils.cs new file mode 100644 index 000000000..96d8c92d6 --- /dev/null +++ b/src/TensorflowNET.Hub/tf_utils.cs @@ -0,0 +1,80 @@ +using System; +using System.IO; + +namespace Tensorflow.Hub +{ + internal class tf_utils + { + public static string bytes_to_readable_str(long? numBytes, bool includeB = false) + { + if (numBytes == null) return numBytes.ToString(); + + var num = (double)numBytes; + + if (num < 1024) + { + return $"{(long)num}{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + if (num < 1024) + { + return $"{num:F2}k{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + if (num < 1024) + { + return $"{num:F2}M{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + return $"{num:F2}G{(includeB ? "B" : "")}"; + } + + public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) + { + var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; + + using (var fileStream = new FileStream(tempPath, FileMode.Create)) + { + using (var writer = new StreamWriter(fileStream)) + { + writer.Write(contents); + writer.Flush(); + } + } + + try + { + if (File.Exists(filename)) + { + if (overwrite) + { + File.Delete(filename); + File.Move(tempPath, filename); + } + } + else + { + File.Move(tempPath, filename); + } + } + catch + { + File.Delete(tempPath); + throw; + } + } + + public static string absolute_path(string path) + { + if (path.Contains("://")) + { + return path; + } + + return Path.GetFullPath(path); + } + } +} diff --git a/src/python/.vscode/launch.json b/src/python/.vscode/launch.json new file mode 100644 index 000000000..4d4e27495 --- /dev/null +++ b/src/python/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/xor_keras.py", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/src/python/simple_rnn.py b/src/python/simple_rnn.py new file mode 100644 index 000000000..c5f3b1f2c --- /dev/null +++ b/src/python/simple_rnn.py @@ -0,0 +1,17 @@ +import numpy as np +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +# tf.experimental.numpy +inputs = np.arange(6 * 10 * 8).reshape([6, 10, 8]).astype(np.float32) +# simple_rnn = tf.keras.layers.SimpleRNN(4) + +# output = simple_rnn(inputs) # The output has shape `[6, 4]`. + +simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences=True, return_state=True) + +# whole_sequence_output has shape `[6, 10, 4]`. +# final_state has shape `[6, 4]`. +whole_sequence_output, final_state = simple_rnn(inputs) +print(whole_sequence_output) +print(final_state) \ No newline at end of file diff --git a/src/python/subclassing.py b/src/python/subclassing.py new file mode 100644 index 000000000..bccbef292 --- /dev/null +++ b/src/python/subclassing.py @@ -0,0 +1,154 @@ +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow.keras import Model, layers +import numpy as np + +# MNIST dataset parameters. +num_classes = 10 # total classes (0-9 digits). + +# Training parameters. +learning_rate = 0.001 +training_steps = 100 +batch_size = 128 +display_step = 10 + +# Network parameters. +conv1_filters = 32 # number of filters for 1st conv layer. +conv2_filters = 64 # number of filters for 2nd conv layer. +fc1_units = 1024 # number of neurons for 1st fully-connected layer. + +# Prepare MNIST data. +from tensorflow.keras.datasets import mnist +(x_train, y_train), (x_test, y_test) = mnist.load_data() +# Convert to float32. +x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32) +# Normalize images value from [0, 255] to [0, 1]. +x_train, x_test = x_train / 255., x_test / 255. + +# Use tf.data API to shuffle and batch data. +train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1) + +# Create TF Model. +class ConvNet(Model): + # Set layers. + def __init__(self): + super(ConvNet, self).__init__() + # Convolution Layer with 32 filters and a kernel size of 5. + self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu) + # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. + self.maxpool1 = layers.MaxPool2D(2, strides=2) + + # Convolution Layer with 64 filters and a kernel size of 3. + self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu) + # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. + self.maxpool2 = layers.MaxPool2D(2, strides=2) + + # Flatten the data to a 1-D vector for the fully connected layer. + self.flatten = layers.Flatten() + + # Fully connected layer. + self.fc1 = layers.Dense(1024) + # Apply Dropout (if is_training is False, dropout is not applied). + self.dropout = layers.Dropout(rate=0.5) + + # Output layer, class prediction. + self.out = layers.Dense(num_classes) + + # Set forward pass. + def call(self, x, is_training=False): + x = tf.reshape(x, [-1, 28, 28, 1]) + x = self.conv1(x) + x = self.maxpool1(x) + x = self.conv2(x) + x = self.maxpool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.dropout(x) + x = self.out(x) + if not is_training: + # tf cross entropy expect logits without softmax, so only + # apply softmax when not training. + x = tf.nn.softmax(x) + return x +''' +# Build neural network model. +conv_net = ConvNet() + +# Cross-Entropy Loss. +# Note that this will apply 'softmax' to the logits. +def cross_entropy_loss(x, y): + # Convert labels to int 64 for tf cross-entropy function. + y = tf.cast(y, tf.int64) + # Apply softmax to logits and compute cross-entropy. + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x) + # Average loss across the batch. + return tf.reduce_mean(loss) + +# Accuracy metric. +def accuracy(y_pred, y_true): + # Predicted class is the index of highest score in prediction vector (i.e. argmax). + correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64)) + return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1) + +# Stochastic gradient descent optimizer. +optimizer = tf.optimizers.Adam(learning_rate) + +# Optimization process. +def run_optimization(x, y): + # Wrap computation inside a GradientTape for automatic differentiation. + with tf.GradientTape() as g: + # Forward pass. + pred = conv_net(x, is_training=True) + # Compute loss. + loss = cross_entropy_loss(pred, y) + + # Variables to update, i.e. trainable variables. + trainable_variables = conv_net.trainable_variables + + # Compute gradients. + gradients = g.gradient(loss, trainable_variables) + + # Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, trainable_variables)) + +# Run training for the given number of steps. + +for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1): + # Run the optimization to update W and b values. + run_optimization(batch_x, batch_y) + + if step % display_step == 0: + pred = conv_net(batch_x) + loss = cross_entropy_loss(pred, batch_y) + acc = accuracy(pred, batch_y) + print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc)) + +# Test model on validation set. +pred = conv_net(x_test) +print("Test Accuracy: %f" % accuracy(pred, y_test)) + +conv_net.save_weights('weights.h5') +''' + +conv_net = ConvNet() +conv_net.build(x_test.shape) +conv_net.load_weights('weights.h5') +# Test model on validation set. +pred = conv_net(x_test) +# print("Test Accuracy: %f" % accuracy(pred, y_test)) + +# Visualize predictions. +import matplotlib.pyplot as plt + +# Predict 5 images from validation set. +n_images = 5 +test_images = x_test[:n_images] +predictions = conv_net(test_images) + +# Display image and model prediction. +for i in range(n_images): + plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray') + plt.show() + print("Model prediction: %i" % np.argmax(predictions.numpy()[i])) \ No newline at end of file diff --git a/src/python/xor_keras.py b/src/python/xor_keras.py new file mode 100644 index 000000000..e73886050 --- /dev/null +++ b/src/python/xor_keras.py @@ -0,0 +1,24 @@ +import os +import numpy as np +import tensorflow as tf + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +print(tf.__version__) +# https://playground.tensorflow.org/ +# tf.compat.v1.enable_eager_execution() +# tf.debugging.set_log_device_placement(True); +tf.config.run_functions_eagerly(True) + +x = np.array([[ 0, 0 ], [ 0, 1 ], [ 1, 0 ], [ 1, 1 ]]) +y = np.array([[ 0 ], [ 1 ], [ 1 ], [ 0 ] ]) + +model = tf.keras.Sequential() +model.add(tf.keras.Input(2)) +model.add(tf.keras.layers.Dense(32, "relu")) +model.add(tf.keras.layers.Dense(1, "sigmoid")) +model.compile(optimizer = tf.keras.optimizers.Adam(), + loss = tf.keras.losses.MeanSquaredError(), + metrics = ["accuracy"]) +model.fit(x, y, 1, 100) +result = model.evaluate(x, y) +print(model.predict(x, 4)) \ No newline at end of file diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md deleted file mode 100644 index ea6f02d5e..000000000 --- a/tensorflowlib/README.md +++ /dev/null @@ -1,7 +0,0 @@ -Here are some pre-built TensorFlow binaries you can use for each platform: - -- Linux - - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0.tar.gz - - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.12.0.tar.gz -- Mac: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.12.0.tar.gz -- Windows: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0.zip \ No newline at end of file diff --git a/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj b/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj new file mode 100644 index 000000000..461993408 --- /dev/null +++ b/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj @@ -0,0 +1,24 @@ + + + + net6.0 + enable + enable + + false + true + + + + + + + + + + + + + + + diff --git a/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs b/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs new file mode 100644 index 000000000..67d0aa602 --- /dev/null +++ b/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs @@ -0,0 +1,63 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlow.Kernel.UnitTest +{ + [TestClass] + public class concat_op_test + { + [TestMethod] + public void testConcatEmpty() + { + var t1 = tf.constant(new int[] { }); + var t2 = tf.constant(new int[] { }); + var c = array_ops.concat(new[] { t1, t2 }, 0); + var expected = np.array(new int[] { }); + Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray())); + } + + [TestMethod] + public void testConcatNegativeAxis() + { + var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); + var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }); + var c = array_ops.concat(new[] { t1, t2 }, -2); + var expected = np.array(new int[,,] { { { 1, 2, 3 }, { 4, 5, 6 } }, { { 7, 8, 9 }, { 10, 11, 12 } } }); + Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray())); + + c = array_ops.concat(new[] { t1, t2 }, -1); + expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); + Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray())); + } + + [TestMethod] + [DataRow(TF_DataType.TF_INT32)] + [DataRow(TF_DataType.TF_INT64)] + [DataRow(TF_DataType.TF_UINT32)] + [DataRow(TF_DataType.TF_UINT64)] + public void testConcatDtype(TF_DataType dtype) + { + var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }, dtype: dtype); + var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }, dtype: dtype); + var c = array_ops.concat(new[] { t1, t2 }, 1); + var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); + Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray())); + + } + + [TestMethod] + [DataRow(TF_DataType.TF_INT32)] + [DataRow(TF_DataType.TF_INT64)] + public void testConcatAxisType(TF_DataType dtype) + { + var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); + var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }); + var c = array_ops.concat(new[] { t1, t2 }, tf.constant(1, dtype: dtype)); + var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); + Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray())); + } + + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs deleted file mode 100644 index 2913c6b60..000000000 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; - -namespace TensorFlowNET.Examples -{ - /// - /// Simple hello world using TensorFlow - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/helloworld.py - /// - public class HelloWorld : IExample - { - public void Run() - { - /* # Create a Constant op - The op is added as a node to the default graph. - - The value returned by the constructor represents the output - of the Constant op.*/ - var hello = tf.constant("Hello, TensorFlow!"); - - // Start tf session - var sess = tf.Session(); - - // Run the op - sess.run(hello); - } - } -} diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs deleted file mode 100644 index fdb0c2bbd..000000000 --- a/test/TensorFlowNET.Examples/Program.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System; -using System.Linq; -using System.Reflection; - -namespace TensorFlowNET.Examples -{ - class Program - { - static void Main(string[] args) - { - var assembly = Assembly.GetEntryAssembly(); - foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) - { - var example = (IExample)Activator.CreateInstance(type); - example.Run(); - } - } - } -} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj deleted file mode 100644 index 627d10888..000000000 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ /dev/null @@ -1,12 +0,0 @@ - - - - Exe - netcoreapp2.1 - - - - - - - diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs new file mode 100644 index 000000000..21c5fdbfe --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs @@ -0,0 +1,105 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class QueueTest : GraphModeTestBase + { + [TestMethod] + public void PaddingFIFOQueue() + { + var numbers = tf.placeholder(tf.int32); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new Shape(-1)); + var enqueue = queue.enqueue(numbers); + var dequeue_many = queue.dequeue_many(n: 3); + + var sess = tf.Session(); + sess.run(enqueue, (numbers, new[] { 1 })); + sess.run(enqueue, (numbers, new[] { 2, 3 })); + sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); + + var result = sess.run(dequeue_many[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); + } + + [TestMethod] + public void FIFOQueue() + { + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 3 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + var sess = tf.Session(); + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } + + [TestMethod] + public void PriorityQueue() + { + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + var sess = tf.Session(); + init.run(); + + var result = sess.run(x); + Assert.AreEqual(result[0], 2L); + + result = sess.run(x); + Assert.AreEqual(result[0], 3L); + + result = sess.run(x); + Assert.AreEqual(result[0], 4L); + } + + [TestMethod] + public void RandomShuffleQueue() + { + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + var sess = tf.Session(); + init.run(); + + foreach (var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/SessionTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/SessionTest.cs new file mode 100644 index 000000000..2300b0948 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/SessionTest.cs @@ -0,0 +1,116 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class SessionTest : GraphModeTestBase + { + [TestMethod] + public void EvalTensor() + { + lock (this) + { + var a = constant_op.constant(np.array(3.0).reshape((1, 1))); + var b = constant_op.constant(np.array(2.0).reshape((1, 1))); + var c = math_ops.matmul(a, b, name: "matmul"); + var sess = tf.Session(); + var result = c.eval(sess); + Assert.AreEqual(result[0], 6.0); + } + } + + [TestMethod] + public void Eval_SmallString_Scalar() + { + var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 4, 8); + var sess = tf.Session(); + var result = c.eval(sess).StringData(); + Assert.AreEqual(result[0], "heythere"); + } + + [TestMethod] + public void Eval_LargeString_Scalar() + { + lock (this) + { + const int size = 30_000; + var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 0, size - 5000); + var sess = tf.Session(); + var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); + Console.WriteLine(result); + } + } + + [TestMethod] + public void Autocast_Case0() + { + var sess = tf.Session().as_default(); + ITensorOrOperation operation = tf.global_variables_initializer(); + // the cast to ITensorOrOperation is essential for the test of this method signature + var ret = sess.run(operation); + } + + [TestMethod] + public void Autocast_Case1() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.int32, shape: new Shape(6)); + var op = tf.reshape(input, new int[] { 2, 3 }); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6))); + + Assert.AreEqual(ret.shape, (2, 3)); + assertAllEqual(ret.ToArray(), new[] { 1, 2, 3, 4, 5, 6 }); + print(ret.dtype); + print(ret); + } + + [TestMethod] + public void Autocast_Case2() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.float32, shape: new Shape(6)); + var op = tf.reshape(input, new int[] { 2, 3 }); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); + } + + [TestMethod, Ignore] + public void Autocast_Case3() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.float32, shape: new Shape(6)); + var op = tf.reshape(input, new int[] { 2, 3 }); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); + + Assert.AreEqual(ret.shape, (2, 3)); + Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 }); + print(ret.dtype); + print(ret); + } + + [TestMethod, Ignore] + public void Autocast_Case4() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.byte8, shape: new Shape(6)); + var op = tf.reshape(input, new int[] { 2, 3 }); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f)); + + Assert.AreEqual(ret.shape, (2, 3)); + Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 }); + print(ret.dtype); + print(ret); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs new file mode 100644 index 000000000..8093c1f23 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs @@ -0,0 +1,75 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using static Tensorflow.Binding; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class TensorTest : GraphModeTestBase + { + [TestMethod, Ignore] + public void sparse_to_dense() + { + var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 }); + var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); + var st = tf.concat(values: new[] { indices, labels }, axis: 1); + var onehot = tf.sparse_to_dense(st, (5, 5), 1); + var sess = tf.Session(); + var result = sess.run(onehot); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray())); + } + + [TestMethod, Ignore] + public void sparse_tensor_to_dense() + { + var decoded_list = tf.SparseTensor(new[,] + { + { 0L, 0L }, + { 1L, 2L } + }, + new int[] { 1, 2 }, + new[] { 3L, 4L }); + + var onehot = tf.sparse_tensor_to_dense(decoded_list); + var sess = tf.Session(); + var result = sess.run(onehot); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray())); + } + + [TestMethod] + public void batch_to_space_nd() + { + var inputs = np.arange(24).reshape((4, 2, 3)); + var block_shape = new[] { 2, 2 }; + int[,] crops = { { 0, 0 }, { 0, 0 } }; + var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); + + var sess = tf.Session(); + var result = sess.run(tensor); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray())); + } + + [TestMethod] + public void boolean_mask() + { + if (!tf.executing_eagerly()) + tf.enable_eager_execution(); + var tensor = new[] { 0, 1, 2, 3 }; + var mask = np.array(new[] { true, false, true, false }); + var masked = tf.boolean_mask(tensor, mask); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray())); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs new file mode 100644 index 000000000..3c95501db --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs @@ -0,0 +1,26 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class VariableTest : GraphModeTestBase + { + [TestMethod] + public void InitVariable() + { + var v = tf.Variable(new[] { 1, 2 }); + var init = tf.compat.v1.global_variables_initializer(); + + var sess = tf.compat.v1.Session(); + sess.run(init); + // Usage passing the session explicitly. + print(v.eval(sess)); + // Usage with the default session. The 'with' block + // above makes 'sess' the default session. + print(v.eval()); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs new file mode 100644 index 000000000..abb44eeed --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs @@ -0,0 +1,201 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; +using Tensorflow.Keras.UnitTest; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class ComplexTest : EagerModeTestBase + { + // Tests for Complex128 + + [TestMethod] + public void complex128_basic() + { + double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; + double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; + + Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_real_result = tf.math.real(t_complex); + Tensor t_imag_result = tf.math.imag(t_complex); + + NDArray n_real_result = t_real_result.numpy(); + NDArray n_imag_result = t_imag_result.numpy(); + + double[] d_real_result =n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + [TestMethod] + public void complex128_abs() + { + tf.enable_eager_execution(); + + double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; + + double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_abs_result = tf.abs(t_complex); + + double[] d_abs_result = t_abs_result.numpy().ToArray(); + Assert.IsTrue(base.Equal(d_abs_result, d_abs)); + } + [TestMethod] + public void complex128_conj() + { + double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; + + double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); + + Tensor t_result = tf.math.conj(t_complex); + + NDArray n_real_result = tf.math.real(t_result).numpy(); + NDArray n_imag_result = tf.math.imag(t_result).numpy(); + + double[] d_real_result = n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); + } + [TestMethod] + public void complex128_angle() + { + double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 }; + double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 }; + + double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); + + Tensor t_result = tf.math.angle(t_complex); + + NDArray n_result = t_result.numpy(); + + double[] d_result = n_result.ToArray(); + + Assert.IsTrue(base.Equal(d_result, d_expected)); + } + + // Tests for Complex64 + [TestMethod] + public void complex64_basic() + { + tf.init_scope(); + float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_real_result = tf.math.real(t_complex); + Tensor t_imag_result = tf.math.imag(t_complex); + + // Convert the EagerTensors to NumPy arrays directly + float[] d_real_result = t_real_result.numpy().ToArray(); + float[] d_imag_result = t_imag_result.numpy().ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + [TestMethod] + public void complex64_abs() + { + tf.enable_eager_execution(); + + float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; + float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; + + float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); + + Tensor t_abs_result = tf.abs(t_complex); + + NDArray n_abs_result = t_abs_result.numpy(); + + float[] d_abs_result = n_abs_result.ToArray(); + Assert.IsTrue(base.Equal(d_abs_result, d_abs)); + + } + [TestMethod] + public void complex64_conj() + { + float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; + float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; + + float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; + float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); + + Tensor t_result = tf.math.conj(t_complex); + + NDArray n_real_result = tf.math.real(t_result).numpy(); + NDArray n_imag_result = tf.math.imag(t_result).numpy(); + + float[] d_real_result = n_real_result.ToArray(); + float[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); + + } + [TestMethod] + public void complex64_angle() + { + float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f }; + float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f }; + + float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); + + Tensor t_result = tf.math.angle(t_complex); + + NDArray n_result = t_result.numpy(); + + float[] d_result = n_result.ToArray(); + + Assert.IsTrue(base.Equal(d_result, d_expected)); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs new file mode 100644 index 000000000..7063c22cf --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs @@ -0,0 +1,85 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ControlFlowTest +{ + /// + /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py + /// + [TestClass] + public class CondTestCases : GraphModeTestBase + { + [Ignore("Dependent on UpdateEdge")] + [TestMethod] + public void testCondTrue_ConstOnly() + { + var graph = tf.Graph().as_default(); + + var sess = tf.Session(graph); + var x = tf.constant(2, name: "x"); + var y = tf.constant(5, name: "y"); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.constant(22, name: "t22"), + () => tf.constant(55, name: "f55")); + + int result = z.eval(sess); + assertEquals(result, 22); + } + + [TestMethod] + public void testCondFalse_ConstOnly() + { + var graph = tf.Graph().as_default(); + + var sess = tf.Session(graph); + var x = tf.constant(2, name: "x"); + var y = tf.constant(1, name: "y"); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.constant(22, name: "t22"), + () => tf.constant(11, name: "f11")); + + int result = z.eval(sess); + assertEquals(result, 11); + } + + [Ignore("Dependent on UpdateEdge")] + [TestMethod] + public void testCondTrue() + { + tf.Graph().as_default(); + + var x = tf.constant(2, name: "x"); + var y = tf.constant(5, name: "y"); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.multiply(x, 17), + () => tf.add(y, 23)); + + var result = evaluate(z); + assertEquals(result, 34); + } + + [Ignore("Dependent on UpdateEdge")] + [TestMethod] + public void testCondFalse() + { + tf.Graph().as_default(); + + var x = tf.constant(2); + var y = tf.constant(1); + + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.multiply(x, 17), + () => tf.add(y, 23)); + + var result = evaluate(z); + assertEquals(result, 24); + } + + // NOTE: all other python test cases of this class are either not needed due to strong typing or test a deprecated api + + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs new file mode 100644 index 000000000..667f336f8 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs @@ -0,0 +1,23 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.ControlFlowTest +{ + /// + /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py + /// + [TestClass] + public class ShapeTestCase : GraphModeTestBase + { + + [TestMethod] + public void testShape() + { + var tensor = constant_op.constant(new[] { 1.0, 2.0 }); + self.assertEquals(new long[] { 2 }, tensor.shape.dims); + self.assertEquals(new long[] { 2 }, + control_flow_ops.with_dependencies(new[] { constant_op.constant(1.0).op }, tensor).shape.dims); + } + + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs new file mode 100644 index 000000000..e93324f3e --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs @@ -0,0 +1,50 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ControlFlowTest +{ + [TestClass] + public class WhileContextTestCase : GraphModeTestBase + { + /// + /// https://www.tensorflow.org/api_docs/python/tf/while_loop + /// + [TestMethod] + public void SimpleWhileLoop() + { + var i = constant_op.constant(0, name: "i"); + var c = new Func(x => tf.less(x, 10, name: "c")); + var b = new Func(x => tf.add(x, 1, name: "c")); + // var r = control_flow_ops.while_loop(c, b, i); + } + + private void _testWhileContextHelper(int maximum_iterations) + { + // TODO: implement missing code dependencies + using var sess = this.cached_session(); + var i = constant_op.constant(0, name: "i"); + var c = new Func(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c")); + var b = new Func(x => math_ops.add(x, 1, name: "c")); + //control_flow_ops.while_loop( + // c, b, i , maximum_iterations: tf.constant(maximum_iterations)); + foreach (Operation op in sess.graph.get_operations()) + { + var control_flow_context = op._get_control_flow_context(); + /*if (control_flow_context != null) + self.assertProtoEquals(control_flow_context.to_proto(), + WhileContext.from_proto( + control_flow_context.to_proto()).to_proto(), "");*/ + } + } + + [Ignore("TODO")] + [TestMethod] + public void testWhileContextWithMaximumIterations() + { + _testWhileContextHelper(maximum_iterations: 10); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs new file mode 100644 index 000000000..88b0b0b73 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs @@ -0,0 +1,41 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.FunctionalOpsTest +{ + /// + /// https://www.tensorflow.org/api_docs/python/tf/scan + /// + [TestClass] + public class ScanTestCase : GraphModeTestBase + { + [TestMethod, Ignore("need UpdateEdge API")] + public void ScanForward() + { + var fn = new Func((a, x) => tf.add(a, x)); + + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new Shape(6)); + var scan = functional_ops.scan(fn, input); + var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(result, np.array(1, 3, 6, 10, 15, 21)); + } + + [TestMethod, Ignore("need UpdateEdge API")] + public void ScanReverse() + { + var fn = new Func((a, x) => tf.add(a, x)); + + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new Shape(6)); + var scan = functional_ops.scan(fn, input, reverse: true); + var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(result, np.array(21, 20, 18, 15, 11, 6)); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs new file mode 100644 index 000000000..cea6de172 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -0,0 +1,815 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; +using Tensorflow.Framework; + +namespace TensorFlowNET.UnitTest.Gradient +{ + [TestClass] + public class GradientTest : GraphModeTestBase + { + [TestMethod] + public void BroadcastToGrad() + { + var x = tf.constant(2, dtype: dtypes.float32); + var y = tf.broadcast_to(x, (2, 4, 3)); + var grad = tf.gradients(y, x); + + var sess = tf.Session(graph); + float result = sess.run(grad[0]); + Assert.AreEqual(result, 24.0f); + } + + [TestMethod] + public void CumsumGrad() + { + var x = tf.constant(2, dtype: dtypes.float32); + var y = tf.broadcast_to(x, (2, 4, 3)); + var z = tf.cumsum(y, axis: 1); + var grad = tf.gradients(z, x); + + var sess = tf.Session(graph); + float result = sess.run(grad[0]); + Assert.AreEqual(result, 60.0f); + } + + [TestMethod, Ignore] + public void testGradients() + { + var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in"); + var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w"); + var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b"); + var xw = math_ops.matmul(inp, w, name: "xw"); + var h = nn_ops.bias_add(xw, b, name: "h"); + var w_grad = gradients_impl.gradients(new[] { h }, new[] { w })[0]; + self.assertEquals("MatMul", w_grad.op.type); + // TODO: Operation._original_op + //self.assertEquals(w_grad.op._original_op, xw.op); + self.assertTrue((bool)w_grad.op.get_attr("transpose_a")); + self.assertFalse((bool)w_grad.op.get_attr("transpose_b")); + } + + [TestMethod] + public void testBatchMatMulGradient() + { + var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape: new[] { 2, 3, 3 }); + var b = tf.divide(a, tf.constant(2.0f)); + var c = tf.batch_matmul(a, b); + var g = tf.gradients(c, new[] { a, b }, stop_gradients: new[] { a, b }); + var checkG = new[] + { + 3.0f, 7.5f, 12.0f, + 3.0f, 7.5f, 12.0f, + 3.0f, 7.5f, 12.0f, + 16.5f, 21.0f, 25.5f, + 16.5f, 21.0f, 25.5f, + 16.5f, 21.0f, 25.5f, + 12.0f, 12.0f, 12.0f, + 15.0f, 15.0f, 15.0f, + 18.0f, 18.0f, 18.0f, + 39.0f, 39.0f, 39.0f, + 42.0f, 42.0f, 42.0f, + 45.0f, 45.0f, 45.0f + }; + var sess = tf.Session(); + var result = sess.run(g); + var resultList = result[0].ToArray().ToList(); + resultList.AddRange(result[1].ToArray()); + Console.WriteLine(result.ToString()); + CollectionAssert.AreEqual(resultList.ToArray(), checkG); + } + + [TestMethod] + public void testSimpleGradients() + { + (T, T) evaluateDerivatives(Func f, T xval) where T : unmanaged + { + var x = tf.constant(xval); + var y = f(x); + var g = tf.gradients(y, x); + + var session = tf.Session(); + var result = session.run(new[] { y, g[0] }); + return (result[0].ToArray()[0], result[1].ToArray()[0]); + } + + void test(string name, Func tfF, Func targetF, double[] values) + { + foreach (var x in values) + { + var (expectedY, expectedDY) = targetF(x); + + { + var (actualY, actualDY) = evaluateDerivatives(tfF, x); + self.assertFloat64Equal(expectedY, actualY, $"value {name}/float64 at {x}"); + self.assertFloat64Equal(expectedDY, actualDY, $"derivative {name}/float64 at {x}"); + } + + { + var (actualY, actualDY) = evaluateDerivatives(tfF, (float)x); + self.assertFloat32Equal((float)expectedY, actualY, $"value {name}/float32 at {x}"); + self.assertFloat32Equal((float)expectedDY, actualDY, $"derivative {name}/float32 at {x}"); + } + } + } + + test("tf.exp", + x => tf.exp(5 * x), + x => (Math.Exp(5.0 * x), 5.0 * Math.Exp(5.0 * x)), + new[] { -1.0, 0.0, 1.0, 1.5 }); + + test("tf.log", + x => tf.log(x), + x => (Math.Log(x), 1.0 / x), + new[] { 0.5, 1.0, 1.5, 2.0 }); + + test("tf.sqrt", + x => tf.sqrt(x), + x => (Math.Sqrt(x), 0.5 / Math.Sqrt(x)), + new[] { 0.5, 1.0, 1.1, 1.5, 2.0 }); + + test("tf.sin", + x => tf.sin(x), + x => (Math.Sin(x), Math.Cos(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.sinh", + x => tf.sinh(x), + x => (Math.Sinh(x), Math.Cosh(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.cos", + x => tf.cos(x), + x => (Math.Cos(x), -Math.Sin(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.cosh", + x => tf.cosh(x), + x => (Math.Cosh(x), Math.Sinh(x)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.tanh", + x => tf.tanh(x), + x => (Math.Tanh(x), 1.0 - Math.Pow(Math.Tanh(x), 2.0)), + new[] { -1.0, 0.0, 1.0, 1.5, 2.0 }); + + test("tf.maximum", + x => tf.maximum(x, tf.constant(0.0, dtype: x.dtype)), + x => (Math.Max(x, 0.0), (x > 0.0) ? 1.0 : 0.0), + new[] { -1.0, 1.0 }); + + test("tf.minimum", + x => tf.minimum(x, tf.constant(0.0, dtype: x.dtype)), + x => (Math.Min(x, 0.0), (x < 0.0) ? 1.0 : 0.0), + new[] { -1.0, 1.0 }); + } + + [TestMethod] + public void testReduceSumGradients() + { + /* python code + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() + + x = tf.placeholder(tf.float64, shape = (1, 1)) + m = tf.broadcast_to(x, (2, 3)) + g0 = tf.gradients(tf.reduce_sum(m), x)[0] + g1 = tf.gradients(tf.reduce_sum(m, axis = 0)[0], x)[0] + g2 = tf.gradients(tf.reduce_sum(m, axis = 1)[0], x)[0] + with tf.compat.v1.Session() as sess: + (r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]}) + */ + + var x = tf.placeholder(tf.float64, shape: new Shape(1, 1)); + var m = tf.broadcast_to(x, new Shape(2, 3)); + var g0 = tf.gradients(tf.reduce_sum(m), x)[0]; + var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; + var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; + + var session = tf.Session(); + var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); + self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); + self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); + self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); + } + + [TestMethod] + public void testTanhGradient() + { + var a = tf.constant(1f); + var b = tf.tanh(a); + var g = tf.gradients(b, a); + var sess = tf.Session(); + var result = sess.run(g); + var actual = result[0]; + Assert.AreEqual(actual, 0.41997434127f); + } + + + [TestMethod] + public void testLgammaGrad() + { + var a = tf.constant(5f); + var b = tf.lgamma(a); + var g = tf.gradients(b, a); + var sess = tf.Session(); + var result = sess.run(new object[] { g, b }); + var actualDeriv = result[0]; + var actual = result[1]; + Assert.AreEqual(actualDeriv, 1.5061177f); + Assert.AreEqual(actual, 3.17805386f); + } + + [TestMethod] + public void testSliceGrad() + { + var a = tf.tanh(tf.constant(new[] { 2f, 3f }, shape: new[] { 2, 1 })); + var b = tf.strided_slice(a, + tf.constant(new[] { 0 }, tf.int32, new[] { 1 }), + tf.constant(new[] { 1 }, tf.int32, new[] { 1 }), + tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) + ); + var g = tf.gradients(b, a); + var sess = tf.Session(); + var result = sess.run(new object[] { g, b }); + var actualDeriv = np.squeeze(result[0]); + var actual = np.squeeze(result[1]); + Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); + Assert.AreEqual(actual, 0.9640276f); + } + + [TestMethod] + public void testConcatGrad() + { + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var a = tf.concat(new List(new[] { a1, a2 }), 0); + var g = tf.gradients(a, a1); + var sess = tf.Session(); + var result = sess.run(new object[] { g, a }); + var actualDeriv = result[0][0]; + var actual = result[1][0]; + Assert.AreEqual(actualDeriv, 1f); + Assert.AreEqual(actual, 2f); + } + + [TestMethod] + public void testStopGradientFunction() + { + var ap = tf.constant(1f); + var b = tf.tanh(ap) + array_ops.stop_gradient(ap); + var g = tf.gradients(b, ap); + var sess = tf.Session(); + var result = sess.run(g); + var actual = result[0]; + Assert.AreEqual(actual, 0.41997434127f); + } + + [Ignore("TODO")] + [TestMethod] + public void testUnusedOutput() + { + //def testUnusedOutput(self): + // with ops.Graph().as_default(): + // w = constant(1.0, shape=[2, 2]) + // x = constant(1.0, shape=[2, 2]) + // wx = math_ops.matmul(w, x) + // split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) + // c = math_ops.reduce_sum(split_wx[1]) + // gw = gradients.gradients(c, [w])[0] + // self.assertEquals("MatMul", gw.op.type) + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradients() + { + + //def testColocateGradients(self): + // with ops.Graph().as_default() as g: + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // with g.device("/device:GPU:0"): + // wx = math_ops.matmul(w, x) + // gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithAggregation() + { + //def testColocateGradientsWithAggregation(self): + // with ops.Graph().as_default() as g: + // with g.device("/device:GPU:1"): + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // y = constant(1.0, shape=[1, 2]) + // wx = math_ops.matmul(w, x) + // wy = math_ops.matmul(w, y) + // with g.device("/device:GPU:0"): + // z = wx + wy + + // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) + + // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + // self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithAggregationInMultipleDevices() + { + //def testColocateGradientsWithAggregationInMultipleDevices(self): + // with ops.Graph().as_default() as g: + // with g.device("/device:GPU:1"): + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // y = constant(1.0, shape=[1, 2]) + // with g.device("/task:1"): + // wx = math_ops.matmul(w, x) + // with g.device("/task:2"): + // wy = math_ops.matmul(w, y) + // with g.device("/device:GPU:0"): + // z = wx + wy + + // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) + + // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + // self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups()) + } + + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithGateGradients() + { + + //def testColocateGradientsWithGateGradients(self): + // if not test_util.is_gpu_available(): + // self.skipTest("No GPU available") + // with ops.Graph().as_default() as g: + // with g.device("/device:CPU:0"): + // x = constant(1.0, shape=[1, 1]) + // y = constant(1.0, shape=[1, 1]) + // s = x + y + // with g.device("/device:GPU:0"): + // z = math_ops.reduce_sum(s) + + // gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, + // gate_gradients=True)[0] + // with session.Session(): + // # Make sure the placer doesn't complain. + // self.evaluate(gz_x) + + } + + [Ignore("TODO")] + [TestMethod] + public void testBoundaryStop() + { + //def testBoundaryStop(self): + // # Test that we don't differentiate 'x'. The gradient function for 'x' is + // # set explicitly to None so we will get an exception if the gradient code + // # tries to differentiate 'x'. + // with ops.Graph().as_default(): + // c = constant(1.0) + // x = array_ops.identity(c) + // y = x + 1.0 + // z = y + 1 + // grads = gradients.gradients(z, [x]) + // self.assertTrue(all(x is not None for x in grads)) + + } + + [TestMethod] + public void testBoundaryContinue() + { + // Test that we differentiate both 'x' and 'y' correctly when x is a + // predecessor of y. + + //TODO: @test_util.run_v1_only("b/120545219") + + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = x * 2.0; + var z = y * 3.0; + var grads = tf.gradients(z, new[] { x, y }); + self.assertTrue(all(grads.Select(x => x != null))); + self.assertEqual(6.0, grads[0].eval()); + } + } + + [TestMethod] + public void testAggregationMethodAccumulateN() + { + //TODO: @test_util.run_v1_only("b/120545219") + + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = x * 2.0; + var z = y + y + y + y + y + y + y + y + y + y; + var grads = tf.gradients(z, new[] { x, y }, + aggregation_method: AggregationMethod.EXPERIMENTAL_ACCUMULATE_N); + self.assertTrue(all(grads.Select(x => x != null))); + self.assertEqual(20.0, grads[0].eval()); + self.assertEqual(10.0, grads[1].eval()); + } + } + + [TestMethod] + public void testAggregationMethodAddN() + { + //TODO: @test_util.run_v1_only("b/120545219") + + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = x * 2.0; + var z = y + y + y + y + y + y + y + y + y + y; + var grads = tf.gradients(z, new[] { x, y }, + aggregation_method: AggregationMethod.ADD_N); + self.assertTrue(grads.All(x => x != null)); + self.assertEqual(20.0, grads[0].eval()); + self.assertEqual(10.0, grads[1].eval()); + } + } + + [TestMethod] + public void testAggregationMethodTree() + { + //TODO: @test_util.run_v1_only("b/120545219") + + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = x * 2.0; + var z = y + y + y + y + y + y + y + y + y + y; + var grads = tf.gradients(z, new[] { x, y }, + aggregation_method: AggregationMethod.EXPERIMENTAL_TREE); + self.assertTrue(grads.All(x => x != null)); + self.assertEqual(20.0, grads[0].eval()); + self.assertEqual(10.0, grads[1].eval()); + } + } + + [Ignore("TODO")] + [TestMethod] + public void testNoGradientForStringOutputs() + { + + //def testNoGradientForStringOutputs(self): + // with ops.Graph().as_default(): + + // def _TestOpGrad(_, float_grad, string_grad): + // """Gradient function for TestStringOutput.""" + // self.assertEquals(float_grad.dtype, dtypes.float32) + // self.assertFalse(string_grad) + // return float_grad + + // ops.RegisterGradient("TestStringOutput")(_TestOpGrad) + + // c = constant(1.0) + // x, _ = test_ops.test_string_output(c) + // z = x * 2.0 + // w = z * 3.0 + // grads = gradients.gradients(z, [c]) + // self.assertTrue(isinstance(grads[0], ops.Tensor)) + // grads = gradients.gradients(w, [c]) + // self.assertTrue(isinstance(grads[0], ops.Tensor)) + } + + [Ignore("TODO: CompositeTensors are not supported yet.")] + [TestMethod] + public void testSingletonIndexedSlices() + { + tf.Graph().as_default(); + + // TODO: uncomment when CompositeTensors are supported. + /* + var x = tf.placeholder(TF_DataType.TF_FLOAT); + var y = tf.identity(x); + var dy_indices = tf.placeholder(TF_DataType.TF_INT32); + var dy_values = tf.placeholder(TF_DataType.TF_FLOAT); + var dy = new IndexedSlices(dy_values, dy_indices); + + var dx = tf.gradients(new[] { y }, new[] { x }, grad_ys: new[] { dy })[0]; + // The IndexedSlices gradient of tf.identity is the identity map. + using (var sess = self.cached_session()) + { + var feed_dict = new FeedItem[] + { + ( x, new Tensor(new float[] { 1.0f }) ), + (dy_indices, new Tensor(new int[] { 0 })), + (dy_values, new Tensor(new float[] { 2.0f })) + }; + var result = sess.run(new[] { dx, dy }, feed_dict); + var vdx = result[0]; + var vdy = result[1]; + self.assertEqual(vdx, vdy); + } + */ + + } + + [Ignore("TODO")] + [TestMethod] + public void testNonDifferentiableSwitchInWhileLoop() + { + + + //@test_util.run_v1_only("b/120545219") + //def testNonDifferentiableSwitchInWhileLoop(self): + // with ops.Graph().as_default(): + // v = array_ops.placeholder(dtypes.float32, []) + + // def _Step(i, a, ta): + // a += math_ops.cast(v, dtypes.int32) + // return (i + 1, a, ta.write(i, a)) + + // n = 4 + // i, _, ta = control_flow_ops.while_loop( + // lambda i, *_: i < n, + // _Step, [0, 0, tensor_array_ops.TensorArray( + // dtypes.int32, size=n)]) + // target = ta.read(i - 1) + // grad, = gradients.gradients(target, v) + // self.assertIsNone(grad) + + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableReadValueGradient() + { + + //def testVariableReadValueGradient(self): + // with ops.Graph().as_default(): + // init = constant_op.constant(100.0) + // var = variables.Variable(init) + // gradient = gradients.gradients(var.read_value(), var) + // self.assertIsNotNone(gradient) + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableAsGraphElementGradient() + { + //def testVariableAsGraphElementGradient(self): + // with ops.Graph().as_default() as graph: + // init = constant_op.constant(100.0) + // var = variables.Variable(init) + // gradient = gradients.gradients(graph.as_graph_element(var), var) + // self.assertIsNotNone(gradient) + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableRefGradient() + { + + //@test_util.run_v1_only("b/120545219") + //def testVariableRefGradient(self): + // with ops.Graph().as_default(): + // init = constant_op.constant(100.0) + // var = variables.VariableV1(init) + // gradient = gradients.gradients(var._ref(), var) + // self.assertIsNotNone(gradient) + } + + [TestMethod] + public void testDependentYs() + { + //TODO: @test_util.run_v1_only("b/120545219") + using (self.cached_session()) + { + var x = constant_op.constant(3.0); + var y = math_ops.square(x); + var y1 = math_ops.square(y); + var y2 = math_ops.square(y1); + var g = tf.gradients(new[] { y, y2 }, new[] { x }); + self.assertAllClose(17502.0, g[0].eval()); + g = tf.gradients(y + y2, x); + self.assertAllClose(17502.0, g[0].eval()); + var z = array_ops.identity(y); + var z2 = array_ops.identity(y2); + g = tf.gradients(new[] { z, z2 }, new[] { x }); + self.assertAllClose(17502.0, g[0].eval()); + } + } + + [Ignore("TODO")] + [TestMethod] + public void testPartialDerivatives() + { + + //TODO: @test_util.run_v1_only("b/120545219") + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = 2 * x; + var z = x + y; + var totalg = tf.gradients(z, new[] { x, y }); + self.assertEqual(new[] { 3.0, 1.0 }, totalg.Select(g => g.eval())); + var partialg = tf.gradients(z, new[] { x, y }, stop_gradients: new[] { x, y }); + self.assertEqual(new[] { 1.0, 1.0 }, partialg.Select(g => g.eval())); + } + } + + private struct Case + { + public Tensor[] grad1; + public Tensor[] grad2; + public string constants; + public string variables; + } + + [Ignore("FIXME")] + [TestMethod] + public void testStopGradients() + { + + //TODO: @test_util.run_v1_only("b/120545219") + Dictionary makeGraph(RandomizedImpl rng, string stop_gradients) + { + Tensor functionOf(Tensor[] xs, int k) + { + var shape = new Shape(k, k); + // TODO: replace by DefaultIfEmpty() before Aggregate(). + if (!xs.Any()) + { + return rng.random(shape).astype(np.float32); + } + return xs.Select(x => gen_math_ops.mat_mul(rng.random(shape).astype(np.float32), x)) + .Aggregate((t1, t2) => t1 + t2) + + rng.random(shape).astype(np.float32); + } + + var a = functionOf(Array.Empty(), 3); + if (stop_gradients.Contains('a')) a = array_ops.stop_gradient(a); + var b = functionOf(new Tensor[] { a }, 3); + if (stop_gradients.Contains('b')) b = array_ops.stop_gradient(b); + var c = functionOf(new Tensor[] { a, b }, 3); + if (stop_gradients.Contains('c')) c = array_ops.stop_gradient(c); + var d = functionOf(new Tensor[] { b, c }, 3); + if (stop_gradients.Contains('d')) d = array_ops.stop_gradient(d); + + return new Dictionary + { + { 'a', a }, + { 'b', b }, + { 'c', c }, + { 'd', d } + }; + } + + Tensor[] gradients(Tensor[] ys, Tensor[] xs, Tensor[] stop_gradients = null) + { + var dydxs = tf.gradients(ys, xs, stop_gradients); + dydxs = dydxs.Select((dydx, i) => dydx == null ? xs[i] * 0 : dydx).ToArray(); + return dydxs; + } + + var seed = np.random.randint(1000); + // TODO: remove next line when np.random.RandomState implemented. + tf.set_random_seed(seed); + var cases = new List(); + // TODO: add "" case. + var subsets = new List { "" }.Concat("a b c d ab ac ad bc bd cd abc abd acd bcd abcd".Split()); + // TODO: pass np.random.RandomState(seed) instead of np.random + var graph = makeGraph(np.random, string.Empty); + foreach (var constants in subsets) + { + var graphWithStops = makeGraph(np.random, constants); + foreach (var variables_ in subsets) + { + // compute the gradient when stopped using tf.stop_gradients + var grad1 = gradients( + new[] { graphWithStops['d'] }, + variables_.ToCharArray().Select(v => graphWithStops[v]).ToArray() + ); + // compute the gradient when stopped using the stop_gradients from args + var grad2 = gradients( + new[] { graph['d'] }, + variables_.ToCharArray().Select(v => graph[v]).ToArray(), + constants.ToCharArray().Select(c => graph[c]).DefaultIfEmpty(null)?.ToArray() + ); + cases.Add(new Case + { + grad1 = grad1, + grad2 = grad2, + variables = variables_, + constants = constants, + }) ; + } + } + + // evaluate all tensors in one call to session.run for speed + using (var sess = self.cached_session()) + { + var results = sess.run( + cases.Select(case_ => ( + case_.grad1, + case_.grad2 + )).ToArray() + ); + + foreach (var (result, case_) in results.Zip(cases)) + { + var npgrad1 = result[0]; + var npgrad2 = result[1]; + foreach (var (a, b) in npgrad1.Zip(npgrad2)) + { + self.assertAllClose(a, b); + } + } + } + } + + + + [Ignore("TODO: Unconnected gradients are not implemented")] + [TestMethod] + public void testUnconnectedGradientsNoneUnconnectedGradients() + { + + + //def testUnconnectedGradientsNoneUnconnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0, shape=[2, 2]) + // y = constant(3.0, shape=[3, 1]) + // grad = gradients.gradients( + // [y], [x], unconnected_gradients="none") + // self.assertIsNone(grad[0]) + } + + [Ignore("TODO: Unconnected gradients are not implemented")] + [TestMethod] + public void testUnconnectedGradientsZerosUnconnectedGradients() + { + //def testUnconnectedGradientsZerosUnconnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0, shape=[2, 2]) + // y = constant(3.0, shape=[3, 1]) + // grads = gradients.gradients( + // [y], [x], unconnected_gradients="zero") + // with self.cached_session() as sess: + // self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) + + // tf.Graph().as_default(); + // var x = tf.constant(1.0, shape: new long[] { 2, 2 }); + // var y = tf.constant(3.0, shape: new long[] { 3, 1 }); + // var grads = tf.gradients(new[] { y }, new[] { x }, unconnected_gradients: "zero"); + // using (self.cached_session()) + // { + // self.assertAllEqual(new[,] { { 0.0, 0.0 }, { 0.0, 0.0 } }, self.evaluate(grads)[0]); + // } + } + + [Ignore("TODO: Unconnected gradients are not implemented")] + [TestMethod] + public void testUnconnectedGradientsZeroConnectedGradients() + { + //def testUnconnectedGradientsZeroConnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0) + // y = x * 3.0 + // grad = gradients.gradients( + // [y], [x], unconnected_gradients="zero") + // with self.cached_session() as sess: + // self.assertEquals(3.0, self.evaluate(grad)[0]) + + // tf.Graph().as_default(); + + // var x = tf.constant(1.0f); + // var y = x * 3.0f; + // var grad = tf.gradients(new [] { y }, new [] { x }, unconnected_gradients: "zero"); + // using (var sess = tf.Session()) + // { + // self.assertEquals(3.0, self.evaluate(grad)[0]); + // } + } + + [Ignore("TODO: Unconnected gradients are not implemented")] + [TestMethod] + public void testUnknownUnconnectedGradientsValueGiven() + { + //def testUnknownUnconnectedGradientsValueGiven(self): + // with ops.Graph().as_default(): + // x = constant(1.0) + // y = constant(1.0) + // with self.assertRaisesRegexp( + // ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): + // gradients.gradients([y], [x], unconnected_gradients="nonsense") + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs new file mode 100644 index 000000000..a8bb079e3 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs @@ -0,0 +1,23 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + public class GraphModeTestBase : PythonTest + { + protected Graph graph; + [TestInitialize] + public void TestInit() + { + tf.compat.v1.disable_eager_execution(); + graph = tf.Graph().as_default(); + } + + [TestCleanup] + public void TestClean() + { + graph.Exit(); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs new file mode 100644 index 000000000..127b65bf6 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs @@ -0,0 +1,258 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; +using System; +using System.IO; + +namespace TensorFlowNET.UnitTest +{ + /// + /// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file + /// + [TestClass] + public class ImageTest : GraphModeTestBase + { + string imgPath = "shasta-daisy.jpg"; + Tensor contents; + + [TestInitialize] + public void Initialize() + { + imgPath = TestHelper.GetFullPathFromDataDir(imgPath); + contents = tf.io.read_file(imgPath); + } + + [TestMethod] + public void adjust_contrast() + { + var input = np.array(0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f); + var image = tf.reshape(input, new int[] { 3, 3, 1 }); + + var init = tf.global_variables_initializer(); + var sess = tf.Session(); + sess.run(init); + var adjust_contrast = tf.image.adjust_contrast(image, 2.0f); + var result = sess.run(adjust_contrast); + var res = np.array(-4f, -2f, 0f, 2f, 4f, 6f, 8f, 10f, 12f).reshape((3,3,1)); + Assert.AreEqual(result.numpy(), res); + } + + [Ignore] + [TestMethod] + public void adjust_hue() + { + var image = tf.constant(new int[] {1,2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18}); + image = tf.reshape(image, new int[] { 3, 2, 3 }); + var adjusted_image = tf.image.adjust_hue(image, 0.2f); + var res = tf.constant(new int[] {2,1,3, 4, 5, 6,8,7,9,11,10,12,14,13,15,17,16,18}); + res = tf.reshape(res,(3,2,3)); + Assert.AreEqual(adjusted_image, res); + } + + [TestMethod] + public void combined_non_max_suppression() + { + var boxesX = tf.constant(new float[,] { { 200, 100, 150, 100 }, { 220, 120, 150, 100 }, { 190, 110, 150, 100 }, { 210, 112, 150, 100 } }); + var boxes1 = tf.reshape(boxesX, (1, 4, 1, 4)); + var scoresX = tf.constant(new float[,] { { 0.2f, 0.7f, 0.1f }, { 0.1f, 0.8f, 0.1f }, { 0.3f, 0.6f, 0.1f }, { 0.05f, 0.9f, 0.05f } }); + var scores1 = tf.reshape(scoresX, (1, 4, 3)); + + var init = tf.global_variables_initializer(); + var sess = tf.Session(); + sess.run(init); + + var (boxes, scores, classes, valid_detections) = tf.image.combined_non_max_suppression(boxes1, scores1, 10, 10, 0.5f, 0.2f, clip_boxes: false); + var result = sess.run((boxes, scores, classes, valid_detections)); + + var boxes_gt = tf.constant(new float[,] { { 210f, 112f, 150f, 100f }, { 200f, 100f, 150f, 100f }, { 190f, 110f, 150f, 100f }, + { 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f , 0f},{ 0f, 0f, 0f, 0f},{ 0f , 0f, 0f, 0f},{ 0f, 0f, 0f, 0f} }); + boxes_gt = tf.reshape(boxes_gt, (1, 10, 4)); + Assert.AreEqual(result.Item1.numpy(), boxes_gt.numpy()); + var scores_gt = tf.constant(new float[,] { { 0.9f, 0.7f, 0.3f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } }); + scores_gt = tf.reshape(scores_gt, (1, 10)); + Assert.AreEqual(result.Item2.numpy(), scores_gt.numpy()); + var classes_gt = tf.constant(new float[,] { { 1f, 1f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } }); + classes_gt = tf.reshape(classes_gt, (1, 10)); + Assert.AreEqual(result.Item3.numpy(), classes_gt.numpy()); + var valid_detections_gt = tf.constant(new int[,] { { 3 } }); + valid_detections_gt = tf.reshape(valid_detections_gt, (1)); + Assert.AreEqual(result.Item4.numpy(), valid_detections_gt.numpy()); + } + + [TestMethod] + public void crop_and_resize() + { + int BATCH_SIZE = 1; + int NUM_BOXES = 5; + int IMAGE_HEIGHT = 256; + int IMAGE_WIDTH = 256; + int CHANNELS = 3; + var crop_size = tf.constant(new int[] { 24, 24 }); + var image = tf.random.uniform((BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS)); + var boxes = tf.random.uniform((NUM_BOXES, 4)); + var box_ind = tf.random.uniform((NUM_BOXES), minval: 0, maxval: BATCH_SIZE, dtype: TF_DataType.TF_INT32); + var output = tf.image.crop_and_resize(image, boxes, box_ind, crop_size); + Assert.AreEqual((5,24,24,3), output.shape); + } + + [TestMethod] + public void decode_image() + { + var img = tf.image.decode_image(contents); + Assert.AreEqual(img.name, "decode_image/DecodeImage:0"); + } + + [TestMethod] + public void resize_image() + { + tf.enable_eager_execution(); + var image = tf.constant(new int[5, 5] + { + {1, 0, 0, 0, 0 }, + {0, 1, 0, 0, 0 }, + {0, 0, 1, 0, 0 }, + {0, 0, 0, 1, 0 }, + {0, 0, 0, 0, 1 } + }); + image = image[tf.newaxis, tf.ellipsis, tf.newaxis]; + image = tf.image.resize(image, (3, 5)); + image = image[0, tf.ellipsis, 0]; + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0.6666667f, 0.3333333f, 0, 0, 0 }, + image[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0, 0, 1, 0, 0 }, + image[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0, 0, 0, 0.3333335f, 0.6666665f }, + image[2].ToArray())); + tf.compat.v1.disable_eager_execution(); + } + + [TestMethod] + public void TestCropAndResize() + { + var graph = tf.Graph().as_default(); + + // 3x3 'Image' with numbered coordinates + var input = np.array(0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f); + var image = tf.reshape(input, new int[] { 1, 3, 3, 1 }); + + // 4x4 'Image' with numbered coordinates + var input2 = np.array(0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f); + var image2 = tf.reshape(input2, new int[] { 1, 4, 4, 1 }); + // create one box over the full image that flips it (y1 > y2) + var box = tf.reshape(np.array(1f, 0f, 0f, 1f), new int[] { 1, 4 }); + var boxInd = tf.Variable(np.array(0)); + // crop first 3x3 imageto size 1x1 + var cropSize1_1 = tf.Variable(np.array(1, 1)); + // don't crop second 4x4 image + var cropSize2_2 = tf.Variable(np.array(4, 4)); + + var init = tf.global_variables_initializer(); + var sess = tf.Session(); + sess.run(init); + + var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); + + var result = sess.run(cropped); + // check if cropped to 1x1 center was succesfull + Assert.AreEqual(result.size, 1ul); + Assert.AreEqual(result[0, 0, 0, 0], 4f); + + cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); + result = sess.run(cropped); + // check if flipped and no cropping occured + Assert.AreEqual(result.size, 16ul); + Assert.AreEqual(result[0, 0, 0, 0], 12f); + } + + [TestMethod] + public void ImageSaveTest() + { + var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); + var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg"); + var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png"); + + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + + var contents = tf.io.read_file(imgPath); + var bmp = tf.image.decode_image(contents); + Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0"); + + var jpeg = tf.image.encode_jpeg(bmp); + var op1 = tf.io.write_file(jpegImgPath, jpeg); + + var png = tf.image.encode_png(bmp); + var op2 = tf.io.write_file(pngImgPath, png); + + this.session().run(op1); + this.session().run(op2); + + Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath); + Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath); + + // 如果要测试图片正确性,需要注释下面两行代码 + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + } + + [TestMethod] + public void ImageFlipTest() + { + var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); + + var contents = tf.io.read_file(imgPath); + var bmp = tf.image.decode_image(contents); + + // 左右翻转 + var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png"); + File.Delete(lrImgPath); + + var lr = tf.image.flip_left_right(bmp); + var png = tf.image.encode_png(lr); + var op = tf.io.write_file(lrImgPath, png); + this.session().run(op); + + Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath); + + // 上下翻转 + var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png"); + File.Delete(updownImgPath); + + var updown = tf.image.flip_up_down(bmp); + var pngupdown = tf.image.encode_png(updown); + var op2 = tf.io.write_file(updownImgPath, pngupdown); + this.session().run(op2); + Assert.IsTrue(File.Exists(updownImgPath)); + + + // 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码 + File.Delete(lrImgPath); + File.Delete(updownImgPath); + + // 多图翻转 + // 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了 + var mImg = tf.stack(new[] { bmp, lr }, axis:0); + print(mImg.shape); + + var up2 = tf.image.flip_up_down(mImg); + + var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转 + File.Delete(updownImgPath_m1); + + var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下 + File.Delete(img001_updown_m2); + + var png2 = tf.image.encode_png(up2[0]); + tf.io.write_file(updownImgPath_m1, png2); + + png2 = tf.image.encode_png(up2[1]); + tf.io.write_file(img001_updown_m2, png2); + + // 如果要测试图片正确性,需要注释下面两行代码 + File.Delete(updownImgPath_m1); + File.Delete(img001_updown_m2); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs new file mode 100644 index 000000000..4b92d0210 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs @@ -0,0 +1,267 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class MultithreadingTests : GraphModeTestBase + { + [TestMethod] + public void SessionCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Assert.IsNull(tf.peak_default_graph()); + + var sess = tf.Session(); + var default_graph = tf.get_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); + } + } + + [TestMethod] + public void SessionCreation_x2() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(16, Core); + + //the core method + void Core(int tid) + { + Assert.IsNull(tf.peak_default_graph()); + //tf.Session created an other graph + var sess = tf.Session(); + var default_graph = tf.get_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); + } + } + + [TestMethod] + public void GraphCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Assert.IsNull(tf.peak_default_graph()); + var beforehand = tf.get_default_graph(); //this should create default automatically. + beforehand.as_default(); + Assert.IsNotNull(tf.peak_default_graph()); + + var sess = tf.Session(); + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); + } + } + + + [TestMethod] + public void Marshal_AllocHGlobal() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + for (int i = 0; i < 100; i++) + { + Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); + } + } + } + + [TestMethod] + public void TensorCreation() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + var sess = tf.Session(); + for (int i = 0; i < 100; i++) + { + var t = new Tensor(1); + } + } + } + + [TestMethod] + public void TensorCreation_Array() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + //tf.Session created an other graph + var sess = tf.Session(); + for (int i = 0; i < 100; i++) + { + var t = new Tensor(new int[] { 1, 2, 3 }); + } + } + } + + [TestMethod] + public void SessionRun() + { + MultiThreadedUnitTestExecuter.Run(2, Core); + + //the core method + void Core(int tid) + { + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); + + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; + var sess = tf.Session(graph); + for (int i = 0; i < 100; i++) + { + var result = sess.run(math); + Assert.AreEqual(result[0], 5f); + } + } + } + + [TestMethod] + public void SessionRun_InsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); + + var sess = tf.Session(graph); + Assert.IsNotNull(tf.get_default_graph()); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; + + var result = sess.run(math); + Assert.AreEqual(result[0], 5f); + } + } + + [TestMethod] + public void SessionRun_Initialization() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + var sess = tf.Session(); + Assert.IsNotNull(tf.get_default_graph()); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; + } + } + + [TestMethod] + public void SessionRun_Initialization_OutsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Assert.IsNull(tf.peak_default_graph()); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); + var math = a1 + a2; + } + } + + [TestMethod] + public void TF_GraphOperationByName() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Assert.IsNull(tf.peak_default_graph()); + + tf.compat.v1.disable_eager_execution(); + var graph = tf.Graph().as_default(); + + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); + var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); + var math = a1 + a2; + for (int i = 0; i < 100; i++) + { + var op = tf.get_default_graph().OperationByName("ConstantK"); + } + } + } + + private static readonly string modelPath = Path.GetFullPath("./Utilities/models/example1/"); + + [Ignore] + public void TF_GraphOperationByName_FromModel() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Console.WriteLine(); + for (int j = 0; j < 100; j++) + { + var sess = Session.LoadFromSavedModel(modelPath).as_default(); + var inputs = new[] { "sp", "fuel" }; + + var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); + var outp = sess.graph.OperationByName("softmax_tensor").output; + + for (var i = 0; i < 8; i++) + { + var data = new float[96]; + FeedItem[] feeds = new FeedItem[2]; + + for (int f = 0; f < 2; f++) + feeds[f] = new FeedItem(inp[f], new NDArray(data)); + + sess.run(outp, feeds); + } + } + } + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs new file mode 100644 index 000000000..253a3259d --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs @@ -0,0 +1,78 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class NameScopeTest : GraphModeTestBase + { + string name = ""; + + [TestMethod] + public void NestedNameScope() + { + Graph g = tf.Graph().as_default(); + + tf_with(new ops.NameScope("scope1"), scope1 => + { + name = scope1; + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + tf_with(new ops.NameScope("scope2"), scope2 => + { + name = scope2; + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + }); + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }); + + g.Exit(); + + Assert.AreEqual("", g._name_stack); + } + + [TestMethod, Ignore("Unimplemented Usage")] + public void NestedNameScope_Using() + { + Graph g = tf.Graph().as_default(); + + using (var name = new ops.NameScope("scope1")) + { + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + using (var name2 = new ops.NameScope("scope2")) + { + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + } + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }; + + g.Exit(); + + Assert.AreEqual("", g._name_stack); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs b/test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs new file mode 100644 index 000000000..47887e29c --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs @@ -0,0 +1,1294 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; +using Buffer = Tensorflow.Buffer; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class OperationsTest : GraphModeTestBase + { + /// + /// Port from tensorflow\c\c_api_test.cc + /// `TEST(CAPI, GetAllOpList)` + /// + [TestMethod] + public void GetAllOpList() + { + var handle = c_api.TF_GetAllOpList(); + var buffer = new Buffer(handle); + var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); + + var _registered_ops = new Dictionary(); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + + // r1.14 added NN op + var op = _registered_ops.FirstOrDefault(x => x.Key == "NearestNeighbors"); + Assert.IsTrue(op_list.Op.Count > 1000); + } + + [TestMethod] + public void addInPlaceholder() + { + var a = tf.placeholder(tf.float32); + var b = tf.placeholder(tf.float32); + var c = tf.add(a, b); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, 3.0f), + new FeedItem(b, 2.0f)); + Assert.AreEqual(o, 5.0f); + } + + [TestMethod] + public void addInConstant() + { + var a = tf.constant(4.0f); + var b = tf.constant(5.0f); + var c = tf.add(a, b); + + var sess = tf.Session(); + var o = sess.run(c); + Assert.AreEqual(o, 9.0f); + } + + [TestMethod] + public void isFinite() + { + var a = tf.constant(new[] { 1, np.nan, 2, np.nan, 3, np.nan, 4, np.nan }); + var b = tf.cast(tf.is_finite(a), tf.float32); + var check = np.array(1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f); + + var sess = tf.Session(); + var o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + } + + [TestMethod] + public void isNan() + { + var a = tf.constant(new[] { 1, np.nan, 2, np.nan, 3, np.nan, 4, np.nan }); + var b = tf.cast(tf.is_nan(a), tf.float32); + var check = np.array(0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f); + + var sess = tf.Session(); + var o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + } + + [TestMethod] + public void cumSumTest() + { + var a = tf.constant(new[] { 1, 1, 2, 3, 4, 5 }); + var b = tf.cumsum(a); + var check = np.array(1, 2, 4, 7, 11, 16); + + var sess = tf.Session(); + var o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + + b = tf.cumsum(a, exclusive: true); + check = np.array(0, 1, 2, 4, 7, 11); + + sess = tf.Session(); + o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + + b = tf.cumsum(a, reverse: true); + check = np.array(16, 15, 14, 12, 9, 5); + + sess = tf.Session(); + o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + + b = tf.cumsum(a, exclusive: true, reverse: true); + check = np.array(15, 14, 12, 9, 5, 0); + + sess = tf.Session(); + o = sess.run(b); + Assert.IsTrue(np.array_equal(o, check)); + } + + [TestMethod] + public void logicalOpsTest() + { + var a = tf.constant(new[] { 1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f }); + var b = tf.less(a, 0f); + var c = tf.greater(a, 0f); + var d = tf.cast(tf.logical_and(b, c), tf.int32); + var check = np.array(new[] { 0, 0, 0, 0, 0, 0, 0, 0 }); + + var sess = tf.Session(); + var o = sess.run(d); + Assert.IsTrue(np.array_equal(o, check)); + + d = tf.cast(tf.logical_not(b), tf.int32); + check = np.array(new[] { 1, 1, 1, 1, 0, 0, 0, 0 }); + + sess = tf.Session(); + o = sess.run(d); + Assert.IsTrue(np.array_equal(o, check)); + + d = tf.cast(tf.logical_or(b, c), tf.int32); + check = np.array(new[] { 1, 1, 1, 1, 1, 1, 1, 1 }); + + sess = tf.Session(); + o = sess.run(d); + Assert.IsTrue(np.array_equal(o, check)); + + d = tf.cast(tf.logical_xor(b, c), tf.int32); + check = np.array(new[] { 1, 1, 1, 1, 1, 1, 1, 1 }); + + sess = tf.Session(); + o = sess.run(d); + Assert.IsTrue(np.array_equal(o, check)); + } + + [TestMethod] + public void addOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = 2; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = firstIntFeed.Sum() + secondIntFeed.Sum(); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, intResult); + + // Testing `operator +(Tensor x, Tensor y)` + c = tf.reduce_sum(tf.reduce_sum(a + b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, intResult); + + // Testing `operator +(Tensor x, int y)` + c = tf.reduce_sum(tf.reduce_sum(a + secondIntVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, intResult); + + // Testing `operator +(int x, Tensor y)` + c = tf.reduce_sum(tf.reduce_sum(secondIntVal + a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, intResult); + #endregion + + #region floatTest + const float firstFloatVal = 2.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Sum() + secondFloatFeed.Sum(); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, floatResult); + + // Testing `operator +(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a + b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, floatResult); + + // Testing `operator +(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, floatResult); + + // Testing `operator +(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, floatResult); + #endregion + + #region doubleTest + const double firstDoubleVal = 2.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Sum() + secondDoubleFeed.Sum(); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator +(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a + b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, doubleResult); + + // Testing `operator +(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a + secondDoubleVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, doubleResult); + + // Testing `operator +(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(secondDoubleVal + a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual(o, doubleResult); + #endregion + } + + [TestMethod] + public void subOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = -2; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = firstIntFeed.Sum() - secondIntFeed.Sum(); + var intResultTwo = -firstIntFeed.Sum(); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator -(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a - b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator -(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(a - secondIntVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator -(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(secondIntVal - a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, Math.Abs(intResult)); + + // Testing `operator -(Tensor x) + c = tf.reduce_sum(tf.reduce_sum(-a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + #endregion + + #region floatTest + const float firstFloatVal = -2.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Sum() - secondFloatFeed.Sum(); + var floatResultTwo = -firstFloatFeed.Sum(); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator -(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a - b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator -(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator -(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, Math.Abs(floatResult)); + + // Testing `operator -(Tensor x) + c = tf.reduce_sum(tf.reduce_sum(-a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResultTwo); + #endregion + + #region doubleTest + const double firstDoubleVal = -2.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Sum() - secondDoubleFeed.Sum(); + var doubleResultTwo = -firstDoubleFeed.Sum(); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator -(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a - b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator -(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a - secondDoubleVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator -(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(secondDoubleVal - a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, Math.Abs(doubleResult)); + + // Testing `operator -(Tensor x) + c = tf.reduce_sum(tf.reduce_sum(-a, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResultTwo); + #endregion + } + + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if (first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if (first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if (first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + + [TestMethod] + public void mulOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = 2; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = MultiplyArray(firstIntFeed, secondIntFeed).Sum(); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator *(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(a * secondIntVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator *(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstIntVal * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + #endregion + + #region floatTest + const float firstFloatVal = 2.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed).Sum(); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator *(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator *(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + #endregion + + #region doubleTest + const double firstDoubleVal = 2.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed).Sum(); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator *(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a * secondDoubleVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator *(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstDoubleVal * b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + #endregion + } + + [Ignore] + [TestMethod] + public void divOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = 6; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = (int)(firstIntFeed.Sum() / (float)secondIntVal); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(gen_math_ops.floor_div(a, b), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator /(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(a / secondIntVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator /(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstIntVal / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + #endregion + + #region floatTest + const float firstFloatVal = 6.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed.Select(x => 1 / x).ToArray()).Sum(); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator /(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + + // Testing `operator /(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + #endregion + + #region doubleTest + const double firstDoubleVal = 6.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed.Select(x => 1 / x).ToArray()).Sum(); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator /(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + + // Testing `operator /(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + #endregion + } + + [TestMethod] + public void greaterThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem > intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem < intThreshold); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > intThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold > a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem > floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem < floatThreshold); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > floatThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold > a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem > doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem < doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > doubleThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold > a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + #endregion + } + + [TestMethod] + public void lessThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem < intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem > intThreshold); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < intThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold < a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem < floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem > floatThreshold); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < floatThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold < a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem < doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem > doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < doubleThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold < a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + #endregion + } + + [TestMethod] + public void greaterOrEqualThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem >= intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem <= intThreshold); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >=(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= intThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator >=(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold >= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem >= floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem <= floatThreshold); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >=(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= floatThreshold, tf.int32), 1)); + sess = tf.Session(); + sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator >=(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold >= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem >= doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem <= doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >=(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= doubleThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator >=(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold >= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + #endregion + } + + [TestMethod] + public void lessOrEqualThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem <= intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem >= intThreshold); + + var a = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new Shape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + var sess = tf.Session(); + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <=(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= intThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + + // Testing `operator <=(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold <= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem <= floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem >= floatThreshold); + + a = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <=(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= floatThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + + // Testing `operator <=(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold <= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem <= doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem >= doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new Shape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <=(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= doubleThreshold, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + + // Testing `operator <=(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold <= a, tf.int32), 1)); + sess = tf.Session(); + o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + #endregion + } + + [Ignore("Not finished yet")] + [TestMethod] + public void map_fn() + { + var a = tf.constant(new[] { 1, 2, 3, 4 }); + var b = tf.constant(new[] { 17, 12, 11, 10 }); + var ab = tf.stack(new[] { a, b }, 1); + + Func map_operation = (value_ab) => + { + var value_a = value_ab[0]; + var value_b = value_ab[1]; + return value_a + value_b; + }; + + var map_result = tf.map_fn(map_operation, ab); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs new file mode 100644 index 000000000..cc09b101d --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs @@ -0,0 +1,102 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; +using Tensorflow.Keras.UnitTest; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class SignalTest : EagerModeTestBase + { + [TestMethod] + public void fft() + { + double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; + double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_frequency_domain = tf.signal.fft(t_complex); + Tensor f_time_domain = tf.signal.ifft(t_frequency_domain); + + Tensor t_real_result = tf.math.real(f_time_domain); + Tensor t_imag_result = tf.math.imag(f_time_domain); + + NDArray n_real_result = t_real_result.numpy(); + NDArray n_imag_result = t_imag_result.numpy(); + + double[] d_real_result = n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + [TestMethod] + public void fft2d() + { + double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; + double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_complex_2d = tf.reshape(t_complex,new int[] { 2, 2 }); + + Tensor t_frequency_domain_2d = tf.signal.fft2d(t_complex_2d); + Tensor t_time_domain_2d = tf.signal.ifft2d(t_frequency_domain_2d); + + Tensor t_time_domain = tf.reshape(t_time_domain_2d, new int[] { 4 }); + + Tensor t_real_result = tf.math.real(t_time_domain); + Tensor t_imag_result = tf.math.imag(t_time_domain); + + NDArray n_real_result = t_real_result.numpy(); + NDArray n_imag_result = t_imag_result.numpy(); + + double[] d_real_result = n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + [TestMethod] + public void fft3d() + { + double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0, -3.0, -2.0, -1.0, -4.0 }; + double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0, 6.0, 4.0, 2.0, 0.0}; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_complex_3d = tf.reshape(t_complex, new int[] { 2, 2, 2 }); + + Tensor t_frequency_domain_3d = tf.signal.fft2d(t_complex_3d); + Tensor t_time_domain_3d = tf.signal.ifft2d(t_frequency_domain_3d); + + Tensor t_time_domain = tf.reshape(t_time_domain_3d, new int[] { 8 }); + + Tensor t_real_result = tf.math.real(t_time_domain); + Tensor t_imag_result = tf.math.imag(t_time_domain); + + NDArray n_real_result = t_real_result.numpy(); + NDArray n_imag_result = t_imag_result.numpy(); + + double[] d_real_result = n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj new file mode 100644 index 000000000..40dd53f74 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj @@ -0,0 +1,43 @@ + + + + net6.0 + 9.0 + false + TensorFlowNET.UnitTest + AnyCPU;x64 + + + + DEBUG;TRACE + true + + + + DEBUG;TRACE + false + + + + true + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + diff --git a/test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs b/test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs new file mode 100644 index 000000000..295bc0488 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs @@ -0,0 +1,177 @@ +using System; +using System.Diagnostics; +using System.Threading; + +namespace TensorFlowNET.UnitTest +{ + public delegate void MultiThreadedTestDelegate(int threadid); + + /// + /// Creates a synchronized eco-system of running code. + /// + public class MultiThreadedUnitTestExecuter : IDisposable + { + public int ThreadCount { get; } + public Thread[] Threads { get; } + public Exception[] Exceptions { get; } + private readonly SemaphoreSlim barrier_threadstarted; + private readonly ManualResetEventSlim barrier_corestart; + private readonly SemaphoreSlim done_barrier2; + + public Action PostRun { get; set; } + + #region Static + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workload); + } + + [DebuggerHidden] + public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workloads); + } + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action postRun) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (postRun == null) throw new ArgumentNullException(nameof(postRun)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount) { PostRun = postRun }.Run(workload); + } + + #endregion + + + /// Initializes a new instance of the class. + public MultiThreadedUnitTestExecuter(int threadCount) + { + if (threadCount <= 0) + throw new ArgumentOutOfRangeException(nameof(threadCount)); + ThreadCount = threadCount; + Threads = new Thread[ThreadCount]; + Exceptions = new Exception[ThreadCount]; + done_barrier2 = new SemaphoreSlim(0, threadCount); + barrier_corestart = new ManualResetEventSlim(); + barrier_threadstarted = new SemaphoreSlim(0, threadCount); + } + + [DebuggerHidden] + public void Run(params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) + throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length != 1 && workloads.Length % ThreadCount != 0) + throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads."); + + if (ThreadCount == 1) + { + Exception ex = null; + new Thread(() => + { + try + { + workloads[0](0); + } + catch (Exception e) + { + if (Debugger.IsAttached) + throw; + ex = e; + } + finally + { + done_barrier2.Release(1); + } + }).Start(); + + done_barrier2.Wait(); + + if (ex != null) + throw new Exception($"Thread 0 has failed: ", ex); + + PostRun?.Invoke(this); + + return; + } + + //thread core + Exception ThreadCore(MultiThreadedTestDelegate core, int threadid) + { + barrier_threadstarted.Release(1); + barrier_corestart.Wait(); + //workload + try + { + core(threadid); + } + catch (Exception e) + { + if (Debugger.IsAttached) + throw; + return e; + } + finally + { + done_barrier2.Release(1); + } + + return null; + } + + //initialize all threads + if (workloads.Length == 1) + { + var workload = workloads[0]; + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } + else + { + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + var workload = workloads[i_local % workloads.Length]; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } + + //run all threads + for (int i = 0; i < ThreadCount; i++) Threads[i].Start(); + //wait for threads to be started and ready + for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait(); + + //signal threads to start + barrier_corestart.Set(); + + //wait for threads to finish + for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); + + //handle fails + for (int i = 0; i < ThreadCount; i++) + if (Exceptions[i] != null) + throw new Exception($"Thread {i} has failed: ", Exceptions[i]); + + //checks after ended + PostRun?.Invoke(this); + } + + public void Dispose() + { + barrier_threadstarted.Dispose(); + barrier_corestart.Dispose(); + done_barrier2.Dispose(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs new file mode 100644 index 000000000..d1cda7286 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs @@ -0,0 +1,22 @@ +using System; +using System.IO; + +namespace TensorFlowNET.UnitTest +{ + public class TestHelper + { + public static string GetFullPathFromDataDir(string fileName) + { + var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); + return Path.Combine(dataDir, fileName); + } + + static string GetRootContentDir(string dir) + { + var path = Path.GetFullPath(Path.Combine(dir, "data")); + if (Directory.Exists(path)) + return path; + return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/fingerprint.pb new file mode 100644 index 000000000..c37cc37bd --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/fingerprint.pb @@ -0,0 +1 @@ +̟땐͉ Σ(ռ2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/keras_metadata.pb new file mode 100644 index 000000000..5fe8f1a65 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/keras_metadata.pb @@ -0,0 +1,7 @@ + +&root"_tf_keras_sequential*&{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 5, 3]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 5, 3]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 5, 3]}, "float32", "input_1"]}, "keras_version": "2.12.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 5}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 6}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 7}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 8}]}}, "training_config": {"loss": "binary_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "binary_accuracy"}, "shared_object_id": 11}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2 + root.layer_with_weights-0"_tf_keras_rnn_layer* {"name": "lstm", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 5, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, null, 3]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}, "shared_object_id": 12}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 5, 3]}}2 +root.layer_with_weights-1"_tf_keras_layer*{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 6}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 7}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 8, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 32}}, "shared_object_id": 13}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 32]}}2 +root.layer_with_weights-0.cell"_tf_keras_layer*{"name": "lstm_cell", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "LSTMCell", "config": {"name": "lstm_cell", "trainable": true, "dtype": "float32", "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 4, "build_input_shape": {"class_name": "__tuple__", "items": [null, 3]}}2 +Rroot.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 14}2 +Sroot.keras_api.metrics.1"_tf_keras_metric*{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "binary_accuracy"}, "shared_object_id": 11}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/saved_model.pb new file mode 100644 index 000000000..618c800eb Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/saved_model.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..ea67db4f4 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.data-00000-of-00001 differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.index new file mode 100644 index 000000000..11f13d165 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/lstm_from_sequential/variables/variables.index differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb new file mode 100644 index 000000000..361ca3a8a Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb new file mode 100644 index 000000000..b98e17337 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb @@ -0,0 +1,6 @@ + +root"_tf_keras_sequential*{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}}]}, "shared_object_id": 3, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 784]}, "ndim": 2, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2}]}}, "training_config": {"loss": "sparse_categorical_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2 +root.layer_with_weights-0"_tf_keras_layer*{"name": "transformer", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 + root.layer-1"_tf_keras_layer*{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 +9root.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 6}2 +:root.keras_api.metrics.1"_tf_keras_metric*{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb new file mode 100644 index 000000000..f22755e07 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..399265af6 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index new file mode 100644 index 000000000..e0b0e800a Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy new file mode 100644 index 000000000..b5a8f8b32 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb new file mode 100644 index 000000000..b62a57c3d Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb new file mode 100644 index 000000000..e1aab781a --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb @@ -0,0 +1,9 @@ + +$root"_tf_keras_network*${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2 + root.layer-0"_tf_keras_input_layer*{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2 + root.layer-1"_tf_keras_layer*{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2 +root.layer_with_weights-0"_tf_keras_layer*{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 +root.layer_with_weights-1"_tf_keras_layer*{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2 + root.layer-4"_tf_keras_layer*{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 +Troot.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2 +Uroot.keras_api.metrics.1"_tf_keras_metric*{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy new file mode 100644 index 000000000..dd70331cf Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb new file mode 100644 index 000000000..771a58c62 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..0061f3865 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index new file mode 100644 index 000000000..06ba4b293 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs new file mode 100644 index 000000000..29648790f --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs @@ -0,0 +1,71 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Generic; +using Tensorflow.Keras.Callbacks; +using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + + +namespace Tensorflow.Keras.UnitTest.Callbacks +{ + [TestClass] + public class EarlystoppingTest + { + [TestMethod] + // Because loading the weight variable into the model has not yet been implemented, + // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. + public void Earlystopping() + { + var layers = keras.layers; + var model = keras.Sequential(new List + { + layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)), + layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), + layers.MaxPooling2D(), + layers.Flatten(), + layers.Dense(128, activation: keras.activations.Relu), + layers.Dense(10) + }); + + + model.summary(); + + model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), + loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), + metrics: new[] { "acc" }); + + var num_epochs = 3; + var batch_size = 8; + + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59900, + }).Result; + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + NDArray x2 = x1; + + var x = new NDArray[] { x1, x2 }; + + // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. + CallbackParams callback_parameters = new CallbackParams + { + Model = model, + Epochs = num_epochs, + }; + // define your earlystop + ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); + // define a callbcaklist, then add the earlystopping to it. + var callbacks = new List{ earlystop}; + model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks); + } + + } + + +} + diff --git a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs new file mode 100644 index 000000000..635f13a54 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs @@ -0,0 +1,84 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest +{ + public class EagerModeTestBase + { + [TestInitialize] + public void TestInit() + { + if (!tf.executing_eagerly()) + tf.enable_eager_execution(); + tf.Context.ensure_initialized(); + } + + [TestCleanup] + public void TestClean() + { + } + + public bool Equal(float[] f1, float[] f2) + { + bool ret = false; + var tolerance = .000001f; + for (var i = 0; i < f1.Length; i++) + { + ret = Math.Abs(f1[i] - f2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } + + + public void AssertArray(int[] f1, int[] f2) + { + bool ret = false; + for (var i = 0; i < f1.Length; i++) + { + ret = f1[i] == f2[i]; + if (!ret) + break; + } + + if (!ret) + { + Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); + } + } + + public void AssertArray(float[] f1, float[] f2) + { + bool ret = false; + var tolerance = .00001f; + for (var i = 0; i < f1.Length; i++) + { + ret = Math.Abs(f1[i] - f2[i]) <= tolerance; + if (!ret) + break; + } + + if (!ret) + { + Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); + } + } + + public bool Equal(double[] d1, double[] d2) + { + bool ret = false; + var tolerance = .000000000000001f; + for (var i = 0; i < d1.Length; i++) + { + ret = Math.Abs(d1[i] - d2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs b/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs new file mode 100644 index 000000000..162aa1c5e --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs @@ -0,0 +1,75 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest; + +[TestClass] +public class GradientTest : EagerModeTestBase +{ + public IModel get_actor(int num_states) + { + var inputs = tf.keras.layers.Input(shape: num_states); + var outputs = tf.keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs); + + var model = tf.keras.Model(inputs, outputs); + + return model; + } + + public IModel get_critic(int num_states, int num_actions) + { + // State as input + var state_input = keras.layers.Input(shape: num_states); + + // Action as input + var action_input = keras.layers.Input(shape: num_actions); + + var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input)); + + var outputs = keras.layers.Dense(1).Apply(concat); + + var model = tf.keras.Model(new Tensors(state_input, action_input), outputs); + model.summary(); + + return model; + } + + [TestMethod] + public void GetGradientTest() + { + var numStates = 3; + var numActions = 1; + var batchSize = 64; + var gamma = 0.99f; + + var target_actor_model = get_actor(numStates); + var target_critic_model = get_critic(numStates, numActions); + var critic_model = get_critic(numStates, numActions); + + Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT); + Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT); + Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT); + Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT); + + using (var tape = tf.GradientTape()) + { + var target_actions = target_actor_model.Apply(next_state_batch, training: true); + var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true); + + var y = reward_batch + tf.multiply(gamma, target_critic_value); + + var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true); + + var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value)); + + var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables); + + Assert.IsNotNull(critic_grad); + Assert.IsNotNull(critic_grad.First()); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs new file mode 100644 index 000000000..e145ce585 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.UnitTest.Helpers +{ + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs b/test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs new file mode 100644 index 000000000..256eb69c1 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs @@ -0,0 +1,33 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Layers; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class InitLayerNameTest + { + [TestMethod] + public void RNNLayerNameTest() + { + var simpleRnnCell = keras.layers.SimpleRNNCell(1); + Assert.AreEqual("simple_rnn_cell", simpleRnnCell.Name); + var simpleRnn = keras.layers.SimpleRNN(2); + Assert.AreEqual("simple_rnn", simpleRnn.Name); + var lstmCell = keras.layers.LSTMCell(2); + Assert.AreEqual("lstm_cell", lstmCell.Name); + var lstm = keras.layers.LSTM(3); + Assert.AreEqual("lstm", lstm.Name); + } + + [TestMethod] + public void ConvLayerNameTest() + { + var conv2d = keras.layers.Conv2D(8, activation: "linear"); + Assert.AreEqual("conv2d", conv2d.Name); + var conv2dTranspose = keras.layers.Conv2DTranspose(8); + Assert.AreEqual("conv2d_transpose", conv2dTranspose.Name); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs new file mode 100644 index 000000000..b26b69309 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs @@ -0,0 +1,15 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest; + +[TestClass] +public class InitializerTest : EagerModeTestBase +{ + [TestMethod] + public void Orthogonal() + { + var initializer = tf.keras.initializers.Orthogonal(); + var values = initializer.Apply(new Tensorflow.InitializerArgs((2, 2))); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs new file mode 100644 index 000000000..cc99f4a04 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs @@ -0,0 +1,107 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class ActivationTest : EagerModeTestBase + { + [TestMethod] + public void LeakyReLU() + { + var layer = keras.layers.LeakyReLU(); + Tensor output = layer.Apply(np.array(-3.0f, -1.0f, 0.0f, 2.0f)); + Equal(new[] { -0.9f, -0.3f, 0.0f, 2.0f }, output.ToArray()); + } + + [TestMethod] + public void ELU() + { + Tensors input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.ELU().Apply(input); + NDArray expected = new NDArray(new float[] { -0.0950213f, -0.08646648f, -0.06321206f, 0f, 1f, 2f }); + Assert.AreEqual(expected.numpy(), output.numpy()); + } + + [TestMethod] + public void SELU() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.SELU().Apply(input); + NDArray expected = new NDArray(new float[] { -1.6705688f, -1.5201665f, -1.1113307f, 0f, 1.050701f, 2.101402f }); + Assert.AreEqual(expected.numpy(), output.numpy()); + } + + [TestMethod] + public void Softmax() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.Softmax(new Axis(-1)).Apply(input); + var expected = new float[] { 0.0042697787f, 0.011606461f, 0.031549633f, 0.085760795f, 0.23312202f, 0.6336913f }; + Assert.IsTrue(Equal(expected, output.ToArray())); + } + + [TestMethod] + public void Softplus() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.Softplus().Apply(input); + NDArray expected = new NDArray(new float[] { 0.04858733f, 0.12692805f, 0.31326166f, 0.6931472f, 1.3132616f, 2.126928f }); + Assert.IsTrue(expected == output.numpy()); + } + + [TestMethod] + public void Softsign() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.Softsign().Apply(input); + NDArray expected = new NDArray(new float[] { -0.75f, -0.66666667f, -0.5f, 0f, 0.5f, 0.66666667f }); + Assert.AreEqual(expected, output.numpy()); + } + + + [TestMethod] + public void Exponential() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.Exponential().Apply(input); + var expected = new float[] { 0.049787067f, 0.13533528f, 0.36787945f, 1f, 2.7182817f, 7.389056f }; + Assert.IsTrue(Equal(expected, output.ToArray())); + } + + [TestMethod] + public void HardSigmoid() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.HardSigmoid().Apply(input); + // Note, this should be [0, 0.1, 0.3, 0.5, 0.7, 0.9] + // But somehow the second element will have 0.099999994 + // Probably because there is an accuracy loss somewhere + NDArray expected = new NDArray(new float[] { 0f, 0.099999994f, 0.3f, 0.5f, 0.7f, 0.9f }); + Assert.AreEqual(expected, output.numpy()); + } + + + [TestMethod] + public void Swish() + { + Tensor input = tf.constant(new float[] { -3f, -2f, -1f, 0f, 1f, 2f }); + Tensor output = keras.layers.Swish().Apply(input); + NDArray expected = new NDArray(new float[] { -0.14227762f, -0.23840584f, -0.26894143f, 0f, 0.7310586f, 1.761594f }); + Assert.AreEqual(expected, output.numpy()); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/activations/mish + /// + [TestMethod] + public void Mish() + { + var x = tf.constant(new[] { 1.0, 0.0, 1.0 }, dtype: tf.float32); + var output = keras.activations.Mish.Apply(x); + Assert.AreEqual(new[] { 0.86509836f, 0f, 0.86509836f }, output.numpy()); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs new file mode 100644 index 000000000..95ef923eb --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs @@ -0,0 +1,174 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class AttentionTest : EagerModeTestBase + { + #region BaseDenseAttention + + [TestMethod] + public void test_multi_dim_with_mask() + { + // Scores tensor of shape [1, 1, 3] + var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); + // Value tensor of shape [1, 3, 1] + var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); + // Scores mask tensor of shape [1, 1, 3] + var scores_mask = np.array(new[, ,] { { { true, true, false } } }, dtype: np.@bool); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax scores = softmax(scores) with zeros in positions where + // v_mask == False. + // => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863 + // softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137 + // softmax_scores002 = 0 + var expected_scores = np.array(new[, ,] { { { 0.73105857863f, 0.26894142137f, 0f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [1, 1, 1]. + // expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8 + // = 1.35795272077 + //Actually the output is 1.3579528 + var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + + [TestMethod] + public void test_one_dim_batch_size_two() + { + // Scores tensor of shape [2, 1, 1] + var scores = np.array(new[, ,] { { { 1.1f } }, { { 2.1f } } }, dtype: np.float32); + // Value tensor of shape [2, 1, 1] + var v = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); + // Scpres mask tensor of shape [2, 1, 1] + var scores_mask = np.array(new[, ,] { { { true } }, { { true } } }, dtype: np.@bool); + var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); + var actual = _tup_1.Item1; + var actual_scores = _tup_1.Item2; + // Expected softmax_scores = [[[1]], [[1]]] + var expected_scores = np.array(new[, ,] { { { 1f } }, { { 1f } } }, dtype: np.float32); + Assert.AreEqual(expected_scores, actual_scores.numpy()); + // Expected tensor of shape [2, 1, 1]. + // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 + // expected100 = softmax_scores[1, 0] * 2.6 = 2.6 + var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + #endregion + // ------------------------------------------------------------------ + #region Attention + + + [TestMethod] + public void test_calculate_scores_multi_dim() + { + // Query tensor of shape [1, 2, 4] + var q = np.array(new[, ,] { { + { 1f, 1.1f, 1.2f, 1.3f }, + { 2f, 2.1f, 2.2f, 2.3f } + } }, dtype: np.float32); + // Key tensor of shape [1, 3, 4] + var k = np.array(new[, ,] { { + { 1.5f, 1.6f, 1.7f, 1.8f }, + { 2.5f, 2.6f, 2.7f, 2.8f }, + { 3.5f, 3.6f, 3.7f, 3.8f } + } }, dtype: np.float32); + var attention_layer = (Attention)keras.layers.Attention(); + //attention_layer.build(((1, 2, 4), (1, 3, 4))); + var actual = attention_layer._calculate_scores(query: q, key: k); + // Expected tensor of shape [1, 2, 3]. + // expected000 = 1.*1.5+1.1*1.6+1.2*1.7+1.3*1.8 = 7.64 + // expected001 = 1.*2.5+1.1*2.6+1.2*2.7+1.3*2.8 = 12.24 + // expected002 = 1.*3.5+1.1*3.6+1.2*3.7+1.3*3.8 = 16.84 + // expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24 + // expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84 + // expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44 + // Actually the output000 is 7.6400003, the output012 is 31.439999 + var expected = np.array(new[, ,] { { + { 7.6400003f, 12.24f, 16.84f }, + { 14.24f, 22.84f, 31.439999f } + } }, dtype: np.float32); + Assert.IsTrue(expected == actual.numpy()); + } + + [TestMethod] + [Ignore] + public void test_calculate_scores_multi_dim_concat() + { + // Query tensor of shape [1, 2, 4] + var q = np.array(new[, ,] { { + { 1f, 1.1f, 1.2f, 1.3f }, + { 2f, 2.1f, 2.2f, 2.3f } + } }, dtype: np.float32); + // Key tensor of shape [1, 3, 4] + var k = np.array(new[, ,] { { + { 1.5f, 1.6f, 1.7f, 1.8f }, + { 2.5f, 2.6f, 2.7f, 2.8f }, + { 3.5f, 3.6f, 3.7f, 3.8f } + } }, dtype: np.float32); + var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); + //attention_layer.concat_score_weight = 1; + attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() + { + Name = "concat_score_weight", + Shape = (1), + DType = TF_DataType.TF_FLOAT, + Getter = base_layer_utils.make_variable, + Overwrite = true, + Initializer = tf.ones_initializer, + Synchronization = VariableSynchronization.Auto, + Aggregation = VariableAggregation.None, + Trainable = true + }); + //attention_layer.build(((1, 2, 4), (1, 3, 4))); + //var actual = keras.backend.get_value(attention_layer._calculate_scores(query: q, key: k)); + var actual = attention_layer._calculate_scores(query: q, key: k); + // pylint:disable=line-too-long + // expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840 + // expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825 + // expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147 + // expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825 + // expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147 + // expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657 + //Actually the output012 is 3.9999194 + var expected = np.array(new[, ,] { { + { 3.96753427840f, 3.99558784825f, 3.99940254147f }, + { 3.99558784825f, 3.99940254147f, 3.9999194f } + } }, dtype: np.float32); + Assert.AreEqual(expected, actual.numpy()); + } + #endregion + // ------------------------------------------------------------------ + #region MultiHeadAttention + [TestMethod] + public void test_masked_attention() + { + var batch_size = 3; + + var query = keras.Input(shape: (4, 8)); + var value = keras.Input(shape: (2, 8)); + var mask_tensor = keras.Input(shape: (4, 2)); + var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); + attention_layer.Apply(new Tensor[] { query, value, mask_tensor }); + + var from_data = 10 * np.random.randn(batch_size, 4, 8); + var to_data = 10 * np.random.randn(batch_size, 2, 8); + + var mask_data = np.random.randint(2, size: (batch_size, 4, 2)); + var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); + + var null_mask_data = np.ones((batch_size, 4, 2)); + var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); + + Assert.AreNotEqual(masked_output_data, unmasked_output_data); + } + #endregion + } + +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs new file mode 100644 index 000000000..5294a838c --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs @@ -0,0 +1,74 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class CosineSimilarity + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 1.0f, 1.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1) + //>>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] + //>>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] + //>>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] + //>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) + //>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 + //-0.5 + var loss = keras.losses.CosineSimilarity(axis: 1); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(-0.49999997f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> cosine_loss(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() + //- 0.0999 + var loss = keras.losses.CosineSimilarity(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); + Assert.AreEqual((NDArray)(-0.099999994f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> cosine_loss(y_true, y_pred).numpy() + //- 0.999 + var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(-0.99999994f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> cosine_loss(y_true, y_pred).numpy() + //array([-0., -0.999], dtype = float32) + var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs new file mode 100644 index 000000000..7bf5f5191 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs @@ -0,0 +1,70 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class Huber + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> h = tf.keras.losses.Huber() + //>>> h(y_true, y_pred).numpy() + //0.155 + var loss = keras.losses.Huber(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.155f, call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> h(y_true, y_pred, sample_weight =[1, 0]).numpy() + //0.09 + var loss = keras.losses.Huber(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.1f, 0.0f }); + Assert.AreEqual((NDArray)0.009000001f, call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> h = tf.keras.losses.Huber( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> h(y_true, y_pred).numpy() + //0.31 + var loss = keras.losses.Huber(reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.31f, call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> h = tf.keras.losses.Huber( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> h(y_true, y_pred).numpy() + //array([0.18, 0.13], dtype = float32) + var loss = keras.losses.Huber(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.18f, 0.13000001f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs new file mode 100644 index 000000000..15c6e80fe --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs @@ -0,0 +1,322 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class LayersConvolutionTest : EagerModeTestBase + { + [TestMethod] + public void BasicConv1D() + { + var filters = 8; + + var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(y.shape, (8, 6, 8)); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv1D_ksize() + { + var filters = 8; + + var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(3, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 2, y.shape[1]); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv1D_ksize_same() + { + var filters = 8; + + var conv = keras.layers.Conv1D(filters, kernel_size: 3, padding: "same", activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(3, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1], y.shape[1]); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv1D_ksize_strides() + { + var filters = 8; + var conv = keras.layers.Conv1D(filters, kernel_size: 3, strides: 2, activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(3, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 5, y.shape[1]); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv1D_ksize_dilations() + { + var filters = 8; + var conv = keras.layers.Conv1D(filters, kernel_size: 3, dilation_rate: 2, activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(3, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 4, y.shape[1]); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv1D_ksize_dilation_same() + { + var filters = 8; + var conv = keras.layers.Conv1D(filters, kernel_size: 3, dilation_rate: 2, padding: "same", activation: "linear"); + + var x = np.arange(256.0f).reshape((8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(3, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1], y.shape[1]); + Assert.AreEqual(filters, y.shape[2]); + } + + [TestMethod] + public void BasicConv2D() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 4, y.shape[1]); + Assert.AreEqual(x.dims[2] - 4, y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + [TestMethod] + public void BasicConv2D_ksize() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, kernel_size: 3, activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 2, y.shape[1]); + Assert.AreEqual(x.dims[2] - 2, y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + [TestMethod] + public void BasicConv2D_ksize_same() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, kernel_size: 3, padding: "same", activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1], y.shape[1]); + Assert.AreEqual(x.dims[2], y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + [TestMethod] + public void BasicConv2D_ksize_strides() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, kernel_size: 3, strides: 2, activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 5, y.shape[1]); + Assert.AreEqual(x.dims[2] - 5, y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + [TestMethod] + public void BasicConv2D_ksize_dilation() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, kernel_size: 3, dilation_rate: 2, activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1] - 4, y.shape[1]); + Assert.AreEqual(x.dims[2] - 4, y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + [TestMethod] + public void BasicConv2D_ksize_dilation_same() + { + var filters = 8; + var conv = keras.layers.Conv2D(filters, kernel_size: 3, dilation_rate: 2, padding: "same", activation: "linear"); + + var x = np.arange(256.0f).reshape((1, 8, 8, 4)); + var y = conv.Apply(x); + + Assert.AreEqual(4, y.shape.ndim); + Assert.AreEqual(x.dims[0], y.shape[0]); + Assert.AreEqual(x.dims[1], y.shape[1]); + Assert.AreEqual(x.dims[2], y.shape[2]); + Assert.AreEqual(filters, y.shape[3]); + } + + + [TestMethod] + public void BasicDepthwiseConv2D() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null, + padding:"same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 9, 9, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2457f, 2466f, 2475f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 59.97002f, delta); + Assert.AreEqual(arr[1], 63.96802f, delta); + } + + + [TestMethod] + public void BasicDepthwiseConv2D_strides_2() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null, + padding: "same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 5, 5, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2727f, 2736f, 2745f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 59.97002f, delta); + Assert.AreEqual(arr[1], 63.96802f, delta); + } + + + + [TestMethod] + public void BasicDepthwiseConv2D_strides_3() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null, + padding: "same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 3, 3, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 3267f, 3276f, 3285f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 269.86508f, delta); + Assert.AreEqual(arr[1], 278.8606f, delta); + + } + [TestMethod] + public void BasicDepthwiseConv2D_UseBias() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null, + use_bias: true, padding: "same", + depthwise_initializer: "ones", + bias_initializer:"ones" + ); + + var weight = conv.get_weights(); + + var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + var y = conv.Apply(x2); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().ToArray(); + + Assert.AreEqual(arr[0], 61f); + Assert.AreEqual(arr[1], 65f); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 60.96952f, delta); + Assert.AreEqual(arr[1], 64.96752f, delta); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Cropping.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Cropping.Test.cs new file mode 100644 index 000000000..b7981facb --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Cropping.Test.cs @@ -0,0 +1,43 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class LayersCroppingTest : EagerModeTestBase + { + [TestMethod] + public void Cropping1D() + { + Shape input_shape = (1, 5, 2); + var x = tf.zeros(input_shape); + var cropping_1d = keras.layers.Cropping1D(new[] { 1, 2 }); + var y = cropping_1d.Apply(x); + Assert.AreEqual((1, 2, 2), y.shape); + } + + [TestMethod] + public void Cropping2D() + { + Shape input_shape = (1, 5, 6, 1); + NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 } }); + var x = tf.zeros(input_shape); + var cropping_2d = keras.layers.Cropping2D(cropping); + var y = cropping_2d.Apply(x); + Assert.AreEqual((1, 2, 2, 1), y.shape); + } + + [TestMethod] + public void Cropping3D() + { + Shape input_shape = new Shape(1, 5, 6, 7, 1); + NDArray cropping = new NDArray(new[,] { { 1, 2 }, { 1, 3 }, { 1, 4 } }); + var x = tf.zeros(input_shape); + var cropping_3d = keras.layers.Cropping3D(cropping); + var y = cropping_3d.Apply(x); + Assert.AreEqual(new Shape(1, 2, 2, 2, 1), y.shape); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs new file mode 100644 index 000000000..9bc2fa767 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Generic; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class LayersMergingTest : EagerModeTestBase + { + [TestMethod] + [DataRow(1, 4, 1, 5)] + [DataRow(2, 2, 2, 5)] + [DataRow(3, 2, 1, 10)] + public void Concatenate(int axis, int shapeA, int shapeB, int shapeC) + { + var x = np.arange(10).reshape((1, 2, 1, 5)); + var y = np.arange(10, 20).reshape((1, 2, 1, 5)); + var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y)); + Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Reshaping.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Reshaping.Test.cs new file mode 100644 index 000000000..5b16cc908 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Reshaping.Test.cs @@ -0,0 +1,58 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class LayersReshapingTest : EagerModeTestBase + { + [TestMethod] + public void ZeroPadding2D() + { + Shape input_shape = (1, 1, 2, 2); + var x = np.arange(input_shape.size).reshape(input_shape); + var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } }); + var y = zero_padding_2d.Apply(x); + Assert.AreEqual((1, 2, 3, 2), y.shape); + } + + [TestMethod] + public void UpSampling1D() + { + Shape input_shape = (2, 2, 3); + var x = np.arange(input_shape.size).reshape(input_shape); + var y = tf.keras.layers.UpSampling1D(size: 2).Apply(x); + Assert.AreEqual((2, 4, 3), y.shape); + } + + [TestMethod] + public void UpSampling2D() + { + Shape input_shape = (2, 2, 1, 3); + var x = np.arange(input_shape.size).reshape(input_shape); + var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x); + Assert.AreEqual((2, 2, 2, 3), y.shape); + } + + [TestMethod] + public void Reshape() + { + var inputs = tf.zeros((10, 5, 20)); + var outputs = keras.layers.LeakyReLU().Apply(inputs); + outputs = keras.layers.Reshape((20, 5)).Apply(outputs); + Assert.AreEqual((10, 20, 5), outputs.shape); + } + + [TestMethod] + public void Permute() + { + var inputs = tf.zeros((2, 3, 4, 5)); + var outputs = keras.layers.Permute(new int[] { 3, 2, 1 }).Apply(inputs); + Assert.AreEqual((2, 5, 4, 3), outputs.shape); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs new file mode 100644 index 000000000..7ebb53db3 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -0,0 +1,303 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + /// + /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers + /// + [TestClass] + public class LayersTest : EagerModeTestBase + { + [TestMethod] + public void AveragePooling2D() + { + var x = tf.constant(new float[,] + { + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } + }); + x = tf.reshape(x, (1, 3, 3, 1)); + var avg_pool_2d = keras.layers.AveragePooling2D(pool_size: (2, 2), + strides: (1, 1), padding: "valid"); + Tensor avg = avg_pool_2d.Apply(x); + Assert.AreEqual((1, 2, 2, 1), avg.shape); + Equal(new float[] { 3, 4, 6, 7 }, avg.ToArray()); + } + + [TestMethod] + public void InputLayer() + { + var model = keras.Sequential(new List + { + keras.layers.InputLayer(input_shape: 4), + keras.layers.Dense(8) + }); + model.compile(optimizer: keras.optimizers.RMSprop(0.001f), + loss: keras.losses.MeanSquaredError(), + metrics: new[] { "accuracy" }); + model.fit(np.zeros((10, 4), dtype: tf.float32), np.ones((10, 8), dtype: tf.float32)); + } + + [TestMethod] + public void Sequential() + { + var model = keras.Sequential(); + model.add(keras.Input(shape: 16)); + } + + [TestMethod] + public void Functional() + { + var layers = keras.layers; + + var inputs = keras.Input(shape: 784); + Assert.AreEqual((-1, 784), inputs.shape); + + var dense = layers.Dense(64, activation: keras.activations.Relu); + var x = dense.Apply(inputs); + + x = layers.Dense(64, activation: keras.activations.Relu).Apply(x); + var outputs = layers.Dense(10).Apply(x); + + var model = keras.Model(inputs, outputs, name: "mnist_model"); + model.summary(); + } + + /// + /// Custom layer test, used in Dueling DQN + /// + [TestMethod, Ignore] + public void TensorFlowOpLayer() + { + var layers = keras.layers; + var inputs = layers.Input(shape: 24); + var x = layers.Dense(128, activation: "relu").Apply(inputs); + var value = layers.Dense(24).Apply(x); + var adv = layers.Dense(1).Apply(x); + + var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); + adv = layers.Subtract().Apply((adv, mean)); + var outputs = layers.Add().Apply((value, adv)); + var model = keras.Model(inputs, outputs); + model.compile(optimizer: keras.optimizers.RMSprop(0.001f), + loss: keras.losses.MeanSquaredError(), + metrics: new[] { "acc" }); + model.summary(); + Assert.AreEqual(model.Layers.Count, 8); + var result = model.predict(tf.constant(np.arange(24).astype(np.float32)[np.newaxis, Slice.All])); + Assert.AreEqual(result.shape, new Shape(1, 24)); + model.fit(np.arange(24).astype(np.float32)[np.newaxis, Slice.All], np.arange(24).astype(np.float32)[np.newaxis, Slice.All], verbose: 0); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding + /// + [TestMethod] + public void Embedding() + { + var model = keras.Sequential(); + var layer = keras.layers.Embedding(1000, 64, input_length: 10); + model.add(layer); + var input_array = np.random.randint(1000, size: (32, 10)); + model.compile("rmsprop", "mse", new[] { "accuracy" }); + var output_array = model.predict(input_array); + Assert.AreEqual((32, 10, 64), output_array.shape); + } + [TestMethod] + public void EmbeddingGrad() + { + var inputs = keras.layers.Input(shape: new[] { 32, 10 }); + var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs); + var model = keras.Model(inputs: inputs, outputs: outputs); + var input_array = np.random.randint(1000, size: (1, 32, 10)); + var output_array = np.random.random(size: (1, 32, 10, 64)); + model.compile("rmsprop", "mse", new[] { "accuracy" }); + model.fit(input_array, output_array); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense + /// + [TestMethod] + public void Dense() + { + // Create a `Sequential` model and add a Dense layer as the first layer. + var model = keras.Sequential(); + model.add(keras.Input(shape: 16)); + model.add(keras.layers.Dense(32, activation: keras.activations.Relu)); + // Now the model will take as input arrays of shape (None, 16) + // and output arrays of shape (None, 32). + // Note that after the first layer, you don't need to specify + // the size of the input anymore: + model.add(keras.layers.Dense(32)); + Assert.AreEqual((-1, 32), model.output_shape); + } + + [TestMethod] + public void EinsumDense() + { + var ed = keras.layers.EinsumDense( + equation: "...b,bc->...c", + output_shape: 4, + bias_axes: "c", + bias_initializer: tf.constant_initializer(0.03), + kernel_initializer: tf.constant_initializer(0.5) + ); + var inp = np.array(new[,] { { 1f, 2f }, { 3f, 4f } }); + var expected_output = np.array(new[,] {{1.53f, 1.53f, 1.53f, 1.53f }, + { 3.53f, 3.53f, 3.53f, 3.53f }}); + var actual_output = ed.Apply(inp)[0].numpy(); + Assert.AreEqual(expected_output, actual_output); + } + + [TestMethod] + public void Resizing() + { + var inputs = tf.random.uniform((10, 32, 32, 3)); + var layer = keras.layers.preprocessing.Resizing(16, 16); + var output = layer.Apply(inputs); + Assert.AreEqual((10, 16, 16, 3), output.shape); + } + + [TestMethod] + public void LayerNormalization() + { + var inputs = tf.constant(np.arange(10).reshape((5, 2)) * 10, dtype: tf.float32); + var layer = keras.layers.LayerNormalization(axis: 1); + Tensor output = layer.Apply(inputs); + Assert.AreEqual((5, 2), output.shape); + Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); + + // test_layernorm_weights + Assert.AreEqual(len(layer.TrainableWeights), 2); + Assert.AreEqual(len(layer.Weights), 2); + + var beta = layer.Weights.Where(x => x.Name.StartsWith("beta")).Single(); + var gamma = layer.Weights.Where(x => x.Name.StartsWith("gamma")).Single(); + + // correctness_test + layer = keras.layers.LayerNormalization(axis: -1, epsilon: (float) 1e-12); + var x = np.random.normal(loc: 5.0f, scale: 10.0f, size: (1000, 2, 2, 2)).astype(tf.float32); + + output = layer.Apply(x); + + var y = (output - beta.numpy()) / gamma.numpy(); + + var y_mean = np.mean(y.numpy()); + var y_std = np.sqrt(np.sum(np.power(y.numpy() - np.mean(y.numpy()), 2)) / 8000); + Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_std - 1.0)).ToArray()[0]); + Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_mean)).ToArray()[0]); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization + /// + [TestMethod] + public void Normalization() + { + // Calculate a global mean and variance by analyzing the dataset in adapt(). + var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f }); + var input_data = np.array(new[] { 1f, 2f, 3f }); + var layer = tf.keras.layers.Normalization(axis: null); + layer.adapt(adapt_data); + var x = layer.Apply(input_data); + Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f }); + + // Calculate a mean and variance for each index on the last axis. + adapt_data = np.array(new[,] + { + { 0, 7, 4 }, + { 2, 9, 6 }, + { 0, 7, 4 }, + { 2, 9, 6 } + }, dtype: tf.float32); + input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(axis: -1); + layer.adapt(adapt_data); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -1f, -1f, -1f }); + + // Pass the mean and variance directly. + input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -1.4142135f, -0.70710677f, 0f }); + + // Use the layer to de-normalize inputs (after adapting the layer). + adapt_data = np.array(new[,] + { + { 0, 7, 4 }, + { 2, 9, 6 }, + { 0, 7, 4 }, + { 2, 9, 6 } + }, dtype: tf.float32); + input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32); + layer = tf.keras.layers.Normalization(axis: -1, invert: true); + layer.adapt(adapt_data); + x = layer.Apply(input_data); + Equal(x.numpy().ToArray(), new[] { -2f, -10f, -8f }); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding + /// + [TestMethod] + public void CategoryEncoding() + { + // one-hot + var inputs = np.array(new[] { 3, 2, 0, 1 }); + var layer = tf.keras.layers.CategoryEncoding(4); + + Tensor output = layer.Apply(inputs); + Assert.AreEqual((4, 4), output.shape); + Assert.IsTrue(output[0].numpy().Equals(new[] { 0, 0, 0, 1f })); + Assert.IsTrue(output[1].numpy().Equals(new[] { 0, 0, 1, 0f })); + Assert.IsTrue(output[2].numpy().Equals(new[] { 1, 0, 0, 0f })); + Assert.IsTrue(output[3].numpy().Equals(new[] { 0, 1, 0, 0f })); + + // multi-hot + inputs = np.array(new[,] + { + { 0, 1 }, + { 0, 0 }, + { 1, 2 }, + { 3, 1 } + }); + layer = tf.keras.layers.CategoryEncoding(4, output_mode: "multi_hot"); + output = layer.Apply(inputs); + Assert.IsTrue(output[0].numpy().Equals(new[] { 1, 1, 0, 0f })); + Assert.IsTrue(output[1].numpy().Equals(new[] { 1, 0, 0, 0f })); + Assert.IsTrue(output[2].numpy().Equals(new[] { 0, 1, 1, 0f })); + Assert.IsTrue(output[3].numpy().Equals(new[] { 0, 1, 0, 1f })); + + // using weighted inputs in "count" mode + inputs = np.array(new[,] + { + { 0, 1 }, + { 0, 0 }, + { 1, 2 }, + { 3, 1 } + }); + var weights = np.array(new[,] + { + { 0.1f, 0.2f }, + { 0.1f, 0.1f }, + { 0.2f, 0.3f }, + { 0.4f, 0.2f } + }); + layer = tf.keras.layers.CategoryEncoding(4, output_mode: "count", count_weights: weights); + output = layer.Apply(inputs); + Assert.IsTrue(output[0].numpy().Equals(new[] { 0.1f, 0.2f, 0f, 0f })); + Assert.IsTrue(output[1].numpy().Equals(new[] { 0.2f, 0f, 0f, 0f })); + Assert.IsTrue(output[2].numpy().Equals(new[] { 0f, 0.2f, 0.3f, 0f })); + Assert.IsTrue(output[3].numpy().Equals(new[] { 0f, 0.2f, 0f, 0.4f })); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LogCosh.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LogCosh.Test.cs new file mode 100644 index 000000000..9bfd28b43 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LogCosh.Test.cs @@ -0,0 +1,70 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class LogCosh + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> l = tf.keras.losses.LogCosh() + //>>> l(y_true, y_pred).numpy() + //0.108 + var loss = keras.losses.LogCosh(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.1084452f, call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() + //0.087 + var loss = keras.losses.LogCosh(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); + Assert.AreEqual((NDArray)0.08675616f, call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> l = tf.keras.losses.LogCosh( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> l(y_true, y_pred).numpy() + //0.217 + var loss = keras.losses.LogCosh(reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.2168904f, call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> l = tf.keras.losses.LogCosh( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> l(y_true, y_pred).numpy() + //array([0.217, 0.], dtype = float32) + var loss = keras.losses.LogCosh(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.2168904f, 0.0f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsoluteError.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsoluteError.Test.cs new file mode 100644 index 000000000..1ef83adeb --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsoluteError.Test.cs @@ -0,0 +1,71 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class MeanAbsoluteError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError() + //>>> mae(y_true, y_pred).numpy() + //0.5 + var loss = keras.losses.MeanAbsoluteError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.5f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> mae(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //0.25 + var loss = keras.losses.MeanAbsoluteError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(0.25f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> mae(y_true, y_pred).numpy() + //1.0 + var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(1.0f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> mae(y_true, y_pred).numpy() + //array([0.5, 0.5], dtype = float32) + var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.5f, 0.5f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsolutePercentageError.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsolutePercentageError.Test.cs new file mode 100644 index 000000000..440168396 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanAbsolutePercentageError.Test.cs @@ -0,0 +1,70 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class MeanAbsolutePercentageError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 2.0f, 1.0f }, { 2.0f, 3.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError() + //>>> mape(y_true, y_pred).numpy() + //50. + var loss = keras.losses.MeanAbsolutePercentageError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(50f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> mape(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //20. + var loss = keras.losses.MeanAbsolutePercentageError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(20f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> mape(y_true, y_pred).numpy() + //100. + var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(100f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> mape(y_true, y_pred).numpy() + //array([25., 75.], dtype = float32) + var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 25f, 75f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredError.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredError.Test.cs new file mode 100644 index 000000000..828d65e55 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredError.Test.cs @@ -0,0 +1,62 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class MeanSquaredErrorTest + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } }; + private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; + + [TestMethod] + + public void Mse_Double() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual(call.numpy(), 0.5); + } + + [TestMethod] + + public void Mse_Float() + { + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true_float, y_pred_float); + Assert.AreEqual(call.numpy(), 0.5f); + } + + [TestMethod] + + public void Mse_Sample_Weight() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 }); + Assert.AreEqual(call.numpy(), 0.25); + } + + [TestMethod] + public void Mse_Reduction_SUM() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual(call.numpy(), 1.0); + } + + [TestMethod] + + public void Mse_Reduction_NONE() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual(call.numpy(), new double[] { 0.5, 0.5 }); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredLogarithmicError.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredLogarithmicError.Test.cs new file mode 100644 index 000000000..5cecab0cc --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/MeanSquaredLogarithmicError.Test.cs @@ -0,0 +1,70 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Keras.Losses; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class MeanSquaredLogarithmicError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError() + //>>> msle(y_true, y_pred).numpy() + //0.240 + var loss = keras.losses.MeanSquaredLogarithmicError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.24022643f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> msle(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //0.120 + var loss = keras.losses.MeanSquaredLogarithmicError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(0.12011322f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> msle(y_true, y_pred).numpy() + //0.480 + var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.48045287f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> msle(y_true, y_pred).numpy() + //array([0.240, 0.240], dtype = float32) + var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.24022643f, 0.24022643f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs new file mode 100644 index 000000000..a3516bc83 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs @@ -0,0 +1,302 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + /// + /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers + /// + [TestClass] + public class PoolingTest : EagerModeTestBase + { + private NDArray input_array_1D = np.array(new float[,,] + { + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + }); + + private NDArray input_array_2D = np.array(new float[,,,] + {{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + },{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + }}); + + [TestMethod] + public void GlobalAverage1DPoolingChannelsLast() + { + var pool = keras.layers.GlobalAveragePooling1D(); + var y = pool.Apply(input_array_1D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(5, y.shape[1]); + + var expected = np.array(new float[,] + { + {1,2,3,3,3}, + {4,5,6,3,3}, + {7,8,9,3,3}, + {7,8,9,3,3} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalAverage1DPoolingChannelsFirst() + { + var pool = keras.layers.GlobalAveragePooling1D(data_format: "channels_first"); + var y = pool.Apply(input_array_1D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(3, y.shape[1]); + + var expected = np.array(new float[,] + { + {2.4f, 2.4f, 2.4f}, + {4.2f, 4.2f, 4.2f}, + {6.0f, 6.0f, 6.0f}, + {6.0f, 6.0f, 6.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalAverage2DPoolingChannelsLast() + { + var pool = keras.layers.GlobalAveragePooling2D(); + var y = pool.Apply(input_array_2D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(5, y.shape[1]); + + var expected = np.array(new float[,] + { + {2.5f, 3.5f, 4.5f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {2.5f, 3.5f, 4.5f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalAverage2DPoolingChannelsFirst() + { + var pool = keras.layers.GlobalAveragePooling2D(data_format: "channels_first"); + var y = pool.Apply(input_array_2D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(2, y.shape[1]); + + var expected = np.array(new float[,] + { + {2.4f, 4.2f}, + {6.0f, 6.0f}, + {2.4f, 4.2f}, + {6.0f, 6.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalMax1DPoolingChannelsLast() + { + var pool = keras.layers.GlobalMaxPooling1D(); + var y = pool.Apply(input_array_1D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(5, y.shape[1]); + + var expected = np.array(new float[,] + { + {1,2,3,3,3}, + {4,5,6,3,3}, + {7,8,9,3,3}, + {7,8,9,3,3} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalMax1DPoolingChannelsFirst() + { + var pool = keras.layers.GlobalMaxPooling1D(data_format: "channels_first"); + var y = pool.Apply(input_array_1D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(3, y.shape[1]); + + var expected = np.array(new float[,] + { + {3.0f, 3.0f, 3.0f}, + {6.0f, 6.0f, 6.0f}, + {9.0f, 9.0f, 9.0f}, + {9.0f, 9.0f, 9.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalMax2DPoolingChannelsLast() + { + var input_array_2D = np.array(new float[,,,] + {{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + },{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + }}); + + var pool = keras.layers.GlobalMaxPooling2D(); + var y = pool.Apply(input_array_2D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(5, y.shape[1]); + + var expected = np.array(new float[,] + { + {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void GlobalMax2DPoolingChannelsFirst() + { + var input_array_2D = np.array(new float[,,,] + {{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + },{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + }}); + + var pool = keras.layers.GlobalMaxPooling2D(data_format: "channels_first"); + var y = pool.Apply(input_array_2D); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(2, y.shape[1]); + + var expected = np.array(new float[,] + { + {9.0f, 6.0f}, + {9.0f, 9.0f}, + {9.0f, 6.0f}, + {9.0f, 9.0f} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void Max1DPoolingChannelsLast() + { + var x = input_array_1D; + var pool = keras.layers.MaxPooling1D(pool_size: 2, strides: 1); + var y = pool.Apply(x); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(2, y.shape[1]); + Assert.AreEqual(5, y.shape[2]); + + var expected = np.array(new float[,,] + { + {{1.0f, 2.0f, 3.0f, 3.0f, 3.0f}, + { 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}}, + + {{4.0f, 5.0f, 6.0f, 3.0f, 3.0f}, + {4.0f, 5.0f, 6.0f, 3.0f, 3.0f}}, + + {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}, + + {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + + [TestMethod] + public void Max2DPoolingChannelsLast() + { + var x = np.array(new float[,,,] + {{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + },{ + {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}}, + {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}}, + },{ + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}, + {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}} + }}); + + var pool = keras.layers.MaxPooling2D(pool_size: 2, strides: 1); + var y = pool.Apply(x); + + Assert.AreEqual(4, y.shape[0]); + Assert.AreEqual(1, y.shape[1]); + Assert.AreEqual(2, y.shape[2]); + Assert.AreEqual(5, y.shape[3]); + + var expected = np.array(new float[,,,] + { + {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f}, + {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}}}, + + + {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}}, + + + {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f}, + {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}}}, + + + {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f}, + {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}} + }); + + Assert.AreEqual(expected, y[0].numpy()); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs new file mode 100644 index 000000000..67e2b0464 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -0,0 +1,167 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.NumPy; +using Tensorflow.Train; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Layers +{ + [TestClass] + public class Rnn + { + [TestMethod] + public void SimpleRNNCell() + { + var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); + var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; + var x = tf.random.normal((4, 100)); + var (y, h1) = cell.Apply(inputs: x, states: h0); + var h2 = h1; + Assert.AreEqual((4, 64), y.shape); + Assert.AreEqual((4, 64), h2[0].shape); + } + + [TestMethod] + public void StackedRNNCell() + { + var inputs = tf.ones((32, 10)); + var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) }; + var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; + var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); + var (output, state) = stackedRNNCell.Apply(inputs, states); + Assert.AreEqual((32, 5), output.shape); + Assert.AreEqual((32, 4), state[0].shape); + } + + [TestMethod] + public void LSTMCell() + { + var inputs = tf.ones((2, 100)); + var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; + var rnn = tf.keras.layers.LSTMCell(4); + var (output, new_states) = rnn.Apply(inputs, states); + Assert.AreEqual((2, 4), output.shape); + Assert.AreEqual((2, 4), new_states[0].shape); + } + + [TestMethod] + public void TrainLSTMWithMnist() + { + var input = keras.Input((784)); + var x = keras.layers.Reshape((28, 28)).Apply(input); + x = keras.layers.LSTM(50, return_sequences: true).Apply(x); + x = keras.layers.LSTM(100).Apply(x); + var output = keras.layers.Dense(10, activation: "softmax").Apply(x); + + var model = keras.Model(input, output); + model.summary(); + model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = true, + ValidationSize = 55000, + }).Result; + var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); + } + + [TestMethod] + public void SimpleRNN() + { + var input = keras.Input((784)); + var x = keras.layers.Reshape((28, 28)).Apply(input); + x = keras.layers.SimpleRNN(10).Apply(x); + var output = keras.layers.Dense(10, activation: "softmax").Apply(x); + + var model = keras.Model(input, output); + model.summary(); + model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 58000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2); + } + + [TestMethod] + public void RNNForSimpleRNNCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); + var rnn = tf.keras.layers.RNN(cell: cell); + var cgf = rnn.get_config(); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 10), output.shape); + + } + [TestMethod] + public void RNNForStackedRNNCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; + var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); + var rnn = tf.keras.layers.RNN(cell: stackedRNNCell); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 5), output.shape); + } + + [TestMethod] + public void RNNForLSTMCell() + { + var inputs = tf.ones((5, 10, 8)); + var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); + var output = rnn.Apply(inputs); + Console.WriteLine($"output: {output}"); + Assert.AreEqual((5, 4), output.shape); + } + + [TestMethod] + public void GRUCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4)); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4, reset_after:false, use_bias:false)); + output = rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + + } + + [TestMethod] + public void GRU() + { + var inputs = tf.ones((32, 10, 8)); + var gru = tf.keras.layers.GRU(4); + var output = gru.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + } + + [TestMethod] + public void Bidirectional() + { + var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true)); + var inputs = tf.random.normal((32, 10, 8)); + var outputs = bi.Apply(inputs); + Assert.AreEqual((32, 10, 20), outputs.shape); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs new file mode 100644 index 000000000..0bb1d0110 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs @@ -0,0 +1,57 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest.Losses; + +[TestClass] +public class LossesTest : EagerModeTestBase +{ + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy + /// + [TestMethod] + public void BinaryCrossentropy() + { + // Example 1: (batch_size = 1, number of samples = 4) + var y_true = tf.constant(new float[] { 0, 1, 0, 0 }); + var y_pred = tf.constant(new float[] { -18.6f, 0.51f, 2.94f, -12.8f }); + var bce = tf.keras.losses.BinaryCrossentropy(from_logits: true); + var loss = bce.Call(y_true, y_pred); + Assert.AreEqual((float)loss, 0.865458f); + + // Example 2: (batch_size = 2, number of samples = 4) + y_true = tf.constant(new float[,] { { 0, 1 }, { 0, 0 } }); + y_pred = tf.constant(new float[,] { { -18.6f, 0.51f }, { 2.94f, -12.8f } }); + bce = tf.keras.losses.BinaryCrossentropy(from_logits: true); + loss = bce.Call(y_true, y_pred); + Assert.AreEqual((float)loss, 0.865458f); + + // Using 'sample_weight' attribute + loss = bce.Call(y_true, y_pred, sample_weight: tf.constant(new[] { 0.8f, 0.2f })); + Assert.AreEqual((float)loss, 0.2436386f); + + // Using 'sum' reduction` type. + bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.SUM); + loss = bce.Call(y_true, y_pred); + Assert.AreEqual((float)loss, 1.730916f); + + // Using 'none' reduction type. + bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE); + loss = bce.Call(y_true, y_pred); + Assert.IsTrue(new NDArray(new float[] { 0.23515666f, 1.4957594f }) == loss.numpy()); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/losses/SigmoidFocalCrossEntropy + /// + [TestMethod] + public void SigmoidFocalCrossEntropy() + { + var y_true = np.expand_dims(np.array(new[] { 1.0f, 1.0f, 0 })); + var y_pred = np.expand_dims(np.array(new[] { 0.97f, 0.91f, 0.03f })); + var bce = tf.keras.losses.SigmoidFocalCrossEntropy(); + var loss = bce.Call(y_true, y_pred); + Assert.AreEqual(new[] { 6.8532745e-06f, 1.909787e-04f, 2.0559824e-05f }, loss.numpy()); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs new file mode 100644 index 000000000..560d3580c --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs @@ -0,0 +1,322 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.UnitTest.Layers.Metrics; + +[TestClass] +public class MetricsTest : EagerModeTestBase +{ + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy + /// + [TestMethod] + public void Accuracy() + { + var y_true = np.array(new[,] { { 1 }, { 2 }, { 3 }, { 4 } }); + var y_pred = np.array(new[,] { { 0f }, { 2f }, { 3f }, { 4f } }); + var m = tf.keras.metrics.Accuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.75f); + + m.reset_states(); + var weights = np.array(new[] { 1f, 1f, 0f, 0f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy + /// + [TestMethod] + public void BinaryAccuracy() + { + var y_true = np.array(new[,] { { 1 }, { 1 }, { 0 }, { 0 } }); + var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } }); + var m = tf.keras.metrics.BinaryAccuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.75f); + + m.reset_states(); + var weights = np.array(new[] { 1f, 0f, 0f, 1f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy + /// + [TestMethod] + public void CategoricalAccuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.CategoricalAccuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy + /// + [TestMethod] + public void SparseCategoricalAccuracy() + { + var y_true = np.array(new[] { 2, 1 }); + var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.SparseCategoricalAccuracy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy + /// + [TestMethod] + public void CategoricalCrossentropy() + { + var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } }); + var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); + var m = tf.keras.metrics.CategoricalCrossentropy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 1.1769392f); + + m.reset_states(); + var weights = np.array(new[] { 0.3f, 0.7f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1.6271976f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy + /// + [TestMethod] + public void SparseCategoricalCrossentropy() + { + var y_true = np.array(new[] { 1, 2 }); + var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } }); + var m = tf.keras.metrics.SparseCategoricalCrossentropy(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 1.1769392f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity + /// + [TestMethod] + public void CosineSimilarity() + { + var y_true = np.array(new[,] { { 0, 1 }, { 1, 1 } }); + var y_pred = np.array(new[,] { { 1f, 0f }, { 1f, 1f } }); + var m = tf.keras.metrics.CosineSimilarity(axis: 1); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.49999997f); + + m.reset_states(); + var weights = np.array(new[] { 0.3f, 0.7f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.6999999f); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score + /// + [TestMethod] + public void F1Score() + { + var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); + var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f }); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore + /// + [TestMethod] + public void FBetaScore() + { + var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } }); + var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f }); + } + + /// + /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss + /// + [TestMethod] + public void HammingLoss() + { + // multi-class hamming loss + var y_true = np.array(new[,] + { + { 1, 0, 0, 0 }, + { 0, 0, 1, 0 }, + { 0, 0, 0, 1 }, + { 0, 1, 0, 0 } + }); + var y_pred = np.array(new[,] + { + { 0.8f, 0.1f, 0.1f, 0.0f }, + { 0.2f, 0.0f, 0.8f, 0.0f }, + { 0.05f, 0.05f, 0.1f, 0.8f }, + { 1.0f, 0.0f, 0.0f, 0.0f } + }); + var m = tf.keras.metrics.HammingLoss(mode: "multiclass", threshold: 0.6f); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.25f); + + // multi-label hamming loss + y_true = np.array(new[,] + { + { 1, 0, 1, 0 }, + { 0, 1, 0, 1 }, + { 0, 0, 0, 1 } + }); + y_pred = np.array(new[,] + { + { 0.82f, 0.5f, 0.9f, 0.0f }, + { 0f, 1f, 0.4f, 0.98f }, + { 0.89f, 0.79f, 0f, 0.3f } + }); + m = tf.keras.metrics.HammingLoss(mode: "multilabel", threshold: 0.8f); + m.update_state(y_true, y_pred); + r = m.result().numpy(); + Assert.AreEqual(r, 0.16666667f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy + /// + [TestMethod] + public void TopKCategoricalAccuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy + /// + [TestMethod] + public void SparseTopKCategoricalAccuracy() + { + var y_true = np.array(new[] { 2, 1 }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + + m.reset_states(); + var weights = np.array(new[] { 0.7f, 0.3f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 0.3f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy + /// + [TestMethod] + public void top_k_categorical_accuracy() + { + var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } }); + var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } }); + var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3); + Assert.AreEqual(m.numpy(), new[] { 1f, 1f }); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision + /// + [TestMethod] + public void Precision() + { + var y_true = np.array(new[] { 0, 1, 1, 1 }); + var y_pred = np.array(new[] { 1, 0, 1, 1 }); + var m = tf.keras.metrics.Precision(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.6666667f); + + m.reset_states(); + var weights = np.array(new[] { 0f, 0f, 1f, 0f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1f); + + // With top_k=2, it will calculate precision over y_true[:2] + // and y_pred[:2] + m = tf.keras.metrics.Precision(top_k: 2); + m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 })); + r = m.result().numpy(); + Assert.AreEqual(r, 0f); + + // With top_k=4, it will calculate precision over y_true[:4] + // and y_pred[:4] + m = tf.keras.metrics.Precision(top_k: 4); + m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 })); + r = m.result().numpy(); + Assert.AreEqual(r, 0.5f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall + /// + [TestMethod] + public void Recall() + { + var y_true = np.array(new[] { 0, 1, 1, 1 }); + var y_pred = np.array(new[] { 1, 0, 1, 1 }); + var m = tf.keras.metrics.Recall(); + m.update_state(y_true, y_pred); + var r = m.result().numpy(); + Assert.AreEqual(r, 0.6666667f); + + m.reset_states(); + var weights = np.array(new[] { 0f, 0f, 1f, 0f }); + m.update_state(y_true, y_pred, sample_weight: weights); + r = m.result().numpy(); + Assert.AreEqual(r, 1f); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs new file mode 100644 index 000000000..d4b11a9b2 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs @@ -0,0 +1,62 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Model +{ + [TestClass] + public class ModelBuildTest + { + [TestMethod] + public void DenseBuild() + { + // two dimensions input with unknown batchsize + var input = tf.keras.layers.Input((17, 60)); + var dense = tf.keras.layers.Dense(64); + var output = dense.Apply(input); + var model = tf.keras.Model(input, output); + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + // one dimensions input with unknown batchsize + var input_2 = tf.keras.layers.Input((60)); + var dense_2 = tf.keras.layers.Dense(64); + var output_2 = dense_2.Apply(input_2); + var model_2 = tf.keras.Model(input_2, output_2); + model_2.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + // two dimensions input with specified batchsize + var input_3 = tf.keras.layers.Input((17, 60), 8); + var dense_3 = tf.keras.layers.Dense(64); + var output_3 = dense_3.Apply(input_3); + var model_3 = tf.keras.Model(input_3, output_3); + model_3.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + // one dimensions input with specified batchsize + var input_4 = tf.keras.layers.Input((60), 8); + var dense_4 = tf.keras.layers.Dense(64); + var output_4 = dense_4.Apply(input_4); + var model_4 = tf.keras.Model(input_4, output_4); + model_4.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + } + + [TestMethod] + public void NestedSequential() + { + var block1 = keras.Sequential(new[] { + keras.layers.InputLayer((3, 3)), + keras.Sequential(new [] + { + keras.layers.Flatten(), + keras.layers.Dense(5) + } + ) + }); + block1.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + var x = tf.ones((1, 3, 3)); + var y = block1.predict(x); + Console.WriteLine(y); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs new file mode 100644 index 000000000..c733537e7 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs @@ -0,0 +1,218 @@ +using Microsoft.VisualStudio.TestPlatform.Utilities; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json.Linq; +using System.Collections.Generic; +using System.Linq; +using System.Xml.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.UnitTest.Helpers; +using Tensorflow.NumPy; +using static HDF.PInvoke.H5Z; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Model; + +[TestClass] +public class ModelLoadTest +{ + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile"); + model.summary(); + + model.compile(new Adam(0.0001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + // check the weights + var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy"); + var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy"); + + Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second)); + Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second)); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 8; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 58000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + } + + [TestMethod] + public void AlexnetFromSequential() + { + new ModelSaveTest().AlexnetFromSequential(); + var model = tf.keras.models.load_model(@"./alexnet_from_sequential"); + model.summary(); + + model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + } + + [TestMethod] + public void ModelWithSelfDefinedModule() + { + var model = tf.keras.models.load_model(@"Assets/python_func_model"); + model.summary(); + + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 8; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 55000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + } + + [Ignore] + [TestMethod] + public void LSTMLoad() + { + var model = tf.keras.models.load_model(@"Assets/lstm_from_sequential"); + model.summary(); + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.MeanSquaredError(), new string[] { "accuracy" }); + var inputs = tf.random.normal(shape: (10, 5, 3)); + var outputs = tf.random.normal(shape: (10, 1)); + model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 5, workers: 16, use_multiprocessing: true); + } + + [Ignore] + [TestMethod] + public void VGG19() + { + var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); + model.summary(); + + var classify_model = keras.Sequential(new System.Collections.Generic.List() + { + model, + keras.layers.Flatten(), + keras.layers.Dense(10), + }); + classify_model.summary(); + + classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var x = np.random.uniform(0, 1, (8, 512, 512, 3)); + var y = np.ones(8); + + classify_model.fit(x, y, batch_size: 4); + } + + [Ignore] + [TestMethod] + public void TestModelBeforeTF2_5() + { + var a = keras.layers; + var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; + model.summary(); + } + + + [TestMethod] + public void BiasRegularizerSaveAndLoad() + { + var savemodel = keras.Sequential(new List() + { + tf.keras.layers.InputLayer((227, 227, 3)), + tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), + + tf.keras.layers.Flatten(), + + tf.keras.layers.Dense(1000, activation: "linear"), + tf.keras.layers.Softmax(1) + }); + + savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs); + + savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf"); + + var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load"); + loadModel.summary(); + + loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs); + } + + + [TestMethod] + public void CreateConcatenateModelSaveAndLoad() + { + // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. + var input_layer = tf.keras.layers.Input((8, 8, 5)); + + var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); + conv1.Name = "conv1"; + + var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); + conv2.Name = "conv2"; + + var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); + concat1.Name = "concat1"; + + var model = tf.keras.Model(input_layer, concat1); + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + model.save(@"Assets/concat_axis3_model"); + + + var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); + + var tensors1 = model.predict(tensorInput); + + Assert.AreEqual((1, 8, 8, 4), tensors1.shape); + + model = null; + keras.backend.clear_session(); + + var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); + + var tensors2 = model2.predict(tensorInput); + + Assert.AreEqual(tensors1.shape, tensors2.shape); + } + +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs new file mode 100644 index 000000000..0854a09da --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs @@ -0,0 +1,212 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Generic; +using System.Diagnostics; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Models; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.UnitTest.Helpers; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest.Model +{ + /// + /// https://www.tensorflow.org/guide/keras/save_and_serialize + /// + [TestClass] + public class ModelSaveTest : EagerModeTestBase + { + [TestMethod] + public void GetAndFromConfig() + { + var model = GetFunctionalModel(); + var config = model.get_config(); + Debug.Assert(config is FunctionalConfig); + var new_model = new ModelsApi().from_config(config as FunctionalConfig); + Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); + } + + IModel GetFunctionalModel() + { + // Create a simple model. + var inputs = keras.Input(shape: 32); + var dense_layer = keras.layers.Dense(1); + var outputs = dense_layer.Apply(inputs); + return keras.Model(inputs, outputs); + } + + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var inputs = tf.keras.layers.Input((28, 28, 1)); + var x = tf.keras.layers.Flatten().Apply(inputs); + x = tf.keras.layers.Dense(100, activation: "relu").Apply(x); + x = tf.keras.layers.Dense(units: 10).Apply(x); + var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x); + var model = tf.keras.Model(inputs, outputs); + + model.compile(new Adam(0.001f), + tf.keras.losses.SparseCategoricalCrossentropy(), + new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 58000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_compile", save_format: "tf"); + } + + [TestMethod] + public void SimpleModelFromSequential() + { + var model = keras.Sequential(new List() + { + tf.keras.layers.InputLayer((28, 28, 1)), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(100, "relu"), + tf.keras.layers.Dense(10), + tf.keras.layers.Softmax() + }); + + model.summary(); + + model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 58000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_sequential", save_format: "tf"); + } + + [TestMethod] + public void AlexnetFromSequential() + { + var model = keras.Sequential(new List() + { + tf.keras.layers.InputLayer((227, 227, 3)), + tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), + + tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + tf.keras.layers.BatchNormalization(), + + tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), + + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(4096, activation: "relu"), + tf.keras.layers.Dropout(0.5f), + + tf.keras.layers.Dense(4096, activation: "relu"), + tf.keras.layers.Dropout(0.5f), + + tf.keras.layers.Dense(1000, activation: "linear"), + tf.keras.layers.Softmax(1) + }); + + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + + model.save("./alexnet_from_sequential", save_format: "tf"); + + // The saved model can be test with the following python code: + #region alexnet_python_code + //import pathlib + //import tensorflow as tf + + //def func(a): + // return -a + + //if __name__ == '__main__': + // model = tf.keras.models.load_model("./pb_alex_sequential") + // model.summary() + + // num_classes = 5 + // batch_size = 128 + // img_height = 227 + // img_width = 227 + // epochs = 100 + + // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" + // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) + // data_dir = pathlib.Path(data_dir) + + // train_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "training", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + // val_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "validation", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + + // model.compile(optimizer = 'adam', + // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), + // metrics =['accuracy']) + + // model.build((None, img_height, img_width, 3)) + + // history = model.fit( + // train_ds, + // validation_data = val_ds, + // epochs = epochs + // ) + #endregion + } + + [TestMethod] + public void SaveAfterLoad() + { + var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile"); + model.summary(); + + model.save("Assets/saved_auto_compile_after_loading"); + + //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading"); + //model.summary(); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs new file mode 100644 index 000000000..54b76d41a --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs @@ -0,0 +1,145 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow.Keras.Optimizers; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class MultiInputModelTest + { + [TestMethod] + public void LeNetModel() + { + var inputs = keras.Input((28, 28, 1)); + var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); + var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1); + var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1); + var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2); + var flat1 = keras.layers.Flatten().Apply(pool2); + + var inputs_2 = keras.Input((28, 28, 1)); + var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2); + var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2); + var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2); + var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2); + var flat1_2 = keras.layers.Flatten().Apply(pool2_2); + + var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); + var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); + var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); + var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); + var output = keras.layers.Softmax(-1).Apply(dense3); + + var model = keras.Model((inputs, inputs_2), output); + model.summary(); + + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59900, + }).Result; + + var loss = keras.losses.SparseCategoricalCrossentropy(); + var optimizer = new Adam(0.001f); + model.compile(optimizer, loss, new string[] { "accuracy" }); + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + NDArray x2 = x1; + + var x = new NDArray[] { x1, x2 }; + model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); + + x1 = x1["0:8"]; + x2 = x1; + + x = new NDArray[] { x1, x2 }; + var y = dataset.Train.Labels["0:8"]; + (model as Engine.Model).evaluate(x, y); + + x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT); + x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT); + var pred = model.predict((x1, x2)); + Console.WriteLine(pred); + } + + [TestMethod] + public void LeNetModelDataset() + { + var inputs = keras.Input((28, 28, 1)); + var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); + var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1); + var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1); + var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2); + var flat1 = keras.layers.Flatten().Apply(pool2); + + var inputs_2 = keras.Input((28, 28, 1)); + var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2); + var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2); + var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2); + var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2); + var flat1_2 = keras.layers.Flatten().Apply(pool2_2); + + var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); + var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); + var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); + var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); + var output = keras.layers.Softmax(-1).Apply(dense3); + + var model = keras.Model((inputs, inputs_2), output); + model.summary(); + + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59900, + }).Result; + + var loss = keras.losses.SparseCategoricalCrossentropy(); + var optimizer = new Adam(0.001f); + model.compile(optimizer, loss, new string[] { "accuracy" }); + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + + var multiInputDataset = tf.data.Dataset.zip( + tf.data.Dataset.from_tensor_slices(x1), + tf.data.Dataset.from_tensor_slices(x1), + tf.data.Dataset.from_tensor_slices(dataset.Train.Labels) + ).batch(8); + multiInputDataset.FirstInputTensorCount = 2; + + model.fit(multiInputDataset, epochs: 3); + + x1 = x1["0:8"]; + + multiInputDataset = tf.data.Dataset.zip( + tf.data.Dataset.from_tensor_slices(x1), + tf.data.Dataset.from_tensor_slices(x1), + tf.data.Dataset.from_tensor_slices(dataset.Train.Labels["0:8"]) + ).batch(8); + multiInputDataset.FirstInputTensorCount = 2; + + (model as Engine.Model).evaluate(multiInputDataset); + + x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT); + var x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT); + + multiInputDataset = tf.data.Dataset.zip( + tf.data.Dataset.from_tensor_slices(x1), + tf.data.Dataset.from_tensor_slices(x2) + ).batch(8); + multiInputDataset.FirstInputTensorCount = 2; + + var pred = model.predict(multiInputDataset); + Console.WriteLine(pred); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs new file mode 100644 index 000000000..3706e65c8 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs @@ -0,0 +1,95 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Threading.Tasks; +using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class MultiThreads + { + [TestMethod, Ignore("Failed on MacOS")] + public void Test1() + { + //Arrange + string savefile = "mymodel.h5"; + var model1 = BuildModel(); + model1.save_weights(savefile); + var model2 = BuildModel(); + + //act + model1.load_weights(savefile); + model2.load_weights(savefile); + + } + + [TestMethod, Ignore("Failed on MacOS")] + public void Test2() + { + //Arrange + string savefile = "mymodel2.h5"; + var model1 = BuildModel(); + model1.save_weights(savefile); + model1 = BuildModel(); //recreate model + + //act + model1.load_weights(savefile); + + } + + [TestMethod, Ignore("Failed on MacOS")] + public void Test3Multithreading() + { + //Arrange + string savefile = "mymodel3.h5"; + var model = BuildModel(); + model.save_weights(savefile); + + //Sanity check without multithreading + for (int i = 0; i < 2; i++) + { + var clone = BuildModel(); + clone.load_weights(savefile); + + //Predict something + clone.predict(np.array(new float[,] { { 0, 0 } })); + } //works + + //act + ParallelOptions parallelOptions = new ParallelOptions(); + parallelOptions.MaxDegreeOfParallelism = 8; + var input = np.array(new float[,] { { 0, 0 } }); + Parallel.For(0, 8, parallelOptions, i => + { + var clone = BuildModel(); + clone.load_weights(savefile); + //Predict something + clone.predict(input); + }); + } + + IModel BuildModel() + { + tf.Context.reset_context(); + var inputs = keras.Input(shape: 2); + + // 1st dense layer + var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid); + var outputs = DenseLayer.Apply(inputs); + + // build keras model + var model = tf.keras.Model(inputs, outputs, name: Guid.NewGuid().ToString()); + // show model summary + model.summary(); + + // compile keras model into tensorflow's static graph + model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()), + optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()), + metrics: new[] { "accuracy" }); + return model; + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs b/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs new file mode 100644 index 000000000..15fbe11a4 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs @@ -0,0 +1,49 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class OutputTest + { + [TestMethod] + public void OutputRedirectTest() + { + using var newOutput = new System.IO.StringWriter(); + tf_output_redirect = newOutput; + var model = keras.Sequential(); + model.add(keras.Input(shape: 16)); + model.summary(); + string output = newOutput.ToString(); + Assert.IsTrue(output.StartsWith("Model: sequential")); + tf_output_redirect = null; // don't forget to change it to null !!!! + } + + [TestMethod] + public void SwitchOutputsTest() + { + using var newOutput = new System.IO.StringWriter(); + var model = keras.Sequential(); + model.add(keras.Input(shape: 16)); + model.summary(); // Console.Out + + tf_output_redirect = newOutput; // change to the custom one + model.summary(); + string firstOutput = newOutput.ToString(); + Assert.IsTrue(firstOutput.StartsWith("Model: sequential")); + + // if tf_output_reditect is StringWriter, calling "set" will make the writer clear. + tf_output_redirect = null; // null means Console.Out + model.summary(); + + tf_output_redirect = newOutput; // again, to test whether the newOutput is clear. + model.summary(); + string secondOutput = newOutput.ToString(); + Assert.IsTrue(secondOutput.StartsWith("Model: sequential")); + + Assert.IsTrue(firstOutput == secondOutput); + tf_output_redirect = null; // don't forget to change it to null !!!! + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs new file mode 100644 index 000000000..82c84e794 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -0,0 +1,396 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class PreprocessingTests : EagerModeTestBase + { + private readonly string[] texts = new string[] { + "It was the best of times, it was the worst of times.", + "Mr and Mrs Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.", + "It was the best of times, it was the worst of times.", + "Mr and Mrs Dursley of number four, Privet Drive.", + }; + + private readonly string[][] tokenized_texts = new string[][] { + new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, + new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive","were","proud","to","say","that","they","were","perfectly","normal","thank","you","very","much"}, + new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, + new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive"}, + }; + + private readonly string[] processed_texts = new string[] { + "it was the best of times it was the worst of times", + "mr and mrs dursley of number four privet drive were proud to say that they were perfectly normal thank you very much", + "it was the best of times it was the worst of times", + "mr and mrs dursley of number four privet drive", + }; + + private const string OOV = ""; + + [TestMethod] + public void TokenizeWithNoOOV() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + Assert.AreEqual(7, tokenizer.word_index["worst"]); + Assert.AreEqual(12, tokenizer.word_index["number"]); + Assert.AreEqual(16, tokenizer.word_index["were"]); + } + + [TestMethod] + public void TokenizeWithNoOOV_Tkn() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + // Use the list version, where the tokenization has already been done. + tokenizer.fit_on_texts(tokenized_texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + Assert.AreEqual(7, tokenizer.word_index["worst"]); + Assert.AreEqual(12, tokenizer.word_index["number"]); + Assert.AreEqual(16, tokenizer.word_index["were"]); + } + + [TestMethod] + public void TokenizeWithOOV() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(28, tokenizer.word_index.Count); + + Assert.AreEqual(1, tokenizer.word_index[OOV]); + Assert.AreEqual(8, tokenizer.word_index["worst"]); + Assert.AreEqual(13, tokenizer.word_index["number"]); + Assert.AreEqual(17, tokenizer.word_index["were"]); + } + + [TestMethod] + public void TokenizeWithOOV_Tkn() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + // Use the list version, where the tokenization has already been done. + tokenizer.fit_on_texts(tokenized_texts); + + Assert.AreEqual(28, tokenizer.word_index.Count); + + Assert.AreEqual(1, tokenizer.word_index[OOV]); + Assert.AreEqual(8, tokenizer.word_index["worst"]); + Assert.AreEqual(13, tokenizer.word_index["number"]); + Assert.AreEqual(17, tokenizer.word_index["were"]); + } + + [TestMethod] + public void TokenizeTextsToSequences() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + Assert.AreEqual(4, sequences.Count); + + Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + } + + [TestMethod] + public void TokenizeTextsToSequences_Tkn() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + // Use the list version, where the tokenization has already been done. + tokenizer.fit_on_texts(tokenized_texts); + + var sequences = tokenizer.texts_to_sequences(tokenized_texts); + Assert.AreEqual(4, sequences.Count); + + Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + } + + [TestMethod] + public void TokenizeTextsToSequencesAndBack() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + Assert.AreEqual(4, sequences.Count); + + var processed = tokenizer.sequences_to_texts(sequences); + + Assert.AreEqual(4, processed.Count); + + for (var i = 0; i < processed.Count; i++) + Assert.AreEqual(processed_texts[i], processed[i]); + } + + [TestMethod] + public void TokenizeTextsToSequencesAndBack_Tkn1() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + // Use the list version, where the tokenization has already been done. + tokenizer.fit_on_texts(tokenized_texts); + + // Use the list version, where the tokenization has already been done. + var sequences = tokenizer.texts_to_sequences(tokenized_texts); + Assert.AreEqual(4, sequences.Count); + + var processed = tokenizer.sequences_to_texts(sequences); + + Assert.AreEqual(4, processed.Count); + + for (var i = 0; i < processed.Count; i++) + Assert.AreEqual(processed_texts[i], processed[i]); + } + + [TestMethod] + public void TokenizeTextsToSequencesAndBack_Tkn2() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + // Use the list version, where the tokenization has already been done. + tokenizer.fit_on_texts(tokenized_texts); + + var sequences = tokenizer.texts_to_sequences(texts); + Assert.AreEqual(4, sequences.Count); + + var processed = tokenizer.sequences_to_texts(sequences); + + Assert.AreEqual(4, processed.Count); + + for (var i = 0; i < processed.Count; i++) + Assert.AreEqual(processed_texts[i], processed[i]); + } + + [TestMethod] + public void TokenizeTextsToSequencesAndBack_Tkn3() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + // Use the list version, where the tokenization has already been done. + var sequences = tokenizer.texts_to_sequences(tokenized_texts); + Assert.AreEqual(4, sequences.Count); + + var processed = tokenizer.sequences_to_texts(sequences); + + Assert.AreEqual(4, processed.Count); + + for (var i = 0; i < processed.Count; i++) + Assert.AreEqual(processed_texts[i], processed[i]); + } + [TestMethod] + public void TokenizeTextsToSequencesWithOOV() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + Assert.AreEqual(4, sequences.Count); + + Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + + for (var i = 0; i < sequences.Count; i++) + for (var j = 0; j < sequences[i].Length; j++) + Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]); + } + + [TestMethod] + public void TokenizeTextsToSequencesWithOOVPresent() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words: 20); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + Assert.AreEqual(4, sequences.Count); + + Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + + var oov_count = 0; + for (var i = 0; i < sequences.Count; i++) + for (var j = 0; j < sequences[i].Length; j++) + if (tokenizer.word_index[OOV] == sequences[i][j]) + oov_count += 1; + + Assert.AreEqual(9, oov_count); + } + + [TestMethod] + public void PadSequencesWithDefaults() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences); + + Assert.AreEqual(4, padded.dims[0]); + Assert.AreEqual(22, padded.dims[1]); + + Assert.AreEqual(padded[0, 19], tokenizer.word_index["worst"]); + for (var i = 0; i < 8; i++) + Assert.AreEqual(padded[0, i], 0); + Assert.AreEqual(padded[1, 10], tokenizer.word_index["proud"]); + for (var i = 0; i < 20; i++) + Assert.AreNotEqual(padded[1, i], 0); + } + + [TestMethod] + public void PadSequencesPrePaddingTrunc() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15); + + Assert.AreEqual(4, padded.dims[0]); + Assert.AreEqual(15, padded.dims[1]); + + Assert.AreEqual(padded[0, 12], tokenizer.word_index["worst"]); + for (var i = 0; i < 3; i++) + Assert.AreEqual(padded[0, i], 0); + Assert.AreEqual(padded[1, 3], tokenizer.word_index["proud"]); + for (var i = 0; i < 15; i++) + Assert.AreNotEqual(padded[1, i], 0); + } + + [TestMethod] + public void PadSequencesPrePaddingTrunc_Larger() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45); + + Assert.AreEqual(4, padded.dims[0]); + Assert.AreEqual(45, padded.dims[1]); + + Assert.AreEqual(padded[0, 42], tokenizer.word_index["worst"]); + for (var i = 0; i < 33; i++) + Assert.AreEqual(padded[0, i], 0); + Assert.AreEqual(padded[1, 33], tokenizer.word_index["proud"]); + } + + [TestMethod] + public void PadSequencesPostPaddingTrunc() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post"); + + Assert.AreEqual(4, padded.dims[0]); + Assert.AreEqual(15, padded.dims[1]); + + Assert.AreEqual(padded[0, 9], tokenizer.word_index["worst"]); + for (var i = 12; i < 15; i++) + Assert.AreEqual(padded[0, i], 0); + Assert.AreEqual(padded[1, 10], tokenizer.word_index["proud"]); + for (var i = 0; i < 15; i++) + Assert.AreNotEqual(padded[1, i], 0); + } + + [TestMethod] + public void PadSequencesPostPaddingTrunc_Larger() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45, padding: "post", truncating: "post"); + + Assert.AreEqual(4, padded.dims[0]); + Assert.AreEqual(45, padded.dims[1]); + + Assert.AreEqual(padded[0, 9], tokenizer.word_index["worst"]); + for (var i = 32; i < 45; i++) + Assert.AreEqual(padded[0, i], 0); + Assert.AreEqual(padded[1, 10], tokenizer.word_index["proud"]); + } + + [TestMethod] + public void TextToMatrixBinary() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts); + + Assert.AreEqual(texts.Length, matrix.dims[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray())); + } + + [TestMethod] + public void TextToMatrixCount() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode: "count"); + + Assert.AreEqual(texts.Length, matrix.dims[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 2, 2, 2, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray())); + } + + [TestMethod] + public void TextToMatrixFrequency() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode: "freq"); + + Assert.AreEqual(texts.Length, matrix.dims[0]); + + double t12 = 2.0 / 12.0; + double o12 = 1.0 / 12.0; + double t22 = 2.0 / 22.0; + double o22 = 1.0 / 22.0; + + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, t12, t12, t12, o12, t12, t12, o12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 0, 0, 0, 0, o22, 0, 0, o22, o22, o22, o22, o22, o22, o22, o22, t22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22 }, matrix[1].ToArray())); + } + + [TestMethod] + public void TextToMatrixTDIDF() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode: "tfidf"); + + Assert.AreEqual(texts.Length, matrix.dims[0]); + + double t1 = 1.1736001944781467; + double t2 = 0.69314718055994529; + double t3 = 1.860112299086919; + double t4 = 1.0986122886681098; + double t5 = 0.69314718055994529; + + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, t1, t1, t1, t2, 0, t1, t2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 0, 0, 0, 0, 0, 0, 0, 0, t5, t5, t5, t5, t5, t5, t5, t5, t3, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4 }, matrix[1].ToArray())); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj new file mode 100644 index 000000000..edac1c2ff --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -0,0 +1,87 @@ + + + + net6.0 + + false + AnyCPU;x64 + + + + DEBUG;TRACE + x64 + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/test/TensorFlowNET.Native.UnitTest/Attributes/AttributesTestcs.cs b/test/TensorFlowNET.Native.UnitTest/Attributes/AttributesTestcs.cs new file mode 100644 index 000000000..4db19ed55 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Attributes/AttributesTestcs.cs @@ -0,0 +1,91 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// tensorflow\c\c_api_test.cc + /// `class CApiAttributesTest` + /// + [TestClass] + public class AttributesTestcs : CApiTest, IDisposable + { + private Graph graph_; + private int counter_; + private Status s_; + + public AttributesTestcs() + { + s_ = new Status(); + graph_ = new Graph(); + } + + private OperationDescription init(string type) + { + // Construct op_name to match the name used by REGISTER_OP in the + // ATTR_TEST_REGISTER calls above. + string op_name = "CApiAttributesTestOp"; + if (type.Contains("list(")) + { + op_name += "List"; + type = type.Substring(5, type.Length - 6); + } + op_name += type; + return c_api.TF_NewOperation(graph_, op_name, $"name{counter_++}"); + } + + /// + /// REGISTER_OP for CApiAttributesTest test cases. + /// Registers two ops, each with a single attribute called 'v'. + /// The attribute in one op will have a type 'type', the other + /// will have list(type). + /// + /// + private void ATTR_TEST_REGISTER_OP(string type) + { + + } + + private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) + { + var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); + EXPECT_EQ(TF_Code.TF_OK, s_.Code); + char e = expected_list_size >= 0 ? (char)1 : (char)0; + /*EXPECT_EQ(e, m.is_list); + EXPECT_EQ(expected_list_size, m.list_size); + EXPECT_EQ(expected_type, m.type); + EXPECT_EQ(expected_total_size, m.total_size);*/ + } + + [TestMethod] + public void String() + { + var desc = init("string"); + c_api.TF_SetAttrString(desc, "v", "bunny", 5); + + var oper = c_api.TF_FinishOperation(desc, s_); + //ASSERT_EQ(TF_Code.TF_OK, s_.Code); + //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); + //var value = new char[5]; + + //c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_); + //EXPECT_EQ(TF_Code.TF_OK, s_.Code); + //EXPECT_EQ("bunny", value, 5)); + } + + [TestMethod] + public void GetAttributesTest() + { + var desc = graph_.NewOperation("Placeholder", "node"); + desc.SetAttrType("dtype", TF_DataType.TF_FLOAT); + long[] ref_shape = new long[3] { 1, 2, 3 }; + desc.SetAttrShape("shape", ref_shape); + var oper = desc.FinishOperation(s_); + var metadata = oper.GetAttributeMetadata("shape", s_); + } + + public void Dispose() + { + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs new file mode 100644 index 000000000..c162cb725 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs @@ -0,0 +1,103 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Runtime.InteropServices; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// tensorflow\c\c_api_test.cc + /// `class CApiColocationTest` + /// + [TestClass] + public class CApiColocationTest : CApiTest, IDisposable + { + private Graph graph_ = new Graph(); + private Status s_ = new Status(); + private Operation feed1_; + private Operation feed2_; + private Operation constant_; + private OperationDescription desc_; + + [TestInitialize] + public void SetUp() + { + feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); + s_.Check(); + feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); + s_.Check(); + constant_ = c_test_util.ScalarConst(10, graph_, s_); + s_.Check(); + + desc_ = graph_.NewOperation("AddN", "add"); + TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; + desc_.AddInputList(inputs); + } + + private void SetViaStringList(OperationDescription desc, string[] list) + { + var list_ptrs = new IntPtr[list.Length]; + var list_lens = new uint[list.Length]; + StringVectorToArrays(list, list_ptrs, list_lens); + c_api.TF_SetAttrStringList(desc, "_class", list_ptrs, list_lens, list.Length); + } + + private void StringVectorToArrays(string[] v, IntPtr[] ptrs, uint[] lens) + { + for (int i = 0; i < v.Length; ++i) + { + ptrs[i] = Marshal.StringToHGlobalAnsi(v[i]); + lens[i] = (uint)v[i].Length; + } + } + + private void FinishAndVerify(OperationDescription desc, string[] expected) + { + var op = desc_.FinishOperation(s_); + ASSERT_EQ(TF_Code.TF_OK, s_.Code); + VerifyCollocation(op, expected); + } + + private void VerifyCollocation(Operation op, string[] expected) + { + var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_); + TF_AttrMetadata m = new TF_AttrMetadata(); + if (expected.Length == 0) + { + ASSERT_EQ(TF_Code.TF_INVALID_ARGUMENT, s_.Code); + EXPECT_EQ("Operation 'add' has no attr named '_class'.", s_.Message); + return; + } + EXPECT_EQ(TF_Code.TF_OK, s_.Code); + // EXPECT_EQ(1, m.is_list); + // EXPECT_EQ(expected.Length, m.list_size); + // EXPECT_EQ(TF_AttrType.TF_ATTR_STRING, m.type); + string[] values = new string[expected.Length]; + uint[] lens = new uint[expected.Length]; + string[] storage = new string[m.total_size]; + //c_api.TF_OperationGetAttrStringList(op, "_class", values, lens, expected.Length, storage, m.total_size, s_); + // EXPECT_EQ(TF_Code.TF_OK, s_.Code); + for (int i = 0; i < expected.Length; ++i) + { + // EXPECT_EQ(expected[i], values[i] + lens[i]); + } + } + + [TestMethod] + public void ColocateWith() + { + c_api.TF_ColocateWith(desc_, feed1_); + FinishAndVerify(desc_, new string[] { "loc:@feed1" }); + } + + [TestMethod] + public void StringList() + { + SetViaStringList(desc_, new string[] { "loc:@feed1" }); + FinishAndVerify(desc_, new string[] { "loc:@feed1" }); + } + + public void Dispose() + { + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/CApiTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs new file mode 100644 index 000000000..fb4ed482e --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs @@ -0,0 +1,155 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow.Device; +using Tensorflow.Eager; + +namespace Tensorflow.Native.UnitTest +{ + public class CApiTest + { + protected static readonly TF_Code TF_OK = TF_Code.TF_OK; + protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; + protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL; + + protected void EXPECT_TRUE(bool expected, string msg = "") + => Assert.IsTrue(expected, msg); + + protected static void EXPECT_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); + + protected void CHECK_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); + + protected void EXPECT_NE(object expected, object actual, string msg = "") + => Assert.AreNotEqual(expected, actual, msg); + + protected void CHECK_NE(object expected, object actual, string msg = "") + => Assert.AreNotEqual(expected, actual, msg); + + protected void EXPECT_GE(int expected, int actual, string msg = "") + => Assert.IsTrue(expected >= actual, msg); + + protected void ASSERT_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); + + protected void ASSERT_NE(object expected, object actual, string msg = "") + => Assert.AreNotEqual(expected, actual, msg); + + protected void ASSERT_TRUE(bool condition, string msg = "") + => Assert.IsTrue(condition, msg); + + protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName) + => c_api.TF_NewOperation(graph, opType, opName); + + protected void TF_AddInput(OperationDescription desc, TF_Output input) + => c_api.TF_AddInput(desc, input); + + protected Operation TF_FinishOperation(OperationDescription desc, Status s) + => c_api.TF_FinishOperation(desc, s); + + protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) + => c_api.TF_SetAttrTensor(desc, attrName, value, s); + + protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) + => c_api.TF_SetAttrType(desc, attrName, dtype); + + protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) + => c_api.TF_SetAttrBool(desc, attrName, value); + + protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h) + => c_api.TFE_TensorHandleDataType(h); + + protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status) + => c_api.TFE_TensorHandleNumDims(h, status); + + protected TF_Code TF_GetCode(Status s) + => s.Code; + + protected static TF_Code TF_GetCode(SafeStatusHandle s) + => c_api.TF_GetCode(s); + + protected static string TF_Message(SafeStatusHandle s) + => c_api.StringPiece(c_api.TF_Message(s)); + + protected SafeStatusHandle TF_NewStatus() + => c_api.TF_NewStatus(); + + protected IntPtr TF_TensorData(SafeTensorHandle t) + => c_api.TF_TensorData(t); + + protected ulong TF_TensorByteSize(SafeTensorHandle t) + => c_api.TF_TensorByteSize(t); + + protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status) + => c_api.TFE_OpAddInput(op, h, status); + + protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) + => c_api.TFE_OpSetAttrType(op, attr_name, value); + + protected void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) + => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); + + protected void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, uint length) + => c_api.TFE_OpSetAttrString(op, attr_name, value, length); + + protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) + => c_api.TFE_NewOp(ctx, op_or_function_name, status); + + protected SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) + => c_api.TFE_NewTensorHandle(t, status); + + protected void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status) + => c_api.TFE_Execute(op, retvals, out num_retvals, status); + + protected SafeContextOptionsHandle TFE_NewContextOptions() + => c_api.TFE_NewContextOptions(); + + protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) + => c_api.TFE_NewContext(opts, status); + + protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) + => c_api.TFE_OpGetInputLength(op, input_name, status); + + protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status) + => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); + + protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) + => c_api.TFE_OpGetOutputLength(op, input_name, status); + + protected void TFE_DeleteTensorHandle(IntPtr h) + => c_api.TFE_DeleteTensorHandle(h); + + protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) + => c_api.TFE_ContextGetExecutorForThread(ctx); + + protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) + => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); + + protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status) + => c_api.TFE_TensorHandleResolve(h, status); + + protected string TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) + => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); + + protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status) + => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); + + protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) + => c_api.TFE_ContextListDevices(ctx, status); + + protected int TF_DeviceListCount(SafeDeviceListHandle list) + => c_api.TF_DeviceListCount(list); + + protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status) + => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); + + protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) + => c_api.TF_DeviceListName(list, index, status); + + protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) + => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); + + protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) + => c_api.TFE_OpSetDevice(op, device_name, status); + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Context.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Context.cs new file mode 100644 index 000000000..7628bbc2b --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Context.cs @@ -0,0 +1,43 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Device; +using Tensorflow.Eager; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, Context) + /// + [TestMethod] + public void Context() + { + using var status = c_api.TF_NewStatus(); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + static SafeDeviceListHandle ListDevices(SafeStatusHandle status) + { + using var ctx = NewContext(status); + var devices = c_api.TFE_ContextListDevices(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + return devices; + } + + using var devices = ListDevices(status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + int num_devices = c_api.TF_DeviceListCount(devices); + EXPECT_GE(num_devices, 1, TF_Message(status)); + for (int i = 0; i < num_devices; ++i) + { + EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs new file mode 100644 index 000000000..c8502735d --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Execute_MatMul_CPU.cs @@ -0,0 +1,68 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, Execute_MatMul_CPU) + /// + [TestMethod] + public unsafe void Execute_MatMul_CPU() + { + Execute_MatMul_CPU(false); + } + + unsafe void Execute_MatMul_CPU(bool async) + { + using var status = TF_NewStatus(); + + static SafeContextHandle NewContext(bool async, SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); + return c_api.TFE_NewContext(opts, status); + } + + SafeTensorHandle t; + using (var ctx = NewContext(async, status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + var retvals = new SafeEagerTensorHandle[2]; + using (var m = TestMatrixTensorHandle()) + using (var matmul = MatMulOp(ctx, m, m)) + { + int num_retvals; + c_api.TFE_Execute(matmul, retvals, out num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + + try + { + t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + finally + { + retvals[0].Dispose(); + } + } + + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var product = new float[4]; + EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); + tf.memcpy(product, TF_TensorData(t), TF_TensorByteSize(t)); + + t.Dispose(); + EXPECT_EQ(7f, product[0]); + EXPECT_EQ(10f, product[1]); + EXPECT_EQ(15f, product[2]); + EXPECT_EQ(22f, product[3]); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs new file mode 100644 index 000000000..ff31b195d --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpGetInputAndOutputLengths.cs @@ -0,0 +1,70 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Eager; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) + /// + [TestMethod] + public unsafe void OpGetInputAndOutputLengths() + { + using var status = TF_NewStatus(); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + using var input1 = TestMatrixTensorHandle(); + using var input2 = TestMatrixTensorHandle(); + + var retvals = new SafeEagerTensorHandle[2]; + using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + // Try to retrieve lengths before building the attributes (should fail) + EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + + var inputs = new SafeEagerTensorHandle[] { input1, input2 }; + TFE_OpAddInputList(identityOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + // Try to retrieve lengths before executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + int num_retvals; + TFE_Execute(identityOp, retvals, out num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, num_retvals); + + try + { + // Try to retrieve lengths after executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + finally + { + retvals[0].Dispose(); + retvals[1].Dispose(); + } + } + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs new file mode 100644 index 000000000..ab0d51817 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.OpInferMixedTypeInputListAttrs.cs @@ -0,0 +1,52 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Eager; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) + /// + [TestMethod] + public unsafe void OpInferMixedTypeInputListAttrs() + { + using var status = TF_NewStatus(); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + using var condition = TestScalarTensorHandle(true); + using var t1 = TestMatrixTensorHandle(); + using var t2 = TestAxisTensorHandle(); + using (var assertOp = TFE_NewOp(ctx, "Assert", status)) + { + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(assertOp, condition, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var data = new[] { condition, t1, t2 }; + TFE_OpAddInputList(assertOp, data, 3, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + /*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; + var attr_found = attr_values.First(x => x.Name == "T"); + EXPECT_NE(attr_found, attr_values.Last());*/ + // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); + //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); + //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); + + var retvals = new SafeEagerTensorHandle[0]; + int num_retvals; + TFE_Execute(assertOp, retvals, out num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(0, num_retvals); + } + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandle.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandle.cs new file mode 100644 index 000000000..6f5e30b7f --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandle.cs @@ -0,0 +1,31 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TensorHandle) + /// + [TestMethod] + public unsafe void TensorHandle() + { + using var h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); + + var status = c_api.TF_NewStatus(); + var t = c_api.TFE_TensorHandleResolve(h, status); + ASSERT_EQ(16ul, c_api.TF_TensorByteSize(t)); + + var data = new float[] { 0f, 0f, 0f, 0f }; + tf.memcpy(data, c_api.TF_TensorData(t), data.Length * sizeof(float)); + + EXPECT_EQ(1.0f, data[0]); + EXPECT_EQ(2.0f, data[1]); + EXPECT_EQ(3.0f, data[2]); + EXPECT_EQ(4.0f, data[3]); + t.Dispose(); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs new file mode 100644 index 000000000..bc430f87c --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.TensorHandleDevices.cs @@ -0,0 +1,78 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Eager; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TensorHandleDevices) + /// + [TestMethod] + public unsafe void TensorHandleDevices() + { + var status = c_api.TF_NewStatus(); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + using (var hcpu = TestMatrixTensorHandle()) + { + var device_name = TFE_TensorHandleDeviceName(hcpu, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(device_name.Contains("CPU:0")); + + var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(backing_device_name.Contains("CPU:0")); + + // Disable the test if no GPU is present. + string gpu_device_name = ""; + if (GetDeviceName(ctx, ref gpu_device_name, "GPU")) + { + using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + + var retvals = new SafeEagerTensorHandle[1]; + using (var shape_op = ShapeOp(ctx, hgpu)) + { + TFE_OpSetDevice(shape_op, gpu_device_name, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + int num_retvals; + c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + ASSERT_EQ(1, num_retvals); + + try + { + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(device_name.Contains("GPU:0")); + + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(backing_device_name.Contains("CPU:0")); + } + finally + { + retvals[0].Dispose(); + } + } + } + } + + // not export api + using var executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs new file mode 100644 index 000000000..7c43e111a --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.Variables.cs @@ -0,0 +1,66 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, Variables) + /// + [TestMethod] + public unsafe void Variables() + { + using var status = c_api.TF_NewStatus(); + + static SafeContextHandle NewContext(SafeStatusHandle status) + { + using var opts = c_api.TFE_NewContextOptions(); + return c_api.TFE_NewContext(opts, status); + } + + using var ctx = NewContext(status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + using (var var_handle = CreateVariable(ctx, 12.0f, status)) + { + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + int num_retvals = 1; + var value_handle = new SafeEagerTensorHandle[1]; + using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) + { + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_Execute(op, value_handle, out num_retvals, status); + ASSERT_EQ(1, num_retvals); + } + + try + { + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_EQ(1, num_retvals); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0])); + EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status)); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + var value = 0f; // new float[1]; + var t = TFE_TensorHandleResolve(value_handle[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t)); + tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float)); + t.Dispose(); + EXPECT_EQ(12.0f, value); + } + finally + { + value_handle[0].Dispose(); + } + } + + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs new file mode 100644 index 000000000..c38ba5a5c --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Eager/Eager.cs @@ -0,0 +1,158 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow.Eager; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Eager +{ + /// + /// tensorflow\c\eager\c_api_test.cc + /// + [TestClass] + public partial class CApiEagerTest : CApiTest + { + SafeEagerTensorHandle TestMatrixTensorHandle() + { + var dims = new long[] { 2, 2 }; + var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float)); + tf.memcpy(c_api.TF_TensorData(t), data, data.Length * sizeof(float)); + + using var status = c_api.TF_NewStatus(); + var th = c_api.TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + t.Dispose(); + return th; + } + + SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeEagerTensorHandle a, SafeEagerTensorHandle b) + { + using var status = TF_NewStatus(); + + var op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; + } + + bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) + { + using var status = TF_NewStatus(); + using var devices = TFE_ContextListDevices(ctx, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) + { + var dev_type = TF_DeviceListType(devices, i, status); + CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status)); + var dev_name = TF_DeviceListName(devices, i, status); + CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status)); + if (dev_type == device_type) + { + device_name = dev_name; + return true; + } + } + + return false; + } + + SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a) + { + using var status = TF_NewStatus(); + + var op = TFE_NewOp(ctx, "Shape", status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; + } + + unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) + { + var var_handle = new SafeEagerTensorHandle[1]; + int num_retvals; + using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) + { + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + TFE_Execute(op, var_handle, out num_retvals, status); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + CHECK_EQ(1, num_retvals); + } + + // Assign 'value' to it. + using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) + { + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle[0], status); + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); + tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); + + var value_handle = c_api.TFE_NewTensorHandle(t, status); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + + TFE_OpAddInput(op, value_handle, status); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + + c_api.TFE_Execute(op, null, out num_retvals, status); + if (TF_GetCode(status) != TF_OK) return new SafeEagerTensorHandle(IntPtr.Zero); + CHECK_EQ(0, num_retvals); + } + + return var_handle[0]; + } + + SafeEagerTensorHandle TestAxisTensorHandle() + { + var dims = new long[] { 1 }; + var data = new int[] { 1 }; + var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int)); + tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + using var status = TF_NewStatus(); + var th = c_api.TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + t.Dispose(); + return th; + } + + SafeEagerTensorHandle TestScalarTensorHandle(bool value) + { + var data = new[] { value }; + var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); + tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + using var status = TF_NewStatus(); + var th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + t.Dispose(); + return th; + } + + SafeEagerTensorHandle TestScalarTensorHandle(float value) + { + var data = new[] { value }; + var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); + tf.memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + using var status = TF_NewStatus(); + var th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + t.Dispose(); + return th; + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs new file mode 100644 index 000000000..9230bc731 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs @@ -0,0 +1,585 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using static Tensorflow.Native.UnitTest.c_test_util; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// tensorflow\c\c_api_function_test.cc + /// `class CApiColocationTest` + /// + [TestClass] + public class FunctionTest : CApiTest, IDisposable + { + Graph func_graph_; + Graph host_graph_; + string func_name_ = "MyFunc"; + string func_node_name_ = "MyFunc_0"; + Status s_; + SafeFuncGraphHandle func_; + + [TestInitialize] + public void Initialize() + { + func_graph_ = new Graph(); + host_graph_ = new Graph(); + s_ = new Status(); + } + + [TestMethod] + public void OneOp_ZeroInputs_OneOutput() + { + var c = ScalarConst(10, func_graph_, s_, "scalar10"); + // Define + Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]); + + // Use, run, and verify + var func_op = Use(new Operation[0]); + Run(new KeyValuePair[0], func_op, 10); + VerifyFDef(new[] { "scalar10_0" }, + new List(), + new List { new IOSpec("scalar10", DataType.DtInt32) }, + new List { new EdgeSpec("scalar10_0:output:0", "scalar10") }, + new List()); + } + + [TestMethod] + public void OneOp_OneInput_OneOutput() + { + // Define + var feed = Placeholder(func_graph_, s_); + var neg = Neg(feed, func_graph_, s_); + Define(-1, new Operation[0], new[] { feed }, new[] { neg }, new string[0]); + + // Use, run, and verify + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, func_op, -3); + VerifyFDef(new string[] { "neg_0" }, + new List { new IOSpec("feed", DataType.DtInt32) }, + new List { new IOSpec("neg", DataType.DtInt32) }, + new List { new EdgeSpec("feed", "neg_0:0"), new EdgeSpec("neg_0:y:0", "neg") }, + new List()); + } + + [TestMethod] + public void OneOutput_OutputNames() + { + // Define + var feed = Placeholder(func_graph_, s_); + var neg = Neg(feed, func_graph_, s_); + Define(-1, + new Operation[0], + new[] { feed }, + new[] { neg }, + new[] { "negated_num" }); + + // Use, run, and verify + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, func_op, -3); + VerifyFDef(new string[] { "neg" }, + new List { new IOSpec("feed", DataType.DtInt32) }, + new List { new IOSpec("negated_num", DataType.DtInt32) }, + new List { new EdgeSpec("feed", "neg:0"), new EdgeSpec("neg:y:0", "negated_num") }, + new List()); + } + + [TestMethod] + public void OutputNames_SameNameAsInput() + { + // Define + var feed = Placeholder(func_graph_, s_, "negation"); + var neg = Neg(feed, func_graph_, s_, "neg"); + Define(-1, + new Operation[0], + new[] { feed }, + new[] { neg }, + new[] { "negation" }); + + // Use, run, and verify + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, func_op, -3); + VerifyFDef(new string[] { "neg" }, + new List { new IOSpec("negation_0", DataType.DtInt32) }, + new List { new IOSpec("negation", DataType.DtInt32) }, + new List { new EdgeSpec("negation_0", "neg:0"), new EdgeSpec("neg:y:0", "negation") }, + new List()); + } + + [TestMethod] + public void ZeroOps_Identity() + { + // Define + var feed = Placeholder(func_graph_, s_); + Define(-1, + new Operation[0], + new[] { feed }, + new[] { feed }, + new string[0]); + + // Use, run, and verify + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, func_op, 3); + VerifyFDef(new string[0], + new List { new IOSpec("feed_0", DataType.DtInt32) }, + new List { new IOSpec("feed", DataType.DtInt32) }, + new List { new EdgeSpec("feed_0", "feed") }, + new List()); + } + + [TestMethod] + public void ZeroOps_Permutation() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + Define(-1, + null, + new[] { feed1, feed2 }, + new[] { feed2, feed1 }, + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) }, + new[] { 3, 2 }); + VerifyFDef(new string[0], + new List { new IOSpec("feed1_0"), new IOSpec("feed2_0") }, + new List { new IOSpec("feed2"), new IOSpec("feed1") }, + new List { new EdgeSpec("feed1_0", "feed1"), new EdgeSpec("feed2_0", "feed2") }, + new List()); + } + + [TestMethod] + public void ZeroOps_Permutation_OutputNames() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + Define(-1, + null, + new[] { feed1, feed2 }, + new[] { feed2, feed1 }, + new[] { "first", "second" }); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) }, + new[] { 3, 2 }); + VerifyFDef(new string[0], + new List { new IOSpec("feed1"), new IOSpec("feed2") }, + new List { new IOSpec("first"), new IOSpec("second") }, + new List { new EdgeSpec("feed1", "second"), new EdgeSpec("feed2", "first") }, + new List()); + } + + [TestMethod] + public void OneOp_TwoInputs_OneOutput() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var add = Add(feed1, feed2, func_graph_, s_); + Define(-1, + null, + new[] { feed1, feed2 }, + new[] { add }, + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + func_op, + 2 + 3); + VerifyFDef(new string[] { "add_0" }, + new List { new IOSpec("feed1"), new IOSpec("feed2") }, + new List { new IOSpec("add") }, + new List + { + new EdgeSpec("feed1", "add_0:0"), + new EdgeSpec("feed2", "add_0:1"), + new EdgeSpec("add_0:sum:0", "add") + }, + new List()); + } + + [TestMethod] + public void OneOp_TwoInputs_ZeroOutputs() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var add = Add(feed1, feed2, func_graph_, s_); + Define(-1, + null, + new[] { feed1, feed2 }, + new Operation[0], + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + VerifyFDef(new string[] { "add" }, + new List { new IOSpec("feed1"), new IOSpec("feed2") }, + new List(), + new List + { + new EdgeSpec("feed1", "add:0"), + new EdgeSpec("feed2", "add:1") + }, + new List()); + } + + [TestMethod] + public void TwoOps_ThreeInputs_OneOutput() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var feed3 = Placeholder(func_graph_, s_, "feed3"); + var add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + var add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, + null, + new[] { feed1, feed2, feed3 }, + new[] { add2 }, + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_, "two"); + var ten = ScalarConst(10, host_graph_, s_, "ten"); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, ten, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + func_op, + 2 + 10 + 3); + VerifyFDef(new string[] { "add1", "add2_0" }, + new List { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") }, + new List { new IOSpec("add2") }, + new List + { + new EdgeSpec("feed1", "add1:0"), + new EdgeSpec("feed2", "add1:1"), + new EdgeSpec("add1:sum:0", "add2_0:0"), + new EdgeSpec("feed3", "add2_0:1"), + new EdgeSpec("add2_0:sum:0", "add2"), + }, + new List()); + } + + [TestMethod] + public void OneOp_TwoInputs_TwoDuplicateOutputs() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var add = Add(feed1, feed2, func_graph_, s_); + Define(-1, + null, + new[] { feed1, feed2 }, + new[] { add, add }, + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) }, + new[] { 5, 5 }); + VerifyFDef(new string[] { "add_1" }, + new List { new IOSpec("feed1"), new IOSpec("feed2") }, + new List { new IOSpec("add"), new IOSpec("add_0") }, + new List + { + new EdgeSpec("feed1", "add_1:0"), + new EdgeSpec("feed2", "add_1:1"), + new EdgeSpec("add_1:sum:0", "add"), + new EdgeSpec("add_1:sum:0", "add_0") + }, + new List()); + } + + [TestMethod] + public void TwoDuplicateOutputs_OutputNames() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var add = Add(feed1, feed2, func_graph_, s_); + Define(-1, + null, + new[] { feed1, feed2 }, + new[] { add, add }, + new[] { "out1", "out2" }); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) }, + new[] { 5, 5 }); + VerifyFDef(new string[] { "add" }, + new List { new IOSpec("feed1"), new IOSpec("feed2") }, + new List { new IOSpec("out1"), new IOSpec("out2") }, + new List + { + new EdgeSpec("feed1", "add:0"), + new EdgeSpec("feed2", "add:1"), + new EdgeSpec("add:sum:0", "out1"), + new EdgeSpec("add:sum:0", "out2") + }, + new List()); + } + + [TestMethod] + public void TwoOps_ThreeInputs_TwoOutputs() + { + // Define + var feed1 = Placeholder(func_graph_, s_, "feed1"); + var feed2 = Placeholder(func_graph_, s_, "feed2"); + var feed3 = Placeholder(func_graph_, s_, "feed3"); + var add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + var add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, + null, + new[] { feed1, feed2, feed3 }, + new[] { add1, add2 }, + null); + + // Use, run, and verify + var two = ScalarConst(2, host_graph_, s_, "two"); + var ten = ScalarConst(10, host_graph_, s_, "ten"); + var func_feed = Placeholder(host_graph_, s_); + var func_op = Use(new[] { two, ten, func_feed }); + Run(new[] { new KeyValuePair(func_feed, Int32Tensor(3)) }, + new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) }, + new[] { 12, 15 }); + VerifyFDef(new string[] { "add1_0", "add2_0" }, + new List { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") }, + new List { new IOSpec("add1"), new IOSpec("add2") }, + new List + { + new EdgeSpec("feed1", "add1_0:0"), + new EdgeSpec("feed2", "add1_0:1"), + new EdgeSpec("add1_0:sum:0", "add2_0:0"), + new EdgeSpec("feed3", "add2_0:1"), + new EdgeSpec("add1_0:sum:0", "add1"), + new EdgeSpec("add2_0:sum:0", "add2") + }, + new List()); + } + + void Define(int num_opers, Operation[] opers, + Operation[] inputs, Operation[] outputs, + string[] output_names, bool expect_failure = false) + => DefineT(num_opers, opers, + inputs.Select(x => new TF_Output(x, 0)).ToArray(), + outputs.Select(x => new TF_Output(x, 0)).ToArray(), + output_names, expect_failure); + + void DefineT(int num_opers, Operation[] opers, + TF_Output[] inputs, TF_Output[] outputs, + string[] output_names, bool expect_failure = false) + { + func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false, + num_opers, num_opers == -1 ? null : opers.Select(x => (IntPtr)x).ToArray(), + inputs.Length, inputs.ToArray(), + outputs.Length, outputs.ToArray(), + output_names == null || output_names.Length == 0 ? null : output_names, + IntPtr.Zero, null, s_); + + if (expect_failure) + { + ASSERT_EQ(IntPtr.Zero, func_); + return; + } + + ASSERT_EQ(TF_OK, s_.Code, s_.Message); + ASSERT_NE(func_, IntPtr.Zero); + ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); + c_api.TF_GraphCopyFunction(host_graph_, func_, new SafeFuncGraphHandle(IntPtr.Zero), s_); + ASSERT_EQ(TF_OK, s_.Code, s_.Message); + } + + Operation Use(Operation[] inputs) + => UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray()); + + Operation UseT(TF_Output[] inputs) + => UseHelper(inputs); + + Operation UseHelper(TF_Output[] inputs) + { + var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_); + foreach (var input in inputs) + TF_AddInput(desc, input); + c_api.TF_SetDevice(desc, "/cpu:0"); + var op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, s_.Code, s_.Message); + ASSERT_NE(op, IntPtr.Zero); + + return op; + } + + void Run(KeyValuePair[] inputs, Operation output, int expected_result) + => Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result }); + + unsafe void Run(KeyValuePair[] inputs, TF_Output[] outputs, int[] expected_results) + { + var csession = new CSession(host_graph_, s_); + ASSERT_EQ(TF_OK, s_.Code, s_.Message); + + csession.SetInputs(inputs); + csession.SetOutputs(outputs); + csession.Run(s_); + ASSERT_EQ(TF_OK, s_.Code, s_.Message); + + for (int i = 0; i < expected_results.Length; ++i) + { + var output = csession.output_tensor(i); + ASSERT_TRUE(!output.IsInvalid); + EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output)); + EXPECT_EQ(0, c_api.TF_NumDims(output)); + ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output)); + var output_contents = c_api.TF_TensorData(output); + EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer()); + } + } + + void VerifyFDef(string[] nodes, List inputs, List outputs, + List e_edges, List c_edges, + bool is_exact_edges = true) + { + var fdef = GetFunctionDef(func_); + EXPECT_NE(fdef, IntPtr.Zero); + VerifyFDefNodes(fdef, nodes); + VerifyFDefInputs(fdef, inputs); + VerifyFDefOutputs(fdef, outputs); + VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges); + } + + void VerifyFDefNodes(FunctionDef fdef, string[] nodes) + { + ASSERT_EQ(nodes.Length, fdef.NodeDef.Count); + foreach (var node in fdef.NodeDef) + { + ASSERT_TRUE(nodes.Contains(node.Name), $"Got unexpected node: {node.Name} in fdef: {fdef}"); + } + } + + void VerifyFDefInputs(FunctionDef fdef, List inputs) + { + var signature = fdef.Signature; + ASSERT_EQ(inputs.Count, signature.InputArg.Count); + for (int i = 0; i < inputs.Count; ++i) + { + var arg = signature.InputArg[i]; + var input = inputs[i]; + if (input.Value != DataType.DtInvalid) + ASSERT_EQ(arg.Type, input.Value, $""); + ASSERT_EQ(arg.Name, input.Key, $"Got unexpected name for input {i}. fdef: {fdef}"); + } + } + + void VerifyFDefOutputs(FunctionDef fdef, List outputs) + { + var signature = fdef.Signature; + ASSERT_EQ(outputs.Count, signature.OutputArg.Count); + for (int i = 0; i < outputs.Count; ++i) + { + var arg = signature.OutputArg[i]; + var output = outputs[i]; + if (output.Value != DataType.DtInvalid) + ASSERT_EQ(arg.Type, output.Value, $""); + ASSERT_EQ(arg.Name, output.Key, $"Got unexpected name for input {i}. fdef: {fdef}"); + } + } + + void VerifyFDefEdges(FunctionDef fdef, List e_edges, List c_edges, bool is_exact_edges = true) + { + // Build a set of edges from fdef + var a_edges = new List(); // actual edges + // Get edges from inputs to body nodes and between body nodes + foreach (var node in fdef.NodeDef) + { + for (int i = 0; i < node.Input.Count; ++i) + { + var input = node.Input[i]; + a_edges.Add(new EdgeSpec(input, $"{node.Name}:{i}")); + } + } + // Get edges from body nodes to outputs and from inputs to outputs + foreach (var arg in fdef.Signature.OutputArg) + { + var iter = fdef.Ret.FirstOrDefault(x => x.Key == arg.Name); + if (iter.Key != null) + { + a_edges.Add(new EdgeSpec(iter.Value, arg.Name)); + } + else + { + a_edges.Add(new EdgeSpec(arg.Name, arg.Name)); + } + } + // Verify edges + foreach (var edge in e_edges) + { + ASSERT_TRUE(a_edges.Contains(edge)); + } + foreach (var edge in c_edges) + { + ASSERT_TRUE(a_edges.Contains(edge)); + } + // If caller specified all edges, check that we have seen all + if (is_exact_edges) + { + ASSERT_EQ(e_edges.Count + c_edges.Count, a_edges.Count, + $"Expected edges: {e_edges}, Expected Control edges: {c_edges}, Actual edges: {a_edges}"); + } + } + + public void Dispose() + { + + } + + public struct IOSpec + { + KeyValuePair pair; + public string Key => pair.Key; + public DataType Value => pair.Value; + + public IOSpec(string key, DataType value = DataType.DtInvalid) + { + pair = new KeyValuePair(key, value); + } + } + + public struct EdgeSpec + { + KeyValuePair pair; + public string Key => pair.Key; + public string Value => pair.Value; + + public EdgeSpec(string key, string value) + { + pair = new KeyValuePair(key, value); + } + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Gradients/GradientsTest.cs b/test/TensorFlowNET.Native.UnitTest/Gradients/GradientsTest.cs new file mode 100644 index 000000000..79fa44890 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Gradients/GradientsTest.cs @@ -0,0 +1,276 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using Tensorflow.Util; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// tensorflow\c\c_api_test.cc + /// `class CApiGradientsTest` + /// + [TestClass] + public class GradientsTest : CApiTest, IDisposable + { + private Graph graph_ = new Graph(); + private Graph expected_graph_ = new Graph(); + private Status s_ = new Status(); + + private void TestGradientsSuccess(bool grad_inputs_provided) + { + var inputs = new TF_Output[2]; + var outputs = new TF_Output[1]; + var grad_outputs = new TF_Output[2]; + var expected_grad_outputs = new TF_Output[2]; + + BuildSuccessGraph(inputs, outputs); + BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); + + AddGradients(grad_inputs_provided, "gradients", inputs, 2, outputs, 1, + grad_outputs); + EXPECT_EQ(TF_OK, TF_GetCode(s_)); + + // Compare that the graphs match. + GraphDef expected_gdef; + GraphDef gdef; + EXPECT_TRUE(GetGraphDef(expected_graph_, out expected_gdef)); + EXPECT_TRUE(GetGraphDef(graph_, out gdef)); + // Assert.IsTrue(expected_gdef.ToString().Equals(gdef.ToString())); + + // Compare that the output of the gradients of both graphs match. + RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs); + } + + private bool GetGraphDef(Graph graph, out GraphDef graph_def) + { + graph_def = null; + var s = new Status(); + var buffer = new Buffer(); + c_api.TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)); + if (ret) + graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); + return ret; + } + + private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) + { + var csession = new CSession(graph_, s_); + var expected_csession = new CSession(expected_graph_, s_); + + var grad_outputs_vec = grad_outputs; + csession.SetOutputs(grad_outputs_vec); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)); + var out0 = csession.output_tensor(0); + var out1 = csession.output_tensor(1); + + var expected_grad_outputs_vec = expected_grad_outputs; + expected_csession.SetOutputs(expected_grad_outputs_vec); + expected_csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)); + var expected_out0 = expected_csession.output_tensor(0); + var expected_out1 = expected_csession.output_tensor(1); + + //CompareTensors(out0, expected_out0); + //CompareTensors(out1, expected_out1); + } + /*void TestGradientsError(bool grad_inputs_provided) + { + var inputs = new TF_Output[1]; + var outputs = new TF_Output[1]; + var grad_outputs = new TF_Output[1]; + + BuildErrorGraph(inputs, outputs); + + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); + + string expected_msg = + "No gradient defined for op: TestOpWithNoGradient. Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + EXPECT_EQ(expected_msg, TF_Message(s_)); + }*/ + + private void AddGradients(bool grad_inputs_provided, string prefix, TF_Output[] inputs, int ninputs, + TF_Output[] outputs, int noutputs, TF_Output[] grad_outputs) + { + if (grad_inputs_provided) + { + var grad_inputs = new TF_Output[1]; + float[] grad_inputs_val = { 1.0f, 1.0f, 1.0f, 1.0f }; + var grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); + grad_inputs[0] = new TF_Output(grad_inputs_op, 0); + + IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; + c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, handles); + + // var op = new Operation(handles[0]); + } + else + { + //c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + //ninputs, null, s_, grad_outputs); + } + } + + private void BuildSuccessGraph(TF_Output[] inputs, TF_Output[] outputs) + { + // Construct the following graph: + // | + // z| + // | + // MatMul + // / \ + // ^ ^ + // | | + // x| y| + // | | + // | | + // Const_0 Const_1 + // + var const0_val = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + var const1_val = new float[] { 1.0f, 0.0f, 0.0f, 1.0f }; + var const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + var const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1"); + var matmul = MatMul(graph_, s_, const0, const1, "MatMul"); + inputs[0] = new TF_Output(const0, 0); + inputs[1] = new TF_Output(const1, 0); + outputs[0] = new TF_Output(matmul, 0); + EXPECT_EQ(TF_OK, TF_GetCode(s_)); + } + + private void BuildExpectedGraph(bool grad_inputs_provided, TF_Output[] expected_grad_outputs) + { + // The expected graph looks like this if grad_inputs_provided. + // If grad_inputs_provided is false, Const_0 will be a OnesLike op. + // ^ ^ + // dy| dx| // MatMul Gradient Graph + // | | + // MatMul_2 MatMul_1 + // ^ ^ ^ ^ + // | |----------| | + // | ^ | + // | dz| | + // | | | + // | Const_3 | + // | | + // | ^ | + // | z| | // MatMul Forward Graph + // | | | + // | MatMul | + // | / \ | + // | ^ ^ | + // | | | | + // |---x| y|----| + // | | + // | | + // Const_0 Const_1 + // + float[] const0_val = { 1.0f, 2.0f, 3.0f, 4.0f }; + float[] const1_val = { 1.0f, 0.0f, 0.0f, 1.0f }; + var const0 = FloatConst2x2(expected_graph_, s_, const0_val, "Const_0"); + var const1 = FloatConst2x2(expected_graph_, s_, const1_val, "Const_1"); + var matmul = MatMul(expected_graph_, s_, const0, const1, "MatMul"); + + Operation const3; + if (grad_inputs_provided) + { + float[] const3_val = { 1.0f, 1.0f, 1.0f, 1.0f }; + const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); + } + else + { + const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike"); + } + + var matmul1 = MatMul(expected_graph_, s_, const3, const1, + "gradients/MatMul", false, true); + var matmul2 = MatMul(expected_graph_, s_, const0, const3, + "gradients/MatMul_1", true, false); + expected_grad_outputs[0] = new TF_Output(matmul1, 0); + expected_grad_outputs[1] = new TF_Output(matmul2, 0); + } + + private Operation OnesLike(Graph graph, Status s, Operation input, string name) + { + var desc = TF_NewOperation(graph, "OnesLike", name); + TF_AddInput(desc, new TF_Output(input, 0)); + var op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + return op; + } + + private Operation FloatConst2x2(Graph graph, Status s, float[] values, string name) + { + var tensor = FloatTensor2x2(values); + var desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", tensor, s); + if (TF_GetCode(s) != TF_OK) return IntPtr.Zero; + TF_SetAttrType(desc, "dtype", TF_FLOAT); + var op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + return op; + } + + private Tensor FloatTensor2x2(float[] values) + { + //long[] dims = { 2, 2 }; + //Tensor t = c_api.TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); + //Marshal.Copy(values, 0, t, 4); + Tensor t = np.array(values).reshape((2, 2)); + return t; + } + + private Operation MatMul(Graph graph, Status s, Operation l, Operation r, string name, + bool transpose_a = false, bool transpose_b = false) + { + var desc = TF_NewOperation(graph, "MatMul", name); + if (transpose_a) + { + TF_SetAttrBool(desc, "transpose_a", true); + } + if (transpose_b) + { + TF_SetAttrBool(desc, "transpose_b", true); + } + TF_AddInput(desc, new TF_Output(l, 0)); + TF_AddInput(desc, new TF_Output(r, 0)); + var op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + return op; + } + + [TestMethod] + public void Gradients_GradInputs() + { + //TestGradientsSuccess(true); + } + + [TestMethod] + public void Gradients_NoGradInputs() + { + //TestGradientsSuccess(false); + } + + [TestMethod] + public void OpWithNoGradientRegistered_GradInputs() + { + //TestGradientsError(true); + } + + [TestMethod] + public void OpWithNoGradientRegistered_NoGradInputs() + { + //TestGradientsError(false); + } + + public void Dispose() + { + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Graphs/GraphBuildTest.cs b/test/TensorFlowNET.Native.UnitTest/Graphs/GraphBuildTest.cs new file mode 100644 index 000000000..ed39882e5 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Graphs/GraphBuildTest.cs @@ -0,0 +1,30 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest +{ + [TestClass] + public class GraphBuildTest : CApiTest + { + [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] + public void UpdateEdge() + { + var graph = new Graph().as_default(); + + var one = tf.constant(1, name: "one"); + var two = tf.constant(2, name: "two"); + var add = tf.add(one, two, name: "add"); + var neg = tf.negative(add, name: "neg"); + + Assert.AreEqual(1, one.consumers().Length); + Assert.AreEqual("add", neg.op.node_def.Input[0]); + + // update edge + neg.op._update_input(0, one); + // c_api.TF_UpdateEdge(graph, new TF_Output(c1.op, 0), new TF_Input(neg.op, 0), tf.Status.Handle); + + Assert.AreEqual(2, one.consumers().Length); + Assert.AreEqual("one:0", neg.op.node_def.Input[0]); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Graphs/GraphTest.cs b/test/TensorFlowNET.Native.UnitTest/Graphs/GraphTest.cs new file mode 100644 index 000000000..33b5cd9f3 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Graphs/GraphTest.cs @@ -0,0 +1,425 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest +{ + [TestClass] + public class GraphTest : CApiTest + { + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, Graph)` + /// + [TestMethod] + public void Graph() + { + var s = new Status(); + var graph = new Graph(); + + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); + EXPECT_EQ("feed", feed.name); + EXPECT_EQ("Placeholder", feed.OpType); + EXPECT_EQ("", feed.Device); + EXPECT_EQ(1, feed.NumOutputs); + EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); + EXPECT_EQ(1, feed.OutputListLength("output")); + EXPECT_EQ(0, feed.NumInputs); + EXPECT_EQ(0, feed.OutputNumConsumers(0)); + EXPECT_EQ(0, feed.NumControlInputs); + EXPECT_EQ(0, feed.NumControlOutputs); + + AttrValue attr_value = null; + ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); + EXPECT_EQ(attr_value.Type, DataType.DtInt32); + + // Test not found errors in TF_Operation*() query functions. + EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); + EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); + Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); + EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); + + // Make a constant oper with the scalar "3". + var three = c_test_util.ScalarConst(3, graph, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Add oper. + var add = c_test_util.Add(feed, three, graph, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Test TF_Operation*() query functions. + EXPECT_EQ("add", add.name); + EXPECT_EQ("AddN", add.OpType); + EXPECT_EQ("", add.Device); + EXPECT_EQ(1, add.NumOutputs); + EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); + EXPECT_EQ(1, add.OutputListLength("sum")); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + EXPECT_EQ(2, add.InputListLength("inputs")); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); + EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); + var add_in_0 = add.Input(0); + EXPECT_EQ(feed, add_in_0.oper); + EXPECT_EQ(0, add_in_0.index); + var add_in_1 = add.Input(1); + EXPECT_EQ(three, add_in_1.oper); + EXPECT_EQ(0, add_in_1.index); + EXPECT_EQ(0, add.OutputNumConsumers(0)); + EXPECT_EQ(0, add.NumControlInputs); + EXPECT_EQ(0, add.NumControlOutputs); + + ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); + EXPECT_EQ(DataType.DtInt32, attr_value.Type); + ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); + EXPECT_EQ(2, (int)attr_value.I); + + // Placeholder oper now has a consumer. + EXPECT_EQ(1, feed.OutputNumConsumers(0)); + TF_Input[] feed_port = feed.OutputConsumers(0, 1); + EXPECT_EQ(1, feed_port.Length); + EXPECT_EQ(add, feed_port[0].oper); + EXPECT_EQ(0, feed_port[0].index); + + // The scalar const oper also has a consumer. + EXPECT_EQ(1, three.OutputNumConsumers(0)); + TF_Input[] three_port = three.OutputConsumers(0, 1); + EXPECT_EQ(add, three_port[0].oper); + EXPECT_EQ(1, three_port[0].index); + + // Serialize to GraphDef. + var graph_def = c_test_util.GetGraphDef(graph); + + // Validate GraphDef is what we expect. + bool found_placeholder = false; + bool found_scalar_const = false; + bool found_add = false; + foreach (var n in graph_def.Node) + { + if (c_test_util.IsPlaceholder(n)) + { + Assert.IsFalse(found_placeholder); + found_placeholder = true; + } + else if (c_test_util.IsScalarConst(n, 3)) + { + Assert.IsFalse(found_scalar_const); + found_scalar_const = true; + } + else if (c_test_util.IsAddN(n, 2)) + { + Assert.IsFalse(found_add); + found_add = true; + } + else + { + Assert.Fail($"Unexpected NodeDef: {n}"); + } + } + ASSERT_TRUE(found_placeholder); + ASSERT_TRUE(found_scalar_const); + ASSERT_TRUE(found_add); + + // Add another oper to the graph. + var neg = c_test_util.Neg(add, graph, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Serialize to NodeDef. + var node_def = neg.node_def; + + // Validate NodeDef is what we expect. + ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); + + // Serialize to GraphDef. + var graph_def2 = c_test_util.GetGraphDef(graph); + + // Compare with first GraphDef + added NodeDef. + graph_def.Node.Add(node_def); + EXPECT_EQ(graph_def, graph_def2); + + // Look up some nodes by name. + Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); + EXPECT_EQ(neg, neg2); + var node_def2 = neg2.node_def; + EXPECT_EQ(node_def, node_def2); + + Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); + EXPECT_EQ(feed, feed2); + node_def = feed.node_def; + node_def2 = feed2.node_def; + EXPECT_EQ(node_def, node_def2); + + // Test iterating through the nodes of a graph. + found_placeholder = false; + found_scalar_const = false; + found_add = false; + bool found_neg = false; + uint pos = 0; + Operation oper; + + while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) + { + if (oper.Equals(feed)) + { + Assert.IsFalse(found_placeholder); + found_placeholder = true; + } + else if (oper.Equals(three)) + { + Assert.IsFalse(found_scalar_const); + found_scalar_const = true; + } + else if (oper.Equals(add)) + { + Assert.IsFalse(found_add); + found_add = true; + } + else if (oper.Equals(neg)) + { + Assert.IsFalse(found_neg); + found_neg = true; + } + else + { + node_def = oper.node_def; + Assert.Fail($"Unexpected Node: {node_def.ToString()}"); + } + } + + ASSERT_TRUE(found_placeholder); + ASSERT_TRUE(found_scalar_const); + ASSERT_TRUE(found_add); + ASSERT_TRUE(found_neg); + } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, ImportGraphDef)` + /// + [TestMethod] + public void ImportGraphDef() + { + var s = new Status(); + var graph = new Graph().as_default(); + + // Create a simple graph. + c_test_util.Placeholder(graph, s); + var oper = c_test_util.ScalarConst(3, graph, s); + c_test_util.Neg(oper, graph, s); + + // Export to a GraphDef. + var graph_def = new Buffer(); + c_api.TF_GraphToGraphDef(graph, graph_def, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Import it, with a prefix, in a fresh graph. + graph = new Graph().as_default(); + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + } + + Operation scalar = graph.OperationByName("imported/scalar"); + Operation feed = graph.OperationByName("imported/feed"); + Operation neg = graph.OperationByName("imported/neg"); + + // Test basic structure of the imported graph. + EXPECT_EQ(0, scalar.NumInputs); + EXPECT_EQ(0, feed.NumInputs); + EXPECT_EQ(1, neg.NumInputs); + + var neg_input = neg.Input(0); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Test that we can't see control edges involving the source and sink nodes. + EXPECT_EQ(0, scalar.NumControlInputs); + EXPECT_EQ(0, scalar.GetControlInputs().Length); + EXPECT_EQ(0, scalar.NumControlOutputs); + EXPECT_EQ(0, scalar.GetControlOutputs().Length); + + EXPECT_EQ(0, feed.NumControlInputs); + EXPECT_EQ(0, feed.GetControlInputs().Length); + EXPECT_EQ(0, feed.NumControlOutputs); + EXPECT_EQ(0, feed.GetControlOutputs().Length); + + EXPECT_EQ(0, neg.NumControlInputs); + EXPECT_EQ(0, neg.GetControlInputs().Length); + EXPECT_EQ(0, neg.NumControlOutputs); + EXPECT_EQ(0, neg.GetControlOutputs().Length); + + static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar) + { + using var opts = c_api.TF_NewImportGraphDefOptions(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); + c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); + var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + return results; + } + + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. + Operation feed2; + using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar)) + { + Operation scalar2 = graph.OperationByName("imported2/scalar"); + feed2 = graph.OperationByName("imported2/feed"); + Operation neg2 = graph.OperationByName("imported2/neg"); + + // Check input mapping + neg_input = neg.Input(0); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Check return outputs + var return_outputs = graph.ReturnOutputs(results); + ASSERT_EQ(2, return_outputs.Length); + EXPECT_EQ(feed2, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); // remapped + EXPECT_EQ(0, return_outputs[1].index); + + // Check return operation + var return_opers = graph.ReturnOperations(results); + ASSERT_EQ(1, return_opers.Length); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + } + + // Import again, with control dependencies, into the same graph. + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); + c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + } + + var scalar3 = graph.OperationByName("imported3/scalar"); + var feed3 = graph.OperationByName("imported3/feed"); + var neg3 = graph.OperationByName("imported3/neg"); + ASSERT_TRUE(scalar3 != IntPtr.Zero); + ASSERT_TRUE(feed3 != IntPtr.Zero); + ASSERT_TRUE(neg3 != IntPtr.Zero); + + // Check that newly-imported scalar and feed have control deps (neg3 will + // inherit them from input) + var control_inputs = scalar3.GetControlInputs(); + ASSERT_EQ(2, scalar3.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + control_inputs = feed3.GetControlInputs(); + ASSERT_EQ(2, feed3.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + // Export to a graph def so we can import a graph with control dependencies + graph_def = new Buffer(); + c_api.TF_GraphToGraphDef(graph, graph_def, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Import again, with remapped control dependency, into the same graph + using (var opts = c_api.TF_NewImportGraphDefOptions()) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } + + var scalar4 = graph.OperationByName("imported4/imported3/scalar"); + var feed4 = graph.OperationByName("imported4/imported2/feed"); + + // Check that imported `imported3/scalar` has remapped control dep from + // original graph and imported control dep + control_inputs = scalar4.GetControlInputs(); + ASSERT_EQ(2, scalar4.NumControlInputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed4, control_inputs[1]); + + // Can add nodes to the imported graph without trouble. + c_test_util.Add(feed, scalar, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)` + /// + [TestMethod] + public void ImportGraphDef_WithReturnOutputs() + { + var s = new Status(); + var graph = new Graph().as_default(); + + // Create a graph with two nodes: x and 3 + c_test_util.Placeholder(graph, s); + ASSERT_TRUE(graph.OperationByName("feed") != null); + var oper = c_test_util.ScalarConst(3, graph, s); + ASSERT_TRUE(graph.OperationByName("scalar") != null); + c_test_util.Neg(oper, graph, s); + ASSERT_TRUE(graph.OperationByName("neg") != null); + + // Export to a GraphDef. + var graph_def = graph.ToGraphDef(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Import it in a fresh graph with return outputs. + graph = new Graph().as_default(); + var opts = new ImportGraphDefOptions(); + opts.AddReturnOutput("feed", 0); + opts.AddReturnOutput("scalar", 0); + EXPECT_EQ(2, opts.NumReturnOutputs); + var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + var scalar = graph.OperationByName("scalar"); + var feed = graph.OperationByName("feed"); + var neg = graph.OperationByName("neg"); + ASSERT_TRUE(scalar != IntPtr.Zero); + ASSERT_TRUE(feed != IntPtr.Zero); + ASSERT_TRUE(neg != IntPtr.Zero); + + // Check return outputs + EXPECT_EQ(feed, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); + EXPECT_EQ(0, return_outputs[1].index); + } + + /// + /// `TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings)` + /// + [TestMethod] + public void ImportGraphDef_MissingUnusedInputMappings() + { + + } + + [Ignore] + [TestMethod] + public void ImportGraphMeta() + { + var dir = "my-save-dir/"; + var sess = tf.Session(); + var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); + new_saver.restore(sess, dir + "my-model-10000"); + var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); + var batch_size = tf.size(labels); + var logits = tf.get_collection("logits")[0] as Tensor; + var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, + logits: logits); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs new file mode 100644 index 000000000..4d0d6d8c9 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs @@ -0,0 +1,141 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.Lite; + +namespace Tensorflow.Native.UnitTest +{ + [TestClass] + public class TfLiteTest + { + [TestMethod] + [Ignore] + public void TfLiteVersion() + { + var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion()); + Assert.IsNotNull(ver); + } + + [TestMethod] + [Ignore] + public unsafe void SmokeTest() + { + var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin"); + var options = c_api_lite.TfLiteInterpreterOptionsCreate(); + c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2); + + var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options); + + c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle()); + c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); + + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); + Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter)); + Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter)); + + var input_dims = new int[] { 2 }; + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length)); + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); + + var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); + Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor)); + Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); + Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); + Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor)); + Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor)); + Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor))); + + var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); + Assert.AreEqual(0f, input_params.scale); + Assert.AreEqual(0, input_params.zero_point); + + var input = new[] { 1f, 3f }; + fixed (float* addr = &input[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float))); + } + + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); + + var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); + Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor)); + Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor)); + Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0)); + Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor)); + Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor)); + Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor))); + + var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); + Assert.AreEqual(0f, output_params.scale); + Assert.AreEqual(0, output_params.zero_point); + + var output = new float[2]; + fixed (float* addr = &output[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float))); + } + Assert.AreEqual(3f, output[0]); + Assert.AreEqual(9f, output[1]); + + c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); + } + + [TestMethod] + [Ignore] + public unsafe void QuantizationParamsTest() + { + var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin"); + var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero)); + c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); + var input_dims = new[] { 2 }; + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1)); + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); + + var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); + Assert.IsNotNull(input_tensor); + + Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor)); + Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); + Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); + + var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); + Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point)); + + var input = new byte[] { 1, 3 }; + fixed (byte* addr = &input[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte))); + } + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); + + var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); + Assert.IsNotNull(output_tensor); + + var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); + Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point)); + + var output = new byte[2]; + fixed (byte* addr = &output[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte))); + } + Assert.AreEqual(3f, output[0]); + Assert.AreEqual(9f, output[1]); + + var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point); + var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point); + Assert.AreEqual(dequantizedOutput0, 0.011766f); + Assert.AreEqual(dequantizedOutput1, 0.035298f); + + c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin new file mode 100644 index 000000000..b4c02350c Binary files /dev/null and b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin differ diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin new file mode 100644 index 000000000..07d48b93e Binary files /dev/null and b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin differ diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/CSession.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/CSession.cs new file mode 100644 index 000000000..e79571000 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Sessions/CSession.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Util; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// tensorflow\c\c_test_util.cc + /// TEST(CAPI, Session) + /// + public class CSession + { + private SafeSessionHandle session_; + + private List inputs_ = new List(); + private List input_values_ = new List(); + private List outputs_ = new List(); + private List output_values_ = new List(); + + private List targets_ = new List(); + + public CSession(Graph graph, Status s, bool user_XLA = false) + { + var config = new ConfigProto { InterOpParallelismThreads = 4 }; + session_ = new Session(graph, config, s); + } + + public void SetInputs(Dictionary inputs) + { + DeleteInputValues(); + inputs_.Clear(); + foreach (var input in inputs) + { + inputs_.Add(new TF_Output(input.Key, 0)); + input_values_.Add(input.Value); + } + } + + public void SetInputs(KeyValuePair[] inputs) + { + DeleteInputValues(); + inputs_.Clear(); + foreach (var input in inputs) + { + inputs_.Add(new TF_Output(input.Key, 0)); + input_values_.Add(input.Value); + } + } + + private void DeleteInputValues() + { + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. + input_values_.Clear(); + } + + public void SetOutputs(TF_Output[] outputs) + { + ResetOutputValues(); + outputs_.Clear(); + foreach (var output in outputs) + { + outputs_.Add(output); + output_values_.Add(null); + } + } + + private void ResetOutputValues() + { + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. + output_values_.Clear(); + } + + public unsafe void Run(Status s) + { + var inputs_ptr = inputs_.ToArray(); + var input_values_ptr = input_values_.Select(x => x.Handle.DangerousGetHandle()).ToArray(); + var outputs_ptr = outputs_.ToArray(); + var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); + IntPtr[] targets_ptr = new IntPtr[0]; + + c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, + outputs_ptr, output_values_ptr, outputs_.Count, + targets_ptr, targets_.Count, + IntPtr.Zero, s); + + s.Check(); + + for (var i = 0; i < outputs_.Count; i++) + output_values_[i] = new SafeTensorHandle(output_values_ptr[i]); + } + + public SafeTensorHandle output_tensor(int i) + { + return output_values_[i].Handle; + } + + public void CloseAndDelete(Status s) + { + DeleteInputValues(); + ResetOutputValues(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs new file mode 100644 index 000000000..74f9366c7 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs @@ -0,0 +1,74 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; + +namespace Tensorflow.Native.UnitTest.Sessions +{ + [TestClass] + public class SessionTest : CApiTest + { + /// + /// tensorflow\c\c_api_test.cc + /// `TEST(CAPI, Session)` + /// + [TestMethod] + public void Session() + { + var s = new Status(); + var graph = new Graph(); + + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); + + // Make a constant operation with the scalar "2". + var two = c_test_util.ScalarConst(2, graph, s); + + // Add operation. + var add = c_test_util.Add(feed, two, graph, s); + + var csession = new CSession(graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run the graph. + var inputs = new Dictionary(); + inputs.Add(feed, new Tensor(3)); + csession.SetInputs(inputs); + + var outputs = new TF_Output[] { new TF_Output(add, 0) }; + csession.SetOutputs(outputs); + + csession.Run(s); + Tensor outTensor = csession.output_tensor(0); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.ndim); + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + var output_contents = outTensor.ToArray(); + EXPECT_EQ(3 + 2, output_contents[0]); + + // Add another operation to the graph. + var neg = c_test_util.Neg(add, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run up to the new operation. + inputs = new Dictionary(); + inputs.Add(feed, new Tensor(7)); + csession.SetInputs(inputs); + outputs = new TF_Output[] { new TF_Output(neg, 0) }; + csession.SetOutputs(outputs); + csession.Run(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + outTensor = csession.output_tensor(0); + ASSERT_TRUE(outTensor.Handle.DangerousGetHandle() != IntPtr.Zero); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.ndim); // scalar + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + output_contents = outTensor.ToArray(); + EXPECT_EQ(-(7 + 2), output_contents[0]); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj new file mode 100644 index 000000000..c054a8707 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj @@ -0,0 +1,61 @@ + + + + net6.0 + + false + + AnyCPU;x64 + + + + true + DEBUG;TRACE + x64 + + + + true + DEBUG;TRACE + x64 + + + + true + + + + true + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs new file mode 100644 index 000000000..6ccc6cdd1 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs @@ -0,0 +1,221 @@ +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using System.Runtime.InteropServices; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Tensors +{ + [TestClass] + public class TensorTest : CApiTest + { + [TestMethod] + public unsafe void TensorFromFixed() + { + var array = new float[1000]; + var span = new Span(array, 100, 500); + fixed (float* ptr = &MemoryMarshal.GetReference(span)) + { + using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(2000, (int)t.bytesize); + } + } + + fixed (float* ptr = &array[0]) + { + using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(4000, (int)t.bytesize); + } + } + } + + [TestMethod] + public void TensorFromArray() + { + var array = new float[1000]; + using (var t = new Tensor(array)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); + } + + using (var t = new Tensor(1)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); + Assert.AreEqual(t.shape, Shape.Scalar); + } + } + + [TestMethod] + public void AllocateTensor() + { + ulong num_bytes = 6 * sizeof(float); + long[] dims = { 2, 3 }; + Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); + EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); + EXPECT_EQ(2, t.ndim); + EXPECT_EQ(dims[0], t.shape[0]); + EXPECT_EQ(num_bytes, t.bytesize); + t.Dispose(); + } + + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, MaybeMove)` + /// + [TestMethod, Ignore] + public void MaybeMove() + { + Tensor t = new Tensor(new[] { 2, 3 }); + Tensor o = t.MaybeMove(); + ASSERT_TRUE(o.Handle.IsInvalid); // It is unsafe to move memory TF might not own. + t.Dispose(); + } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, Tensor)` + /// + [TestMethod] + public void Tensor() + { + var array = new[] { 1f, 2f, 3f, 4f, 5f, 6f }; + var tensor = new Tensor(array, (2, 3)); + + EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); + EXPECT_EQ(tensor.rank, 2); + EXPECT_EQ(tensor.shape[0], 2L); + EXPECT_EQ(tensor.shape[1], 3L); + EXPECT_EQ(tensor.bytesize, 6ul * sizeof(float)); + Assert.IsTrue(Enumerable.SequenceEqual(tensor.ToArray(), new float[] { 1, 2, 3, 4, 5, 6 })); + } + + /// + /// Port from c_api_test.cc + /// `TEST_F(CApiAttributesTest, StringTensor)` + /// + [TestMethod] + public void StringTensor() + { + string text = "Hello world!."; + + var tensor = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, + null, + 0, + 1 * 24); + var tstr = c_api.TF_TensorData(tensor); + c_api.TF_StringInit(tstr); + c_api.TF_StringCopy(tstr, text, text.Length); + var data = c_api.TF_StringGetDataPointer(tstr); + + Assert.AreEqual((ulong)text.Length, c_api.TF_StringGetSize(tstr)); + Assert.AreEqual(text, c_api.StringPiece(data)); + Assert.AreEqual(TF_TString_Type.TF_TSTR_SMALL, c_api.TF_StringGetType(tensor)); + Assert.AreEqual(0, c_api.TF_NumDims(tensor)); + + tensor.Dispose(); + c_api.TF_StringDealloc(tstr); + } + + /// + /// Port from tensorflow\c\c_api_test.cc + /// `TEST(CAPI, SetShape)` + /// + [TestMethod] + public void SetShape() + { + var s = new Status(); + var graph = new Graph().as_default(); + + var feed = c_test_util.Placeholder(graph, s); + var feed_out_0 = new TF_Output(feed, 0); + + // Fetch the shape, it should be completely unknown. + int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(-1, num_dims); + + // Set the shape to be unknown, expect no change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + EXPECT_EQ(-1, num_dims); + + // Set the shape to be 2 x Unknown + long[] dims = { 2, -1 }; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + EXPECT_EQ(2, num_dims); + + // Get the dimension vector appropriately. + var returned_dims = new long[dims.Length]; + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Set to a new valid shape: [2, 3] + dims[1] = 3; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + + // Fetch and see that the new value is returned. + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that + // it doesn't change. + dims[0] = -1; + dims[1] = -1; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); + + // Try to fetch a shape with the wrong num_dims + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + // Try to set an invalid shape (cannot change 2x3 to a 2x5). + dims[1] = 5; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + // Test for a scalar. + var three = c_test_util.ScalarConst(3, graph, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + var three_out_0 = new TF_Output(three, 0); + + num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(0, num_dims); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + graph.Exit(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Native.UnitTest/c_test_util.cs b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs new file mode 100644 index 000000000..4044046bd --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs @@ -0,0 +1,237 @@ +using System; +using System.Diagnostics.CodeAnalysis; +using Tensorflow.Util; + +namespace Tensorflow.Native.UnitTest +{ + /// + /// Port from `tensorflow\c\c_test_util.cc` + /// + public static class c_test_util + { + public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") + { + lock (Locks.ProcessWide) + { + var desc = c_api.TF_NewOperation(graph, "AddN", name); + + var inputs = new TF_Output[] + { + new TF_Output(l, 0), + new TF_Output(r, 0), + }; + + c_api.TF_AddInputList(desc, inputs, inputs.Length); + + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } + } + + [SuppressMessage("ReSharper", "RedundantAssignment")] + public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) + { + var buffer = new Buffer(); + + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); + + return s.Code == TF_Code.TF_OK; + } + + public static GraphDef GetGraphDef(Graph graph) + { + var s = new Status(); + var buffer = new Buffer(); + + c_api.TF_GraphToGraphDef(graph, buffer, s); + s.Check(); + return GraphDef.Parser.ParseFrom(buffer.ToArray()); + } + + public static FunctionDef GetFunctionDef(SafeFuncGraphHandle func) + { + var s = new Status(); + var buffer = new Buffer(); + c_api.TF_FunctionToFunctionDef(func, buffer, s); + s.Check(true); + var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); + return func_def; + } + + public static bool IsAddN(NodeDef node_def, int n) + { + if (node_def.Op != "AddN" || node_def.Name != "add" || + node_def.Input.Count != n) + { + return false; + } + + bool found_t = false; + bool found_n = false; + foreach (var attr in node_def.Attr) + { + if (attr.Key == "T") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_t = true; + } + else + { + return false; + } + } + else if (attr.Key == "N") + { + if (attr.Value.I == n) + { + found_n = true; + } + else + { + return false; + } + } + } + + return found_t && found_n; + } + + public static bool IsNeg(NodeDef node_def, string input) + { + return node_def.Op == "Neg" && node_def.Name == "neg" && + node_def.Input.Count == 1 && node_def.Input[0] == input; + } + + public static bool IsPlaceholder(NodeDef node_def) + { + if (node_def.Op != "Placeholder" || node_def.Name != "feed") + { + return false; + } + + bool found_dtype = false; + bool found_shape = false; + foreach (var attr in node_def.Attr) + { + if (attr.Key == "dtype") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_dtype = true; + } + else + { + return false; + } + } + else if (attr.Key == "shape") + { + found_shape = true; + } + } + + return found_dtype && found_shape; + } + + public static bool IsScalarConst(NodeDef node_def, int v) + { + if (node_def.Op != "Const" || node_def.Name != "scalar") + { + return false; + } + + bool found_dtype = false; + bool found_value = false; + foreach (var attr in node_def.Attr) + { + if (attr.Key == "dtype") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_dtype = true; + } + else + { + return false; + } + } + else if (attr.Key == "value") + { + if (attr.Value.Tensor != null && + attr.Value.Tensor.IntVal.Count == 1 && + attr.Value.Tensor.IntVal[0] == v) + { + found_value = true; + } + else + { + return false; + } + } + } + + return found_dtype && found_value; + } + + public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") + { + lock (Locks.ProcessWide) + { + OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); + var neg_input = new TF_Output(n, 0); + c_api.TF_AddInput(desc, neg_input); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } + } + + public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) + { + lock (Locks.ProcessWide) + { + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + if (dims != null) + { + c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); + } + + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } + } + + public static Operation Const(Tensor t, Graph graph, Status s, string name) + { + lock (Locks.ProcessWide) + { + var desc = c_api.TF_NewOperation(graph, "Const", name); + c_api.TF_SetAttrTensor(desc, "value", t, s); + s.Check(); + c_api.TF_SetAttrType(desc, "dtype", t.dtype); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } + } + + public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") + { + return Const(new Tensor(v), graph, s, name); + } + + public static Tensor Int32Tensor(int v) + { + return new Tensor(v); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs new file mode 100644 index 000000000..9f4719575 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -0,0 +1,106 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class RandomTest + { + /// + /// Test the function of setting random seed + /// This will help regenerate the same result + /// + [TestMethod] + public void TFRandomSeedTest() + { + var initValue = np.arange(6).reshape((3, 2)); + tf.set_random_seed(1234); + var a1 = tf.random_uniform(1); + var b1 = tf.random_shuffle(tf.constant(initValue)); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random_uniform(1); + tf.random_shuffle(tf.constant(initValue)); + + tf.set_random_seed(1234); + var a2 = tf.random_uniform(1); + var b2 = tf.random_shuffle(tf.constant(initValue)); + Assert.AreEqual(a1.numpy(), a2.numpy()); + Assert.AreEqual(b1.numpy(), b2.numpy()); + } + + /// + /// compare to Test above, seed is also added in params + /// + [TestMethod, Ignore] + public void TFRandomSeedTest2() + { + var initValue = np.arange(6).reshape((3, 2)); + tf.set_random_seed(1234); + var a1 = tf.random_uniform(1, seed:1234); + var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random_uniform(1); + tf.random_shuffle(tf.constant(initValue)); + + tf.set_random_seed(1234); + var a2 = tf.random_uniform(1); + var b2 = tf.random_shuffle(tf.constant(initValue)); + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); + } + + /// + /// This part we use funcs in tf.random rather than only tf + /// + [TestMethod] + public void TFRandomRaodomSeedTest() + { + tf.set_random_seed(1234); + var a1 = tf.random.normal(1); + var b1 = tf.random.truncated_normal(1); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random.normal(1); + tf.random.truncated_normal(1); + + tf.set_random_seed(1234); + var a2 = tf.random.normal(1); + var b2 = tf.random.truncated_normal(1); + + Assert.AreEqual(a1.numpy(), a2.numpy()); + Assert.AreEqual(b1.numpy(), b2.numpy()); + } + + /// + /// compare to Test above, seed is also added in params + /// + [TestMethod, Ignore] + public void TFRandomRaodomSeedTest2() + { + tf.set_random_seed(1234); + var a1 = tf.random.normal(1, seed:1234); + var b1 = tf.random.truncated_normal(1); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random.normal(1); + tf.random.truncated_normal(1); + + tf.set_random_seed(1234); + var a2 = tf.random.normal(1, seed:1234); + var b2 = tf.random.truncated_normal(1, seed:1234); + + Assert.AreEqual(a1, a2); + Assert.AreEqual(b1, b2); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Basics/ThreadSafeTest.cs b/test/TensorFlowNET.UnitTest/Basics/ThreadSafeTest.cs new file mode 100644 index 000000000..6a633448c --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Basics/ThreadSafeTest.cs @@ -0,0 +1,41 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class ThreadSafeTest + { + [TestMethod] + public void GraphWithMultiThreads() + { + List threads = new List(); + + const int THREADS_COUNT = 5; + + for (int t = 0; t < THREADS_COUNT; t++) + { + Thread thread = new Thread(() => + { + Graph g = new Graph(); + Session session = new Session(g); + session.as_default(); + var input = tf.placeholder(tf.int32, shape: new Shape(6)); + var op = tf.reshape(input, new int[] { 2, 3 }); + }); + thread.Start(); + threads.Add(thread); + } + + threads.ForEach(t => t.Join()); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Basics/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/Basics/TrainSaverTest.cs new file mode 100644 index 000000000..ca073e1ef --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Basics/TrainSaverTest.cs @@ -0,0 +1,94 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class TrainSaverTest + { + public void ExportGraph() + { + var v = tf.Variable(0, name: "my_variable"); + var sess = tf.Session(); + tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); + } + + public void ImportGraph() + { + var sess = tf.Session(); + var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); + + //tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); + // import meta + /*tf.train.import_meta_graph("linear_regression.meta.bin"); + + var cost = graph.OperationByName("truediv").output; + var pred = graph.OperationByName("Add").output; + var optimizer = graph.OperationByName("GradientDescent"); + var X = graph.OperationByName("Placeholder").output; + var Y = graph.OperationByName("Placeholder_1").output; + var W = graph.OperationByName("weight").output; + var b = graph.OperationByName("bias").output;*/ + + /*var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings + { + Formatting = Formatting.Indented + });*/ + } + + public void ImportSavedModel() + { + Session.LoadFromSavedModel("mobilenet"); + } + + public void ImportGraphDefFromPbFile() + { + var g = new Graph(); + var status = g.Import("mobilenet/saved_model.pb"); + } + + public void Save1() + { + var w1 = tf.Variable(0, name: "save1"); + + var init_op = tf.global_variables_initializer(); + + // Add ops to save and restore all the variables. + var saver = tf.train.Saver(); + + var sess = tf.Session(); + sess.run(init_op); + + // Save the variables to disk. + var save_path = saver.save(sess, "/tmp/model1.ckpt"); + Console.WriteLine($"Model saved in path: {save_path}"); + } + + public void Save2() + { + var v1 = tf.compat.v1.get_variable("v1", shape: new Shape(3), initializer: tf.zeros_initializer); + var v2 = tf.compat.v1.get_variable("v2", shape: new Shape(5), initializer: tf.zeros_initializer); + + var inc_v1 = v1.assign(v1.AsTensor() + 1.0f); + var dec_v2 = v2.assign(v2.AsTensor() - 1.0f); + + // Add an op to initialize the variables. + var init_op = tf.global_variables_initializer(); + + // Add ops to save and restore all the variables. + var saver = tf.train.Saver(); + + var sess = tf.Session(); + sess.run(init_op); + // o some work with the model. + inc_v1.op.run(); + dec_v2.op.run(); + + // Save the variables to disk. + var save_path = saver.save(sess, "/tmp/model2.ckpt"); + Console.WriteLine($"Model saved in path: {save_path}"); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs new file mode 100644 index 000000000..1b55508b0 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -0,0 +1,143 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System.Linq; +using static Tensorflow.Binding; +using System; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class VariableTest : EagerModeTestBase + { + [TestMethod] + public void NewVariable() + { + var x = tf.Variable(10, name: "x"); + Assert.AreEqual(0, x.shape.ndim); + Assert.AreEqual(x.numpy(), 10); + } + + [TestMethod] + public void StringVar() + { + var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string); + var mammal2 = tf.Variable("Tiger"); + } + + [TestMethod] + public void VarSum() + { + var x = tf.constant(3, name: "x"); + var y = tf.Variable(x + 1, name: "y"); + Assert.AreEqual(y.numpy(), 4); + } + + [TestMethod] + public void Assign1() + { + var variable = tf.Variable(31, name: "tree"); + var unread = variable.assign(12); + Assert.AreEqual(unread.numpy(), 12); + } + + [TestMethod] + public void Assign2() + { + var v1 = tf.Variable(10.0f, name: "v1"); + var v2 = v1.assign(v1 + 1.0f); + Assert.AreEqual(v1.numpy(), v2.numpy()); + Assert.AreEqual(v1.numpy(), 11f); + } + + [TestMethod] + public void Assign3() + { + var v1 = tf.Variable(10.0f, name: "v1"); + var v2 = tf.Variable(v1, name: "v2"); + Assert.AreEqual(v1.numpy(), v2.numpy()); + v1.assign(30.0f); + Assert.AreNotEqual(v1.numpy(), v2.numpy()); + } + + /// + /// Assign tensor to slice of other tensor. + /// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ + /// + [TestMethod] + public void SliceAssign() + { + NDArray nd = new float[,] + { + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } + }; + + var x = tf.Variable(nd); + + // get slice form variable + var sliced = x[":2", ":2"]; + Assert.AreEqual(nd[0][":2"], sliced[0].numpy()); + Assert.AreEqual(nd[1][":2"], sliced[1].numpy()); + + // assign to the sliced tensor + sliced.assign(22 * tf.ones((2, 2))); + + // test assigned value + nd = new float[,] + { + { 22, 22, 3 }, + { 22, 22, 6 }, + { 7, 8, 9 } + }; + Assert.AreEqual(nd[0], x[0].numpy()); + Assert.AreEqual(nd[1], x[1].numpy()); + Assert.AreEqual(nd[2], x[2].numpy()); + } + + [TestMethod] + [ExpectedException(typeof(ArrayTypeMismatchException))] + public void TypeMismatchedSliceAssign() + { + NDArray intNd = new int[] + { + 1, -2, 3 + }; + NDArray doubleNd = new double[] + { + -5, 6, -7 + }; + var x = tf.Variable(doubleNd); + x[":"].assign(intNd); + } + + [TestMethod] + public void Accumulation() + { + var x = tf.Variable(10, name: "x"); + for (int i = 0; i < 5; i++) + x.assign(x + 1); + + Assert.AreEqual(x.numpy(), 15); + } + + [TestMethod] + public void ShouldReturnNegative() + { + var x = tf.constant(new[,] { { 1, 2 } }); + var neg_x = tf.negative(x); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims)); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray())); + } + + [TestMethod] + public void IdentityOriginalTensor() + { + var a = tf.Variable(5); + var a_identity = tf.identity(a); + a.assign_add(1); + Assert.AreEqual(a_identity.numpy(), 5); + Assert.AreEqual(a.numpy(), 6); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/VersionTest.cs b/test/TensorFlowNET.UnitTest/Basics/VersionTest.cs similarity index 57% rename from test/TensorFlowNET.UnitTest/VersionTest.cs rename to test/TensorFlowNET.UnitTest/Basics/VersionTest.cs index 2e47f32a9..a53255641 100644 --- a/test/TensorFlowNET.UnitTest/VersionTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VersionTest.cs @@ -1,10 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; +using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest +namespace TensorFlowNET.UnitTest.Basics { [TestClass] public class VersionTest @@ -13,7 +10,7 @@ public class VersionTest public void GetVersion() { var ver = tf.VERSION; - Assert.IsTrue(ver.StartsWith("1.")); + Assert.IsTrue(ver.StartsWith("2.")); } } } diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs new file mode 100644 index 000000000..183544ab6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -0,0 +1,237 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Dataset +{ + [TestClass] + public class DatasetTest : EagerModeTestBase + { + [TestMethod] + public void Range() + { + int iStep = 0; + long value = 0; + + var dataset = tf.data.Dataset.range(3); + foreach (var (step, item) in enumerate(dataset)) + { + Assert.AreEqual(iStep, step); + iStep++; + + Assert.AreEqual(value, (long)item.Item1); + value++; + } + } + + [TestMethod] + public void Prefetch() + { + int iStep = 0; + long value = 1; + + var dataset = tf.data.Dataset.range(1, 5, 2); + dataset = dataset.prefetch(2); + + foreach (var (step, item) in enumerate(dataset)) + { + Assert.AreEqual(iStep, step); + iStep++; + + Assert.AreEqual(value, (long)item.Item1); + value += 2; + } + } + + [TestMethod] + public void FromTensorSlices() + { + var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 }); + var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 }); + + var dataset = tf.data.Dataset.from_tensor_slices(X, Y); + int n = 0; + foreach (var (item_x, item_y) in dataset) + { + print($"x:{item_x.numpy()},y:{item_y.numpy()}"); + n += 1; + } + Assert.AreEqual(5, n); + } + + [TestMethod] + public void FromTensor() + { + var X = new[] { 2013, 2014, 2015, 2016, 2017 }; + + var dataset = tf.data.Dataset.from_tensors(X); + int n = 0; + foreach (var x in dataset) + { + Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); + n += 1; + } + Assert.AreEqual(1, n); + } + + [TestMethod] + public void Shard() + { + long value = 0; + + var dataset1 = tf.data.Dataset.range(10); + var dataset2 = dataset1.shard(num_shards: 3, index: 0); + + foreach (var item in dataset2) + { + Assert.AreEqual(value, (long)item.Item1); + value += 3; + } + + value = 1; + var dataset3 = dataset1.shard(num_shards: 3, index: 1); + foreach (var item in dataset3) + { + Assert.AreEqual(value, (long)item.Item1); + value += 3; + } + } + + [TestMethod] + public void Skip() + { + long value = 7; + + var dataset = tf.data.Dataset.range(10); + dataset = dataset.skip(7); + + foreach (var item in dataset) + { + Assert.AreEqual(value, (long)item.Item1); + value++; + } + } + + [TestMethod] + public void Map() + { + long value = 0; + + var dataset = tf.data.Dataset.range(0, 2); + dataset = dataset.map(x => x[0] + 10); + + foreach (var item in dataset) + { + Assert.AreEqual(value + 10, (long)item.Item1); + value++; + } + } + + [TestMethod] + public void Cache() + { + long value = 0; + + var dataset = tf.data.Dataset.range(5); + dataset = dataset.cache(); + + foreach (var item in dataset) + { + Assert.AreEqual(value, (long)item.Item1); + value++; + } + } + + [TestMethod] + public void Cardinality() + { + var dataset = tf.data.Dataset.range(10); + var cardinality = dataset.cardinality(); + Assert.AreEqual(cardinality.numpy(), 10L); + dataset = dataset.map(x => x[0] + 1); + cardinality = dataset.cardinality(); + Assert.AreEqual(cardinality.numpy(), 10L); + } + + [TestMethod] + public void CardinalityWithAutoTune() + { + var dataset = tf.data.Dataset.range(10); + dataset = dataset.map(x => x, num_parallel_calls: -1); + var cardinality = dataset.cardinality(); + Assert.AreEqual(cardinality.numpy(), 10L); + } + + [TestMethod] + public void CardinalityWithRepeat() + { + var dataset = tf.data.Dataset.range(10); + dataset = dataset.repeat(); + var cardinality = dataset.cardinality(); + Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy()); + + dataset = dataset.filter(x => true); + cardinality = dataset.cardinality(); + Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()); + } + + [TestMethod] + public void Shuffle() + { + tf.set_random_seed(1234); + + var dataset = tf.data.Dataset.range(3); + var shuffled = dataset.shuffle(3); + + var zipped = tf.data.Dataset.zip(dataset, shuffled); + + bool allEqual = true; + foreach (var item in zipped) + { + if (item.Item1 != item.Item2) + allEqual = false; + } + + Assert.IsFalse(allEqual); + } + [Ignore] + [TestMethod] + public void GetData() + { + var vocab_size = 20000; // Only consider the top 20k words + var maxlen = 200; // Only consider the first 200 words of each movie review + var dataset = keras.datasets.imdb.load_data(num_words: vocab_size, maxlen: maxlen); + var x_train = dataset.Train.Item1; + var y_train = dataset.Train.Item2; + var x_val = dataset.Test.Item1; + var y_val = dataset.Test.Item2; + + x_train = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_train), maxlen: maxlen); + x_val = keras.preprocessing.sequence.pad_sequences(RemoveZeros(x_val), maxlen: maxlen); + print(len(x_train) + " Training sequences"); + print(len(x_val) + " Validation sequences"); + } + IEnumerable RemoveZeros(NDArray data) + { + var data_array = (int[,])data.ToMultiDimArray(); + List new_data = new List(); + for (var i = 0; i < data_array.GetLength(0); i++) + { + List new_array = new List(); + for (var j = 0; j < data_array.GetLength(1); j++) + { + if (data_array[i, j] == 0) + break; + else + new_array.Add(data_array[i, j]); + } + new_data.Add(new_array.ToArray()); + } + return new_data; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs new file mode 100644 index 000000000..b7b9ae128 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + public class EagerModeTestBase : PythonTest + { + [TestInitialize] + public void TestInit() + { + if (!tf.executing_eagerly()) + tf.enable_eager_execution(); + tf.Context.ensure_initialized(); + } + + public bool Equal(float f1, float f2) + { + var tolerance = .000001f; + return Math.Abs(f1 - f2) <= tolerance; + } + + public bool Equal(long[] l1, long[] l2) + { + if (l1.Length != l2.Length) + return false; + + for (var i = 0; i < l1.Length; i++) + { + if (l1[i] != l2[i]) + return false; + } + + return true; + } + + public bool Equal(float[] f1, float[] f2) + { + bool ret = false; + var tolerance = .000001f; + for (var i = 0; i < f1.Length; i++) + { + ret = Math.Abs(f1[i] - f2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } + + public bool Equal(double[] d1, double[] d2) + { + bool ret = false; + var tolerance = .000000000000001f; + for (var i = 0; i < d1.Length; i++) + { + ret = Math.Abs(d1[i] - d2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs new file mode 100644 index 000000000..1cfceb3e3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs @@ -0,0 +1,206 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Gradient +{ + [TestClass] + public class GradientEagerTest : EagerModeTestBase + { + [TestMethod] + public void ConstantSquare() + { + // Calcute the gradient of w * w + // by Automatic Differentiation in Eager mode + var w = tf.constant(1.5f); + using var tape = tf.GradientTape(); + // w is defined before tape is recording + tape.watch(w); + var loss = w * w; + var grad = tape.gradient(loss, w); + Assert.AreEqual((float)grad, 3.0f); + } + + [TestMethod] + public void SquaredDifference_Constant() + { + // Calcute the gradient of (x1-x2)^2 + // by Automatic Differentiation in Eager mode + var x1 = tf.constant(7f); + var x2 = tf.constant(11f); + + // Sanity check + using (var tape = tf.GradientTape()) + { + tape.watch(x2); + var loss = tf.multiply((x1 - x2), (x1 - x2)); + + var result = tape.gradient(loss, x2); + // Expected is 2*(11-7) = 8 + Assert.AreEqual((float)result, 8f); + } + + // Actual test + using (var tape = tf.GradientTape()) + { + tape.watch(x2); + var loss = tf.squared_difference(x1, x2); + + // Expected is 2*(11-7) = 8 + var result = tape.gradient(loss, x2); + Assert.AreEqual((float)result, 8f); + } + } + + [TestMethod] + public void SquaredDifference_1D() + { + // Calcute the gradient of (x1-x2)^2 + // by Automatic Differentiation in Eager mode + // Expected is 2*(abs(x1-x2)) + Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 }); + Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); + float[] expected = new float[] + { + (29-1) * 2, + (27-3) * 2, + (23-5) * 2, + (7-21) * 2, + (11-19) * 2, + (13-17) * 2 + }; + + // Sanity check + using (var tape = tf.GradientTape()) + { + tape.watch(x1); + tape.watch(x2); + var loss = tf.multiply((x1 - x2), (x1 - x2)); + + var result = tape.gradient(loss, x2); + CollectionAssert.AreEqual(result.ToArray(), expected); + } + + // Actual test + using (var tape = tf.GradientTape()) + { + tape.watch(x1); + tape.watch(x2); + var loss = tf.squared_difference(x1, x2); + + var result = tape.gradient(loss, x2); + CollectionAssert.AreEqual(result.ToArray(), expected); + } + } + + + /// + /// Calcute the higher derivative gradient of w * w * w + /// 高阶梯度 + /// + [TestMethod] + public void HighGradient() + { + var x = tf.Variable(1.0f); + using var tape1 = tf.GradientTape(); + using var tape2 = tf.GradientTape(); + var y = x * x * x; + var dy_dx = tape2.gradient(y, x); + Assert.AreEqual((float)dy_dx, 3.0f); + var d2y_d2x = tape1.gradient(dy_dx, x); + Assert.AreEqual((float)d2y_d2x, 6.0f); + } + + [TestMethod] + public void ConstantMultiply() + { + var x = tf.ones((2, 2)); + using var tape = tf.GradientTape(); + tape.watch(x); + var y = tf.reduce_sum(x); + var z = tf.multiply(y, y); + var dz_dx = tape.gradient(z, x); + + var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; + Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); + } + + [TestMethod] + public void PersistentTape() + { + var x = tf.ones((2, 2)); + using var tape = tf.GradientTape(persistent: true); + tape.watch(x); + var y = tf.reduce_sum(x); + var z = tf.multiply(y, y); + var dz_dx = tape.gradient(z, x); + + var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; + Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); + + var dz_dy = tape.gradient(z, y); + Assert.AreEqual((float)dz_dy, 8.0f); + } + + [TestMethod] + public void ConditionalMultiply() + { + Func func = (x, y) => + { + Tensor output = tf.constant(1.0f); + foreach (var i in range(y)) + { + if (i > 1) + output = tf.multiply(output, x); + } + return output; + }; + + Func grad = (x, y) => + { + using var tape = tf.GradientTape(); + tape.watch(x); + var output = func(x, y); + var grad = tape.gradient(output, x); + return grad; + }; + + var x = tf.constant(2.0f); + var result = grad(x, 4); + Assert.AreEqual((float)result, 4.0f); + } + + [TestMethod] + public void Tile() + { + var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT); + var b = tf.constant(new int[] { 2 }); + using (var tape = tf.GradientTape()) + { + tape.watch(a); + var y = tf.tile(a, b); + var grad = tape.gradient(y, a); + Assert.AreEqual((float)grad.numpy(), 2.0f); + } + } + + [TestMethod] + public void GatherNdTest() + { + var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT); + var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32); + using (var tape = tf.GradientTape()) + { + tape.watch(x); + var res = tf.gather_nd(x, indices); + var grad = tape.gradient(res, x); + var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } }); + Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray(), expected.ToArray())); + } + } + } +} diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs deleted file mode 100644 index 1864fede7..000000000 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; - -namespace TensorFlowNET.UnitTest -{ - [TestClass] - public class GraphTest - { - [TestMethod] - public void ConstructGraph() - { - var g = tf.Graph(); - } - } -} diff --git a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs new file mode 100644 index 000000000..e2fc0c89c --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Threading.Tasks; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class MnistModelLoaderTest + { + [TestMethod] + public async Task TestLoad() + { + var loader = new MnistModelLoader(); + var result = await loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = true, + ValidationSize = 5000, + }); + + Assert.IsNotNull(result); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/KerasTests.cs b/test/TensorFlowNET.UnitTest/KerasTests.cs new file mode 100644 index 000000000..dfbe38601 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/KerasTests.cs @@ -0,0 +1,27 @@ +using Tensorflow; +using Keras.Layers; +using NumSharp; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class BaseTests + { + [TestMethod] + public void Dense_Tensor_ShapeTest() + { + var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu()); + var input = new Tensor(np.array(new int[] { 3 })); + dense_1.__build__(input.TensorShape); + var outputShape = dense_1.output_shape(input.TensorShape); + var a = (int[])(outputShape.dims); + var b = (int[])(new int[] { 1 }); + var _a = np.array(a); + var _b = np.array(b); + + Assert.IsTrue(np.array_equal(_a, _b)); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs new file mode 100644 index 000000000..bf8e1cbf7 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs @@ -0,0 +1,40 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.NenuralNetwork +{ + [TestClass] + public class ActivationFunctionTest : EagerModeTestBase + { + // A constant vector of size 6 + Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); + + [TestMethod] + public void Sigmoid() + { + var b = tf.nn.sigmoid(a, name: "sigmoid"); + var expected = new float[] { 0.7310586f, 0.37754068f, 0.9677046f, 0.10909683f, 0.5f, 0.00150118f }; + var actual = b.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + + [TestMethod] + public void ReLU() + { + var b = tf.nn.relu(a, name: "ReLU"); + var expected = new float[] { 1f, 0f, 3.4f, 0f, 0f, 0f }; + var actual = b.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + + [TestMethod] + public void TanH() + { + var b = tf.nn.tanh(a, name: "TanH"); + var expected = new float[] { 0.7615942f, -0.46211717f, 0.9977749f, -0.970452f, 0f, -0.99999547f }; + var actual = b.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs new file mode 100644 index 000000000..e25c9779d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs @@ -0,0 +1,426 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using Tensorflow; +using static Tensorflow.Binding; +using System.Linq; +using Tensorflow.Operations; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class ArrayOpsTest : EagerModeTestBase + { + /// + /// https://www.tensorflow.org/api_docs/python/tf/slice + /// + [TestMethod] + public void Slice() + { + // Tests based on example code in TF documentation + var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape((3,2,3))); + var indices = tf.constant(np.array(new int[] { 0, 2 })); + + var r1 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 1, 3 })); + Assert.AreEqual(new Shape(1,1,3), r1.shape); + var r1np = r1.numpy(); + Assert.AreEqual(r1np[0, 0, 0], 3); + Assert.AreEqual(r1np[0, 0, 1], 3); + Assert.AreEqual(r1np[0, 0, 2], 3); + + + var r2 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 2, 3 })); + Assert.AreEqual(new Shape(1, 2, 3), r2.shape); + var r2np = r2.numpy(); + Assert.AreEqual(r2np[0, 0, 0], 3); + Assert.AreEqual(r2np[0, 0, 1], 3); + Assert.AreEqual(r2np[0, 0, 2], 3); + Assert.AreEqual(r2np[0, 1, 0], 4); + Assert.AreEqual(r2np[0, 1, 1], 4); + Assert.AreEqual(r2np[0, 1, 2], 4); + + var r3 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 2, 1, 3 })); + Assert.AreEqual(new Shape(2, 1, 3), r3.shape); + var r3np = r3.numpy(); + Assert.AreEqual(r3np[0, 0, 0], 3); + Assert.AreEqual(r3np[0, 0, 1], 3); + Assert.AreEqual(r3np[0, 0, 2], 3); + Assert.AreEqual(r3np[1, 0, 0], 5); + Assert.AreEqual(r3np[1, 0, 1], 5); + Assert.AreEqual(r3np[1, 0, 2], 5); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/gather + /// + [TestMethod] + public void Gather() + { + var input_array = tf.constant(np.arange(12).reshape((3, 4)).astype(np.float32)); + var indices = tf.constant(np.array(new int[] { 0, 2 })); + + var result = array_ops.gather(input_array, indices); + Assert.AreEqual(new Shape(2, 4), result.shape); + Assert.AreEqual(result.numpy()[0, 0], 0.0f); + Assert.AreEqual(result.numpy()[0, 1], 1.0f); + Assert.AreEqual(result.numpy()[1, 3], 11.0f); + + // Tests based on example code in Python doc string for tf.gather() + + var p1 = tf.random.normal(new Shape(5, 6, 7, 8)); + var i1 = tf.random_uniform(new Shape(10, 11), maxval: 7, dtype: tf.int32); + var r1 = tf.gather(p1, i1, axis:2); + Assert.AreEqual(new Shape(5, 6, 10, 11, 8), r1.shape); + + var p2 = tf.random.normal(new Shape(4,3)); + var i2 = tf.constant(new int[,] { { 0, 2} }); + var r2 = tf.gather(p2, i2, axis: 0); + Assert.AreEqual(new Shape(1, 2, 3), r2.shape); + + var r3 = tf.gather(p2, i2, axis: 1); + Assert.AreEqual(new Shape(4,1,2), r3.shape); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/TensorArray + /// + [TestMethod] + public void TensorArray() + { + var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false); + ta.write(0, 10); + ta.write(1, 20); + ta.write(2, 30); + Assert.AreEqual(ta.read(0).numpy(), 10f); + Assert.AreEqual(ta.read(1).numpy(), 20f); + Assert.AreEqual(ta.read(2).numpy(), 30f); + } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/reverse + /// + [TestMethod] + public void ReverseArray() + { + var a = tf.random.normal((2, 3)); + var b = tf.reverse(a, -1); + Assert.IsTrue(Equal(a[0].ToArray().Reverse().ToArray(), b[0].ToArray())); + Assert.IsTrue(Equal(a[1].ToArray().Reverse().ToArray(), b[1].ToArray())); + } + + [TestMethod] + public void ReverseImgArray3D() + { + // 创建 sourceImg 数组 + var sourceImgArray = new float[,,] { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var sourceImg = ops.convert_to_tensor(sourceImgArray); + + // 创建 lrImg 数组 + var lrImgArray = new float[,,] { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var lrImg = ops.convert_to_tensor(lrImgArray); + + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 1); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + // 创建 udImg 数组 + var udImgArray = new float[,,] { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var udImg = ops.convert_to_tensor(udImgArray); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(0)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=0) fail."); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=0 fail."); + } + + [TestMethod] + public void ReverseImgArray4D() + { + // 原图左上角,加一张左右翻转后的图片 + var m = new float[,,,] { + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var sourceImg = ops.convert_to_tensor(m); + + var lrArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var lrImg = ops.convert_to_tensor(lrArray); + + // 创建 ud 数组 + var udArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + } + } + }; + var udImg = ops.convert_to_tensor(udArray); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(1)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + // 左右翻转 + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 0); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + } + + [TestMethod] + public void ReverseImgArray4D_3x3() + { + // 原图左上角,加一张左右翻转后的图片 + var m = new float[,,,] { + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var sourceImg = ops.convert_to_tensor(m); + + var lrArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var lrImg = ops.convert_to_tensor(lrArray); + + // 创建 ud 数组 + var udArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + } + } + }; + var udImg = ops.convert_to_tensor(udArray); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(1)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + // 左右翻转 + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 0); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs new file mode 100644 index 000000000..e57e50722 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs @@ -0,0 +1,89 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class BitwiseApiTest : EagerModeTestBase + { + [TestInitialize] + public void Init() + { + tf.enable_eager_execution(); + } + + [TestMethod] + public void BitwiseAnd() + { + Tensor lhs = tf.constant(new int[] { 0, 5, 3, 14 }); + Tensor rhs = tf.constant(new int[] { 5, 0, 7, 11 }); + + var bitwise_and_result = tf.bitwise.bitwise_and(lhs, rhs); + var expected = new int[] { 0, 0, 3, 10 }; + var actual = bitwise_and_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + [TestMethod] + public void BitwiseOr() + { + Tensor lhs = tf.constant(new int[] { 0, 5, 3, 14 }); + Tensor rhs = tf.constant(new int[] { 5, 0, 7, 11 }); + + var bitwise_or_result = tf.bitwise.bitwise_or(lhs, rhs); + var expected = new int[] { 5, 5, 7, 15 }; + var actual = bitwise_or_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + [TestMethod] + public void BitwiseXOR() + { + Tensor lhs = tf.constant(new int[] { 0, 5, 3, 14 }); + Tensor rhs = tf.constant(new int[] { 5, 0, 7, 11 }); + + var bitwise_xor_result = tf.bitwise.bitwise_xor(lhs, rhs); + var expected = new int[] { 5, 5, 4, 5 }; + var actual = bitwise_xor_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + [TestMethod] + public void Invert() + { + Tensor lhs = tf.constant(new int[] { 0, 1, -3, int.MaxValue }); + + var invert_result = tf.bitwise.invert(lhs); + var expected = new int[] { -1, -2, 2, int.MinValue }; + var actual = invert_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + [TestMethod] + public void LeftShift() + { + Tensor lhs = tf.constant(new int[] { -1, -5, -3, -14 }); + Tensor rhs = tf.constant(new int[] { 5, 0, 7, 11 }); + + var left_shift_result = tf.bitwise.left_shift(lhs, rhs); + var expected = new int[] { -32, -5, -384, -28672 }; + var actual = left_shift_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + [TestMethod] + public void RightShift() + { + Tensor lhs = tf.constant(new int[] { -2, 64, 101, 32 }); + Tensor rhs = tf.constant(new int[] { -1, -5, -3, -14 }); + + var right_shift_result = tf.bitwise.right_shift(lhs, rhs); + var expected = new int[] { -2, 64, 101, 32 }; + var actual = right_shift_result.ToArray(); + Assert.IsTrue(Enumerable.SequenceEqual(expected, actual)); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ClipTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ClipTest.cs new file mode 100644 index 000000000..6cbc69adb --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ClipTest.cs @@ -0,0 +1,21 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.ClipOps +{ + [TestClass] + public class ClipTest : EagerModeTestBase + { + [TestMethod] + public void clip_by_global_norm() + { + var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 })); + var clip_norm = .8f; + var (res, norm) = tf.clip_by_global_norm(t_list, clip_norm); + Equal(res[0].ToArray(), new[] { 0.0560112074f, 0.112022415f, 0.16803363f, 0.22404483f }); + Equal(res[1].ToArray(), new[] { 0.28005603f, 0.336067259f, 0.392078459f, 0.448089659f }); + Assert.AreEqual(norm.numpy(), 14.282857f); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs new file mode 100644 index 000000000..2062dbc30 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs @@ -0,0 +1,173 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class ConstantTest : EagerModeTestBase + { + Status status = new Status(); + + [TestMethod] + public void ScalarConst() + { + var tensor1 = tf.constant(8); // int + Assert.AreEqual(tensor1.dtype, TF_DataType.TF_INT32); + var tensor2 = tf.constant(6.0f); // float + Assert.AreEqual(tensor2.dtype, TF_DataType.TF_FLOAT); + var tensor3 = tf.constant(6.0); // double + Assert.AreEqual(tensor3.dtype, TF_DataType.TF_DOUBLE); + } + + /*[DataTestMethod] + [DataRow(int.MinValue)] + [DataRow(-1)] + [DataRow(0)] + [DataRow(1)] + [DataRow(int.MaxValue)] + public void ScalarConstTypecast_int(int value) + { + var tensor = (Tensor)value; + with(tf.Session(), sess => + { + var result = sess.run(tensor); + Assert.AreEqual(result.Data()[0], value); + }); + } + + [DataTestMethod] + [DataRow(double.NegativeInfinity)] + [DataRow(double.MinValue)] + [DataRow(-1d)] + [DataRow(0d)] + [DataRow(double.Epsilon)] + [DataRow(1d)] + [DataRow(double.MaxValue)] + [DataRow(double.PositiveInfinity)] + [DataRow(double.NaN)] + public void ScalarConstTypecast_double(double value) + { + var tensor = (Tensor)value; + with(tf.Session(), sess => + { + var result = sess.run(tensor); + Assert.AreEqual(result.Data()[0], value); + }); + } + + [DataTestMethod] + [DataRow(float.NegativeInfinity)] + [DataRow(float.MinValue)] + [DataRow(-1f)] + [DataRow(0f)] + [DataRow(float.Epsilon)] + [DataRow(1f)] + [DataRow(float.MaxValue)] + [DataRow(float.PositiveInfinity)] + [DataRow(float.NaN)] + public void ScalarConstTypecast_float(float value) + { + var tensor = (Tensor)value; + with(tf.Session(), sess => + { + var result = sess.run(tensor); + Assert.AreEqual(result.Data()[0], value); + }); + } + + [TestMethod] + public void StringConst() + { + string str = "Hello, TensorFlow.NET!"; + var tensor = tf.constant(str); + with(tf.Session(), sess => + { + var result = sess.run(tensor); + Assert.IsTrue(result.Data()[0] == str); + }); + }*/ + + [TestMethod] + public void ZerosConst() + { + // small size + var tensor = tf.zeros((3, 2), tf.int32, "small"); + + Assert.AreEqual(tensor.shape[0], 3); + Assert.AreEqual(tensor.shape[1], 2); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray())); + + // big size + tensor = tf.zeros((200, 100), tf.int32, "big"); + + Assert.AreEqual(tensor.shape[0], 200); + Assert.AreEqual(tensor.shape[1], 100); + + var data = tensor.numpy().ToArray(); + Assert.AreEqual(0, data[0]); + Assert.AreEqual(0, data[500]); + Assert.AreEqual(0, data[data.Length - 1]); + } + + [TestMethod] + public void OnesConst() + { + var ones = tf.ones(new Shape(3, 2), tf.float32, "ones"); + Assert.AreEqual(ones.dtype, tf.float32); + Assert.AreEqual(ones.shape[0], 3); + Assert.AreEqual(ones.shape[1], 2); + Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray())); + } + + [TestMethod] + public void OnesToHalves() + { + var ones = tf.ones(new Shape(3, 2), tf.float64, "ones"); + var halfes = ones * 0.5; + Assert.AreEqual(halfes.shape[0], 3); + Assert.AreEqual(halfes.shape[1], 2); + Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray())); + } + + [TestMethod] + public void NDimConst() + { + var nd = np.array(new int[,] + { + { 3, 1, 1 }, + { 2, 1, 3 } + }); + + var tensor = tf.constant(nd); + var data = tensor.numpy().ToArray(); + + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims)); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); + } + + [TestMethod] + public void Multiply() + { + var a = tf.constant(3.0); + var b = tf.constant(2.0); + var c = a * b; + + Assert.AreEqual(6.0, (double)c); + } + + [TestMethod] + public void Reshape() + { + var ones = tf.ones((3, 2), tf.float32, "ones"); + var reshaped = tf.reshape(ones, (2, 3)); + Assert.AreEqual(reshaped.dtype, tf.float32); + Assert.AreEqual(reshaped.shape[0], 2); + Assert.AreEqual(reshaped.shape[1], 3); + Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray())); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs new file mode 100644 index 000000000..23dc1d44d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs @@ -0,0 +1,66 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class ControlFlowApiTest + { + [TestMethod] + public void WhileLoopOneInputEagerMode() + { + tf.enable_eager_execution(); + + var i = tf.constant(2); + Func c = (x) => tf.less(x, 10); + Func b = (x) => tf.add(x, 1); + var r = tf.while_loop(c, b, i); + Assert.AreEqual(10, (int)r); + } + + [TestMethod] + public void WhileLoopTwoInputsEagerMode() + { + tf.enable_eager_execution(); + + var i = tf.constant(2); + var j = tf.constant(3); + Func c = (x) => tf.less(x[0] + x[1], 10); + Func b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; + var r = tf.while_loop(c, b, new[] { i, j }); + Assert.AreEqual(5, (int)r[0]); + Assert.AreEqual(6, (int)r[1]); + } + + [TestMethod, Ignore] + public void WhileLoopGraphMode() + { + tf.compat.v1.disable_eager_execution(); + + var i = tf.constant(2); + Func c = (x) => tf.less(x, 10); + Func b = (x) => tf.add(x, 1); + var r = tf.while_loop(c, b, i); + Assert.AreEqual(10, (int)r); + } + + + [TestMethod, Ignore] + public void ScanFunctionGraphMode() + { + tf.compat.v1.disable_eager_execution(); + + Func fn = (prev, current) => tf.add(prev, current); + var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); + var scan = tf.scan(fn, input); + + var sess = tf.Session(); + sess.run(tf.global_variables_initializer()); + var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray()); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs new file mode 100644 index 000000000..df00d5880 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs @@ -0,0 +1,101 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Graphs; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class FunctionApiTest : EagerModeTestBase + { + Tensor Min(Tensor a, Tensor b) + { + return tf.cond(a < b, () => a, () => b); + } + + [TestMethod] + public void MulInAutoGraph() + { + var a = tf.constant(1); + var b = tf.constant(2); + // For first time running, tf.net will record the operations in graph mode. + // And register to tensorflow op library. + var output = Mul(a, b); + Assert.AreEqual(2, (int)output); + + var c = tf.constant(3); + // for the following invoke, Mul will be intercepted and run it in eager mode. + output = Mul(b, c); + Assert.AreEqual(6, (int)output); + } + + /// + /// Method with AutoGraph attribute will be converted to FuncGraph + /// when it's invoked for the first time. + /// + /// + /// + /// + [AutoGraph] + Tensor Mul(Tensor a, Tensor b) + { + return a * b; + } + + [TestMethod] + public void TwoInputs_OneOutput() + { + var func = tf.autograph.to_graph(Add); + var a = tf.constant(1); + var b = tf.constant(2); + var output = func(a, b); + Assert.AreEqual(3, (int)output); + } + + Tensor Add(Tensor a, Tensor b) + { + return a + b; + } + + [TestMethod] + public void TwoInputs_OneOutput_Condition() + { + var func = tf.autograph.to_graph(Condition); + var a = tf.constant(3); + var b = tf.constant(2); + var output = func(a, b); + Assert.AreEqual(2, (int)output); + } + + Tensor Condition(Tensor a, Tensor b) + { + return tf.cond(a < b, a, b); + } + + [TestMethod] + public void TwoInputs_OneOutput_Lambda() + { + var func = tf.autograph.to_graph((x, y) => x * y); + var output = func(tf.constant(3), tf.constant(2)); + Assert.AreEqual(6, (int)output); + } + + [TestMethod] + public void TwoInputs_OneOutput_WhileLoop() + { + var func = tf.autograph.to_graph((x, y) => x * y); + var output = func(tf.constant(3), tf.constant(2)); + Assert.AreEqual(6, (int)output); + } + + Tensor WhileLoop() + { + var i = tf.constant(0); + Func c = i => tf.less(i, 10); + Func b = i => tf.add(i, 1); + //var r = tf.(c, b, [i]) + throw new NotImplementedException(""); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs new file mode 100644 index 000000000..902bcdbfb --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -0,0 +1,81 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class GradientTest + { + [TestMethod] + public void GradientFloatTest() + { + var x = tf.Variable(3.0, dtype: tf.float32); + using var tape = tf.GradientTape(); + var y = tf.square(x); + var y_grad = tape.gradient(y, x); + Assert.AreEqual(9.0f, (float)y); + } + + [TestMethod] + public void GradientDefaultTest() + { + var x = tf.Variable(3.0); + using var tape = tf.GradientTape(); + var y = tf.square(x); + var y_grad = tape.gradient(y, x); + Assert.AreEqual(9.0, (double)y); + } + + [TestMethod] + public void GradientDoubleTest() + { + var x = tf.Variable(3.0, dtype: tf.float64); + using var tape = tf.GradientTape(); + var y = tf.square(x); + var y_grad = tape.gradient(y, x); + Assert.AreEqual(9.0, (double)y); + } + + [TestMethod] + public void GradientOperatorMulTest() + { + var x = tf.constant(0f); + var w = tf.Variable(new float[] { 1, 1 }); + using var gt = tf.GradientTape(); + var y = x * w; + var gr = gt.gradient(y, w); + Assert.AreEqual(new float[] { 0, 0 }, gr.numpy()); + } + + [TestMethod] + public void GradientSliceTest() + { + var X = tf.zeros(10); + var W = tf.Variable(-0.06f, name: "weight"); + var b = tf.Variable(-0.73f, name: "bias"); + using var g = tf.GradientTape(); + var pred = W * X + b; + var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape); + var gradients = g.gradient(test, (W, b)); + Assert.AreEqual((float)gradients.Item1, 0f); + Assert.AreEqual((float)gradients.Item2, 10f); + } + + [TestMethod] + public void GradientConcatTest() + { + var w1 = tf.Variable(new[] { new[] { 1f } }); + var w2 = tf.Variable(new[] { new[] { 3f } }); + using var g = tf.GradientTape(); + var w = tf.concat(new Tensor[] { w1, w2 }, 0); + var x = tf.ones((1, 2)); + var y = tf.reduce_sum(x, 1); + var r = tf.matmul(w, x); + var gradients = g.gradient(r, w); + Assert.AreEqual((float)gradients[0][0], 2f); + Assert.AreEqual((float)gradients[1][0], 2f); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs new file mode 100644 index 000000000..fb515af1a --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -0,0 +1,92 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class LinalgTest : EagerModeTestBase + { + [TestMethod] + public void EyeTest() + { + var tensor = tf.linalg.eye(3); + + Assert.AreEqual(tensor.shape, (3, 3)); + + Assert.AreEqual(0.0f, (double)tensor[2, 0]); + Assert.AreEqual(0.0f, (double)tensor[2, 1]); + Assert.AreEqual(1.0f, (double)tensor[2, 2]); + } + + /// + /// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up + /// + [TestMethod] + public void LSTSQ() + { + var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} }); + var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } }); + var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1)); + var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1)); + var x_over = tf.linalg.lstsq(A_over, b_over); + + var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true)); + Assert.AreEqual(x_over.shape, (2, 1)); + AssetSequenceEqual(x_over.ToArray(), x.ToArray()); + + var x_under = tf.linalg.lstsq(A_under, b_under); + var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true); + + Assert.AreEqual(x_under.shape, (4, 1)); + AssetSequenceEqual(x_under.ToArray(), y.ToArray()); + + /*var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f); + var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f); + Assert.AreEqual(x_under_reg.shape, (4, 1)); + AssetSequenceEqual(x_under_reg.ToArray(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/ + } + + [TestMethod] + public void Einsum() + { + var m0 = tf.random.normal((2, 3)); + var m1 = tf.random.normal((3, 5)); + var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); + Assert.AreEqual(e.shape, (2, 5)); + } + + [TestMethod] + public void GlobalNorm() + { + var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 })); + var norm = tf.linalg.global_norm(t_list); + Assert.AreEqual(norm.numpy(), 14.282857f); + } + + [TestMethod] + public void Tensordot() + { + var a = tf.constant(new[] { 1, 2 }); + var b = tf.constant(new[] { 2, 3 }); + var c = tf.linalg.tensordot(a, b, 0); + Assert.AreEqual(c.shape, (2, 2)); + AssetSequenceEqual(c.ToArray(), new[] { 2, 3, 4, 6 }); + + c = tf.linalg.tensordot(a, b, new[] { 0, 0 }); + Assert.AreEqual(c.shape.ndim, 0); + Assert.AreEqual(c.numpy(), 8); + } + + [TestMethod] + public void Matmul() + { + var a = tf.constant(new[] { 1, 2, 3, 4, 5, 6 }, shape: (2, 3)); + var b = tf.constant(new[] { 7, 8, 9, 10, 11, 12 }, shape: (3, 2)); + var c = tf.linalg.matmul(a, b); + + Assert.AreEqual(c.shape, (2, 2)); + AssetSequenceEqual(c.ToArray(), new[] { 58, 64, 139, 154 }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LoggingTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LoggingTest.cs new file mode 100644 index 000000000..3fa0d0187 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LoggingTest.cs @@ -0,0 +1,16 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class LoggingTest + { + [TestMethod] + public void PrintTest() + { + var tensor = tf.range(10); + tf.print(tensor); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs new file mode 100644 index 000000000..411deb18f --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs @@ -0,0 +1,84 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Linq; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class MathApiTest : EagerModeTestBase + { + // A constant vector of size 6 + Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); + Tensor b = tf.constant(new float[,] { { 1.0f, -0.5f, 3.4f }, { -2.1f, 0.0f, -6.5f } }); + + [TestMethod] + public void Sin() + { + var b = tf.sin(a, name: "Sin"); + var expected = new float[] { 0.84147096f, -0.47942555f, -0.2555412f, -0.86320937f, 0f, -0.21511999f }; + var actual = b.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + + [TestMethod] + public void Tan() + { + var b = tf.tan(a, name: "Tan"); + var expected = new float[] { 1.5574077f, -0.5463025f, 0.264317f, 1.709847f, 0f, -0.2202772f }; + var actual = b.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + + [TestMethod] + public void ReduceSum() + { + var x1 = tf.reduce_sum(b); + Assert.AreEqual(-4.7f, (float)x1); + + var x2 = tf.reduce_sum(b, 0); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1.0999999f, -0.5f, -3.1f }, x2.ToArray())); + + var x3 = tf.reduce_sum(b, 1); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 3.9f, -8.6f }, x3.ToArray())); + + var x4 = tf.reduce_sum(b, 1, keepdims: true); + Assert.AreEqual((2, 1), x4.shape); + + var x5 = tf.reduce_sum(b, (0, 1)); + Assert.AreEqual(-4.7f, (float)x5); + } + + [TestMethod] + public void Erf() + { + var erf = tf.math.erf(a, name: "erf"); + var expected = new float[] { 0.8427007f, -0.5204999f, 0.99999845f, -0.9970206f, 0f, -1f }; + var actual = erf.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } + + [TestMethod] + public void ReduceEuclideanNorm() + { + var x = tf.constant(new[,] { { 1, 2, 3 }, { 1, 1, 1 } }); + Assert.AreEqual(tf.math.reduce_euclidean_norm(x).numpy(), 4); + + var y = tf.constant(new[,] { { 1, 2, 3 }, { 1, 1, 1 } }, dtype: tf.float32); + Assert.IsTrue(Equal(tf.math.reduce_euclidean_norm(y).numpy(), 4.1231055f)); + + Assert.IsTrue(Equal(tf.math.reduce_euclidean_norm(y, 0).ToArray(), + new float[] { np.sqrt(2f), np.sqrt(5f), np.sqrt(10f) })); + + Assert.IsTrue(Equal(tf.math.reduce_euclidean_norm(y, 1).ToArray(), + new float[] { np.sqrt(14f), np.sqrt(3f) })); + + Assert.IsTrue(Equal(tf.math.reduce_euclidean_norm(y, 1, keepdims: true).ToArray(), + new float[] { np.sqrt(14f), np.sqrt(3f) })); + + Assert.AreEqual(tf.math.reduce_euclidean_norm(y, (0, 1)).numpy(), np.sqrt(17f)); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/NeuralNetworkTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/NeuralNetworkTest.cs new file mode 100644 index 000000000..f1b9f08a8 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/NeuralNetworkTest.cs @@ -0,0 +1,18 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Tensorflow.Binding; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NenuralNetwork +{ + [TestClass] + public class NeuralNetworkTest : EagerModeTestBase + { + [TestMethod] + public void l2_loss() + { + var x = tf.Variable(np.array(new[,] { { 1, 2, 3, 4 }, { 5, 6, 7, 8 } }), dtype: tf.float32); + var l2 = tf.nn.l2_loss(x); + Assert.AreEqual(l2.numpy(), 102f); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs new file mode 100644 index 000000000..7a3de882e --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + public class RaggedTensorTest :EagerModeTestBase + { + [TestMethod] + public void Test_from_row_lengths() + { + var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64)); + var rp = RowPartition.from_row_lengths(row_lengths, validate: false); + var rp_row_lengths = rp.row_lengths(); + var rp_nrows = rp.nrows(); + Assert.IsTrue(rp_nrows.ToArray()[0] == rp.nrows().ToArray()[0]); + + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs new file mode 100644 index 000000000..353d192f6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs @@ -0,0 +1,69 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class StringsApiTest + { + [TestMethod] + public void StringFromBytes() + { + var jpg = tf.constant(new byte[] { 0x41, 0xff, 0xd8, 0xff }, tf.@string); + var strings = jpg.ToString(); + Assert.AreEqual(strings, @"tf.Tensor: shape=(), dtype=string, numpy='A\xff\xd8\xff'"); + } + + [TestMethod] + public void StringEqual() + { + var str1 = tf.constant("Hello1"); + var str2 = tf.constant("Hello2"); + var result = tf.equal(str1, str2); + Assert.IsFalse(result.numpy()); + + var str3 = tf.constant("Hello1"); + result = tf.equal(str1, str3); + Assert.IsTrue(result.numpy()); + + var str4 = tf.strings.substr(str1, 0, 5); + var str5 = tf.strings.substr(str2, 0, 5); + result = tf.equal(str4, str5); + Assert.IsTrue(result.numpy()); + } + + [TestMethod] + public void ImageType() + { + var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); + var contents = tf.io.read_file(imgPath); + + var substr = tf.strings.substr(contents, 0, 3); + var jpg = tf.constant(new byte[] { 0xff, 0xd8, 0xff }, tf.@string); + + var result = math_ops.equal(substr, jpg); + Assert.IsTrue((bool)result); + } + + [TestMethod] + public void StringArray() + { + var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" }; + var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations"); + + Assert.AreEqual(3, tensor.shape[0]); + Assert.AreEqual(tensor[0].numpy(), strings[0]); + Assert.AreEqual(tensor[1].numpy(), strings[1]); + Assert.AreEqual(tensor[2].numpy(), strings[2]); + } + + [TestMethod] + public void StringSplit() + { + var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" }); + var ragged_tensor = tf.strings.split(tensor); + Assert.AreEqual((3, -1), ragged_tensor.shape); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs new file mode 100644 index 000000000..43c6c4293 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -0,0 +1,184 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System.Linq; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.ManagedAPI +{ + [TestClass] + public class TensorOperate + { + [TestMethod] + public void TransposeTest() + { + // https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2 + var x = tf.constant(new int[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }); + var transpose_x = tf.transpose(x); + Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy()); + Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy()); + Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy()); + + #region constant a + var a = tf.constant(np.array(new[, , ,] + { + { + { + { 1, 11, 2, 22 } + }, + { + { 3, 33, 4, 44 } + } + }, + { + { + { 5, 55, 6, 66 } + }, + { + { 7, 77, 8, 88 } + } + } + })); + + #endregion + var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 }); + + #region constant transpose_a + var expected_transposed_a = tf.constant(np.array(new[, , ,] + { + { + { { 1, 5 } }, { { 3, 7 } } + }, + { + { { 11, 55 } }, { { 33, 77 } } + }, + { + { + { 2, 6 } + }, + { + { 4, 8 } + } + }, + { + { + { 22, 66 } + }, + { + { 44, 88 } + } + } + })); + #endregion + Assert.AreEqual((4, 2, 1, 2), actual_transposed_a.shape); + Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy()); + } + + [TestMethod] + public void InitTensorTest() + { + var a = tf.constant(np.array(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + })); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims)); + + var b = tf.constant(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + }); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims)); + } + + [TestMethod] + public void ConcatTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); + } + + [TestMethod] + public void ConcatDoubleTest() + { + var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); + var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); + var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); + + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); + } + + [TestMethod] + public void ConcatAndSplitTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + + var value = tf.concat(new[] { a, b, c }, axis: 0); + + var splitValue = tf.split(value, 3, axis: 0); + Assert.AreEqual(3, splitValue.Length); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims)); + } + + #region ones/zeros like + [TestMethod] + public void TestOnesLike() + { + #region 2-dimension + var ones2D = tf.ones_like(new int[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }); + + Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy()); + Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy()); + #endregion + + #region 1-dimension + var ones1D = tf.ones_like(new int[,] + { + { 1, 2, 3 } + }); + + Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy()); + #endregion + } + + [TestMethod] + public void TestZerosLike() + { + #region 2-dimension + var zeros2D = tf.zeros_like(new int[,] + { + { 1, 2, 3 }, + { 4, 5, 6 } + }); + + Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy()); + Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy()); + #endregion + + #region 1-dimension + var zeros1D = tf.zeros_like(new int[,] + { + { 1, 2, 3 } + }); + + Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy()); + #endregion + } + #endregion + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/nn_test.py b/test/TensorFlowNET.UnitTest/ManagedAPI/nn_test.py new file mode 100644 index 000000000..82fab7418 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/nn_test.py @@ -0,0 +1,1243 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for miscellaneous functionality in tensorflow.ops.nn.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +from tensorflow.python.ops.nn_impl import _compute_sampled_logits +from tensorflow.python.platform import test as test_lib + + +class ZeroFractionTest(test_lib.TestCase): + + def _ZeroFraction(self, x): + assert x.shape + total_elements = np.prod(x.shape) + nonzeros = np.count_nonzero(x.flatten()) + return 1.0 - nonzeros / total_elements + + @test_util.run_deprecated_v1 + def testZeroFraction(self): + x_shape = [5, 17] + x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) + y_np = self._ZeroFraction(x_np) + + x_tf = constant_op.constant(x_np) + x_tf.set_shape(x_shape) + y_tf = nn_impl.zero_fraction(x_tf) + y_tf_np = self.evaluate(y_tf) + + eps = 1e-8 + self.assertAllClose(y_tf_np, y_np, eps) + + @test_util.run_deprecated_v1 + def testZeroFractionEmpty(self): + x = np.zeros(0) + y = self.evaluate(nn_impl.zero_fraction(x)) + self.assertTrue(np.isnan(y)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Zeros(self): + sparsity = nn_impl.zero_fraction( + array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(1.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Ones(self): + sparsity = nn_impl.zero_fraction( + array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(0.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testUnknownSize(self): + value = array_ops.placeholder(dtype=dtypes.float32) + sparsity = nn_impl.zero_fraction(value) + with self.cached_session() as sess: + self.assertAllClose( + 0.25, + sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) + + +class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = np.exp(x - m) + z = u.sum(1)[:, np.newaxis] + return u / z + + @test_util.run_in_graph_and_eager_modes + def testSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) + y_tf_np = self.evaluate(y_tf) + y_tf_last_dim_np = self.evaluate(y_tf_last_dim) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_last_dim_np, y_np, eps) + + def testSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 2e-8 + self.assertLess(err, eps) + + +class LogPoissonLossTest(test_lib.TestCase): + + def _log_poisson_loss(self, x, z, compute_full_loss=False): + lpl = np.exp(x) - z * x + if compute_full_loss: + stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) + lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) + return lpl + + @test_util.run_in_graph_and_eager_modes + def testLogPoissonLoss(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) + y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) + y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) + y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) + y_tf_np = self.evaluate(y_tf) + y_tf_np_stirling = self.evaluate(y_tf_stirling) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float64) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss( + z_np, x_tf, compute_full_loss=True) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + err_stirling = gradient_checker.compute_gradient_error( + x_tf, x_shape, y_tf_stirling, x_shape) + eps = 1e-6 + self.assertLess(err, eps) + self.assertLess(err_stirling, eps) + + +class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _log_softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = x - m + return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) + + @test_util.run_in_graph_and_eager_modes + def testLogSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._log_softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + y_tf_np = self.evaluate(y_tf) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + + def testLogSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 1e-7 + self.assertLess(err, eps) + + +class L2LossTest(test_lib.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testL2Loss(self): + for dtype in [dtypes.float32, dtypes.float64]: + x = constant_op.constant( + [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) + l2loss = nn_ops.l2_loss(x) + value = self.evaluate(l2loss) + self.assertAllClose(7.0, value) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) # Make it reproducible. + x_val = np.random.random_sample(x_shape).astype(np.float64) + with self.cached_session(): + x = constant_op.constant(x_val, name="x") + output = nn_ops.l2_loss(x) + err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) + print("L2Loss gradient err = %g " % err) + err_tolerance = 1e-10 + self.assertLess(err, err_tolerance) + + +class L2NormalizeTest(test_lib.TestCase): + + def _l2Normalize(self, x, dim): + if isinstance(dim, list): + norm = np.linalg.norm(x, axis=tuple(dim)) + for d in dim: + norm = np.expand_dims(norm, d) + return x / norm + else: + norm = np.apply_along_axis(np.linalg.norm, dim, x) + return x / np.expand_dims(norm, dim) + + @test_util.run_in_graph_and_eager_modes + def testL2Normalize(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in range(len(x_shape)): + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_in_graph_and_eager_modes + def testL2NormalizeDimArray(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_deprecated_v1 + def testL2NormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float64) + for dim in range(len(x_shape)): + with self.cached_session(): + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + print("L2Normalize gradient err = %g " % err) + self.assertLess(err, 1e-4) + + +class DropoutTest(test_lib.TestCase): + + def testDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. This time with shaped + # noise. + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropoutCorrelation(self): + # Runs a shaped dropout and tests that the correlations are correct. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + # Verifies that each y column as only one type of activation. + for i in xrange(x_dim): + sorted_value = np.unique(np.sort(value[i, :])) + self.assertEqual(sorted_value.size, 1) + + @test_util.run_deprecated_v1 + def testDropoutPlaceholderKeepProb(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.cached_session(): + t = constant_op.constant( + 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + keep_prob_placeholder = array_ops.placeholder(dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob_placeholder) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testShapedDropoutUnknownShape(self): + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout_x = nn_ops.dropout( + x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) + self.assertEqual(x.get_shape(), dropout_x.get_shape()) + + def testPartialShapedDropout(self): + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + # Set noise_shape=[None, 1] which means [x_dim, 1]. + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testInvalidKeepProb(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout(t, [0.0, 1.0]) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) + + @test_util.run_deprecated_v1 + def testInvalidRate(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, [0.0, 1.0]) + + @test_util.run_deprecated_v1 + def testShapedDropoutShapeError(self): + # Runs shaped dropout and verifies an error is thrown on misshapen noise. + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim + 3]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim]) + # test that broadcasting proceeds + _ = nn_ops.dropout(t, keep_prob, noise_shape=[y_dim]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, y_dim]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, 1]) + + def testNoDropoutFast(self): + x = array_ops.zeros((5,)) + y = nn_ops.dropout(x, keep_prob=1) + self.assertTrue(x is y) + + y = nn_ops.dropout_v2(x, rate=0) + self.assertTrue(x is y) + + def testDropoutWithIntegerInputs(self): + x = constant_op.constant([1, 1, 1, 1, 1]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(x, 0.5) + + +class ComputeSampledLogitsTest(test_lib.TestCase): + + def setUp(self): + self._eps = 1e-3 + + def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, + sampled, subtract_log_q): + """Randomly generates input/output data for a single test case. + + This function returns numpy constants for use in a test case. + + Args: + num_classes: An int. The number of embedding classes in the test case. + dim: An int. The dimension of the embedding. + batch_size: An int. The batch size. + num_true: An int. The number of target classes per training example. + labels: A list of batch_size * num_true ints. The target classes. + sampled: A list of indices in [0, num_classes). + subtract_log_q: A bool corresponding to the parameter in + _compute_sampled_logits(). + + Returns: + weights: Embedding weights to use as test input. It is a numpy array + of shape [num_classes, dim] + biases: Embedding biases to use as test input. It is a numpy array + of shape [num_classes]. + hidden_acts: Forward activations of the network to use as test input. + It is a numpy array of shape [batch_size, dim]. + sampled_vals: A tuple based on `sampled` to use as test input in the + format returned by a *_candidate_sampler function. + exp_logits: The output logits expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + exp_labels: The output labels expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + """ + weights = np.random.randn(num_classes, dim).astype(np.float32) + biases = np.random.randn(num_classes).astype(np.float32) + hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) + + true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) + sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) + sampled_vals = (sampled, true_exp, sampled_exp) + + sampled_w, sampled_b = weights[sampled], biases[sampled] + true_w, true_b = weights[labels], biases[labels] + + true_logits = np.sum( + hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( + (batch_size, num_true, dim)), + axis=2) + true_b = true_b.reshape((batch_size, num_true)) + true_logits += true_b + sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b + + if subtract_log_q: + true_logits -= np.log(true_exp) + sampled_logits -= np.log(sampled_exp[np.newaxis, :]) + + exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) + exp_labels = np.hstack((np.ones_like(true_logits) / num_true, + np.zeros_like(sampled_logits))) + + return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels + + def _ShardTestEmbeddings(self, weights, biases, num_shards): + """Shards the weights and biases returned by _GenerateTestData. + + Args: + weights: The weights returned by _GenerateTestData. + biases: The biases returned by _GenerateTestData. + num_shards: The number of shards to create. + + Returns: + sharded_weights: A list of size `num_shards` containing all the weights. + sharded_biases: A list of size `num_shards` containing all the biases. + """ + with ops.Graph().as_default() as g: + sharded_weights = variable_scope.get_variable( + "w", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(weights)) + sharded_biases = variable_scope.get_variable( + "b", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(biases)) + with self.session(graph=g) as sess: + variables.global_variables_initializer().run() + return self.evaluate([list(sharded_weights), list(sharded_biases)]) + + def testShapes(self): + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) + self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) + + def testBasic(self): + """Without accidental hit removal or subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testAccidentalHitRemoval(self): + """With accidental hit removal, no subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + sampled = [1, 0, 2, 3] + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, _, + _) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=sampled, + subtract_log_q=False) + logits_tensor, _ = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=len(sampled), + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=True, + partition_strategy="div", + name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) + # Test that the exponentiated logits of accidental hits are near 0. + # First we need to find the hits in this random test run: + labels_reshape = labels.reshape((batch_size, num_true)) + got_logits = self.evaluate(logits_tensor) + for row in xrange(batch_size): + row_labels = labels_reshape[row, :] + for col in xrange(len(sampled)): + if sampled[col] in row_labels: + # We need to add the num_true_test offset into logits_* + self.assertNear( + np.exp(got_logits[row, col + num_true]), 0., self._eps) + + def testSubtractLogQ(self): + """With subtract_log_q, no accidental hit removal.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=True, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_subtract_log_q_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testSharded(self): + """With sharded weights and sharded biases.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_sharded_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testNCELoss(self): + # A simple test to verify the numerics. + + def _SigmoidCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + pred = 1. / (1. + np.exp(-logits)) + eps = 0.0001 + pred = np.minimum(np.maximum(pred, eps), 1 - eps) + return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_nce_loss = np.sum( + _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) + + got_nce_loss = nn_impl.nce_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_nce_loss = nn_impl.nce_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + def testSampledSoftmaxLoss(self): + # A simple test to verify the numerics. + + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + def testSampledSoftmaxLossBf16(self): + # A simple test to verify the numerics for bfloat16. + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + sampled = [1, 0, 2, 3] + (weights, biases, hidden_acts, _, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=sampled, + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + true_exp_bf16 = np.full([batch_size, 1], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_exp_bf16 = np.full([len(sampled)], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + + got_sampled_softmax_loss = math_ops.cast( + nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights, dtype=dtypes.bfloat16), + biases=constant_op.constant(biases, dtype=dtypes.bfloat16), + labels=constant_op.constant( + labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), + inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals_bf16, + remove_accidental_hits=False), dtypes.float32) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-1) + + +class CReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) + + z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) + self.assertAllClose(y, z, 1e-4) + + +class ReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.maximum(x, 0.0) + + z = self.evaluate(nn_ops.relu(constant_op.constant(x))) + self.assertAllEqual(y, z) + + @test_util.run_deprecated_v1 + def testNaNs(self): + # Test that relu(nan) = nan for various sizes. + for i in range(18): + x = np.zeros(i) + np.nan + with self.cached_session(): + z = nn_ops.relu(constant_op.constant(x)).eval() + self.assertTrue(np.isnan(z).all()) + + +class LeakyReluTest(test_lib.TestCase): + + def testRange(self): + batch_size = 3 + height, width = 4, 4 + np.random.seed(1) # Make it reproducible. + inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( + np.float32) + inputs = constant_op.constant(inputs) + + outputs = nn_ops.leaky_relu(inputs) + self.assertEquals(inputs.shape, outputs.shape) + + inputs, outputs = self.evaluate([inputs, outputs]) + + self.assertGreaterEqual(outputs.min(), 0.0) + self.assertLessEqual(outputs.max(), 1.0) + self.assertAllClose(inputs, outputs) + + @test_util.run_deprecated_v1 + def testValues(self): + for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: + np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) + outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + + outputs = self.evaluate(outputs) + + tol = 2e-3 if dtype == np.float16 else 1e-6 + self.assertAllClose( + outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) + + @test_util.run_deprecated_v1 + def testName(self): + np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) + outputs_with_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values), + name='test_relu_op') + self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') + outputs_without_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values)) + self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') + + +class SwishTest(test_lib.TestCase): + + @test_util.run_deprecated_v1 + def testValues(self): + np_values = np.array( + [np.linspace(-10.0, 0.0, 100), + np.linspace(0.0, 10.0, 100)], + dtype=np.float32) + tf_values = constant_op.constant(np_values) + actual_tf_outputs = nn_impl.swish(tf_values) + expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) + + actual_outputs, expected_outputs = self.evaluate( + [actual_tf_outputs, expected_tf_outputs]) + + self.assertAllClose(actual_outputs, expected_outputs) + + @test_util.run_deprecated_v1 + def testGradients(self): + shape = [5, 3, 4] + sigma = 5 + input_values = np.random.randn(*shape) * sigma + x_tf = constant_op.constant(input_values) + y_tf = nn_impl.swish(x_tf) + with self.cached_session(): + err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) + self.assertLess(err, 1e-4) + + +class MomentsTest(test_lib.TestCase): + + def doOutputTest(self, + input_shape, + moments_axes, + tol=1e-4, + check_gradients=False): + for mu in [0.0, 1.0, 1e3]: + for sigma in [1.0, 0.1]: + for keep_dims in [True, False]: + input_values = np.random.rand(*input_shape) * sigma + mu + expected_mean = np.mean( + input_values, axis=moments_axes, keepdims=keep_dims) + expected_var = np.var( + input_values, axis=moments_axes, keepdims=keep_dims) + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + inputs = constant_op.constant( + input_values, shape=input_shape, dtype=dtypes.float32) + mean, variance = nn_impl.moments_v2( + inputs, moments_axes, keepdims=keep_dims) + + if check_gradients: + err = gradient_checker.compute_gradient_error( + inputs, input_shape, mean, mean.shape.as_list()) + self.assertLess(err, 1e-3) + err = gradient_checker.compute_gradient_error( + inputs, input_shape, variance, variance.shape.as_list()) + self.assertLess(err, 1e-3) + + # Evaluate. + [mean, variance] = self.evaluate([mean, variance]) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(mean).any()) + self.assertFalse(np.isnan(variance).any()) + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + + def testOutputAndGradient2DInput0(self): + self.doOutputTest((10, 10), (0,), check_gradients=True) + + def testOutputAndGradient2DInput01(self): + self.doOutputTest((10, 10), (0, 1), check_gradients=True) + + def testOutput2DInput0(self): + self.doOutputTest((10, 300), (0,)) + + def testOutput2DInput1(self): + self.doOutputTest((10, 300), (1,)) + + def testOutput2DInput01(self): + self.doOutputTest((10, 300), (0, 1)) + + def testOutput4DInput0(self): + self.doOutputTest((10, 10, 10, 30), (0,)) + + def testOutput4DInput1(self): + self.doOutputTest((10, 10, 10, 30), (1,)) + + def testOutput4DInput3(self): + self.doOutputTest((10, 10, 10, 30), (3,)) + + def testOutput4DInput012(self): + self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) + + def testOutput4DInput123(self): + self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) + + +class DataFormatDimMapTest(test_lib.TestCase): + + def _test(self, x_val, y_val_expected): + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x) + + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def test(self): + self._test(0, 0) + self._test(1, 2) + self._test(2, 3) + self._test(3, 1) + self._test(-1, 1) + self._test(-2, 3) + self._test(-3, 2) + self._test(-4, 0) + self._test([1, 3], [2, 1]) + self._test([1, 3, -2], [2, 1, 3]) + self._test([1, -3, -2], [2, 2, 3]) + self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + + def testNHWCtoNCHW(self): + x_val = [1, -3, -2] + y_val_expected = [2, 2, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoHWNC(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoWHCN(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testArbitraryASCII(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + +class DataFormatVectorPermuteTest(test_lib.TestCase): + + def testNHWCToNCHW(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test_lib.main() diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs new file mode 100644 index 000000000..5bf89f2c7 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -0,0 +1,22 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class NameScopeTest : EagerModeTestBase + { + string name = ""; + + [TestMethod] + public void NameScopeInEagerMode() + { + tf_with(new ops.NameScope("scope"), scope => + { + string name = scope; + var const1 = tf.constant(1.0); + }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs new file mode 100644 index 000000000..1d3ff9be5 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs @@ -0,0 +1,179 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/user/basics.indexing.html + /// + [TestClass] + public class ArrayIndexingTest : EagerModeTestBase + { + [TestMethod] + public void int_params() + { + var x = np.arange(24).reshape((2, 3, 4)); + x[1, 2, 3] = 1; + var y = x[1, 2, 3]; + Assert.AreEqual(y.shape, Shape.Scalar); + Assert.AreEqual(y, 1); + + x[0, 0] = new[] { 3, 1, 1, 2 }; + y = x[0, 0]; + Assert.AreEqual(y.shape, 4); + Assert.AreEqual(y, new[] { 3, 1, 1, 2 }); + + y = x[0]; + Assert.AreEqual(y.shape, (3, 4)); + + var z = np.arange(12).reshape((3, 4)); + x[1] = z; + Assert.AreEqual(x[1], z); + } + + [TestMethod] + public void slice_newaxis() + { + var x = np.arange(20).reshape((4, 5)); + var y = x[np.newaxis, ":2"]; + Assert.AreEqual(y.shape, (1, 2, 5)); + } + + [TestMethod] + public void slice_params() + { + var x = np.arange(12).reshape((3, 4)); + var y = x[new Slice(0, 1), new Slice(2)]; + Assert.AreEqual(y.shape, (1, 2)); + Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2))); + } + + [TestMethod] + public void slice_string_params() + { + var x = np.arange(12).reshape((3, 4)); + var y = x[Slice.ParseSlices("0:1,2:")]; + Assert.AreEqual(y.shape, (1, 2)); + Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2))); + } + + [TestMethod] + public void slice_out_bound() + { + var input_shape = tf.constant(new int[] { 1, 1 }); + var input_shape_val = input_shape.numpy(); + input_shape_val[(int)input_shape.size - 1] = 1; + input_shape.Dispose(); + } + + [TestMethod] + public void shape_helper_get_shape_3dim() + { + var x = np.arange(24).reshape((4, 3, 2)); + var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true)); + Assert.AreEqual(shape1, (3, 2)); + + var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1)); + Assert.AreEqual(shape2, (3, 3, 2)); + + var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All); + Assert.AreEqual(shape3, (2, 3, 2)); + + var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2)); + Assert.AreEqual(shape4, (1, 2)); + + var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1)); + Assert.AreEqual(shape5, (2, 2)); + + var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1)); + Assert.AreEqual(shape6, (3, 1)); + } + + [TestMethod] + public void shape_helper_get_shape_4dim() + { + var x = np.arange(120).reshape((4, 3, 2, 5)); + var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) }; + var shape1 = ShapeHelper.GetShape(x.shape, slices); + Assert.AreEqual(shape1, (2, 4)); + + var shape2 = ShapeHelper.GetShape(x.shape, Slice.All); + Assert.AreEqual(shape2, (4, 3, 2, 5)); + + var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true)); + Assert.AreEqual(shape3, (4, 3, 2)); + } + + [TestMethod] + public void iterating() + { + var array = np.array(new[,] { { 0, 3 }, { 2, 2 }, { 3, 1 } }); + int i = 0; + foreach(var x in array) + { + if (i == 0) + Assert.AreEqual(x, new[] { 0, 3 }); + else + Assert.AreEqual(x, array[i]); + i++; + } + } + + [TestMethod] + public void slice_step_setter() + { + var array = np.arange(32).reshape((4, 8)); + var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1; + Assert.AreEqual(s1.shape, (4, 2)); + var expected = new[] { 3, 5, 11, 13, 19, 21, 27, 29 }; + Assert.IsTrue(Enumerable.SequenceEqual(expected, s1.ToArray())); + array[Slice.All, new Slice(2, 5, 2)] = s1; + Assert.AreEqual(array[0], new[] { 0, 1, 3, 3, 5, 5, 6, 7 }); + Assert.AreEqual(array[1], new[] { 8, 9, 11, 11, 13, 13, 14, 15 }); + Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 }); + Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 }); + } + + [TestMethod] + public void slice_step_setter_diff_shape() + { + var array = np.arange(32).reshape((4, 8)); + var s1 = np.array(new[] { 100, 200 }); + array[Slice.All, new Slice(2, 5, 2)] = s1; + Assert.AreEqual(array[0], new[] { 0, 1, 100, 3, 200, 5, 6, 7 }); + Assert.AreEqual(array[1], new[] { 8, 9, 100, 11, 200, 13, 14, 15 }); + Assert.AreEqual(array[2], new[] { 16, 17, 100, 19, 200, 21, 22, 23 }); + Assert.AreEqual(array[3], new[] { 24, 25, 100, 27, 200, 29, 30, 31 }); + } + + [TestMethod] + public void mask_2d_get_value() + { + var x = np.arange(25).reshape((5, 5)); + var y = np.array(new[] { true, false, true, false, true }); + var z = x[y]; + Assert.AreEqual(z.shape, (3, 5)); + Assert.AreEqual(z[0], new[] { 0, 1, 2, 3, 4 }); + Assert.AreEqual(z[1], new[] { 10, 11, 12, 13, 14 }); + Assert.AreEqual(z[2], new[] { 20, 21, 22, 23, 24 }); + } + + [TestMethod] + public void mask_2d_set_value() + { + var x = np.arange(25).reshape((5, 5)); + var y = np.array(new[] {true, false, true, false, false}); + x[y] = 0; + Assert.AreEqual(x[0], new[] { 0, 0, 0, 0, 0 }); + Assert.AreEqual(x[1], new[] { 5, 6, 7, 8, 9 }); + Assert.AreEqual(x[2], new[] { 0, 0, 0, 0, 0 }); + Assert.AreEqual(x[3], new[] { 15, 16, 17, 18, 19 }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs new file mode 100644 index 000000000..289172a45 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs @@ -0,0 +1,44 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/user/basics.indexing.html + /// + [TestClass] + public class ArraySortingTest : EagerModeTestBase + { + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + /// + [TestMethod] + public void argsort() + { + var x = np.array(new[] { 3, 1, 2 }); + var ind = np.argsort(x); + Assert.AreEqual(ind, new[] { 1, 2, 0 }); + + var y = np.array(new[,] { { 0, 3 }, { 2, 2 } }); + ind = np.argsort(y, axis: 0); + Assert.AreEqual(ind[0], new[] { 0, 1 }); + Assert.AreEqual(ind[1], new[] { 1, 0 }); + } + + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.sort.html + /// + [TestMethod] + public void sort() + { + var x = np.array(new int[] { 3, 1, 2 }); + var sorted = np.sort(x); + // Assert.IsTrue(sorted.ToArray() is [1, 2, 3]); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs new file mode 100644 index 000000000..d6beb2599 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs @@ -0,0 +1,31 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.prod.html + /// + [TestClass] + public class LinearAlgebraTest : EagerModeTestBase + { + [TestMethod] + public void lstsq() + { + + } + + [TestMethod] + public void norm() + { + var x = np.arange(9) - 4; + var y = x.reshape((3, 3)); + var norm = np.linalg.norm(y); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs new file mode 100644 index 000000000..d9c04be6e --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs @@ -0,0 +1,42 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/routines.array-manipulation.html + /// + [TestClass] + public class ManipulationTest : EagerModeTestBase + { + [TestMethod] + public void expand_dims() + { + var x = np.array(new[] { 1, 2 }); + var y = np.expand_dims(x, axis: 0); + Assert.AreEqual(y.shape, (1, 2)); + + y = np.expand_dims(x, axis: 1); + Assert.AreEqual(y.shape, (2, 1)); + } + + [TestMethod] + public void moveaxis() + { + var x = np.zeros((3, 4, 5)); + var y = np.moveaxis(x, 0, -1); + Assert.AreEqual(y.shape, (4, 5, 3)); + + y = np.moveaxis(x, (0, 1), (-1, -2)); + Assert.AreEqual(y.shape, (5, 4, 3)); + + y = np.moveaxis(x, (0, 1, 2), (-1, -2, -3)); + Assert.AreEqual(y.shape, (5, 4, 3)); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs b/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs new file mode 100644 index 000000000..e4989a1dc --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/OperatorsTest.cs @@ -0,0 +1,33 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + [TestClass] + public class OperatorsTest + { + [TestMethod] + public void EqualToOperator() + { + NDArray n1 = null; + NDArray n2 = new NDArray(1); + + Assert.IsTrue(n1 == null); + Assert.IsFalse(n2 == null); + Assert.IsFalse(n1 == 1); + Assert.IsTrue(n2 == 1); + } + + [TestMethod] + public void NotEqualToOperator() + { + NDArray n1 = null; + NDArray n2 = new NDArray(1); + + Assert.IsFalse(n1 != null); + Assert.IsTrue(n2 != null); + Assert.IsTrue(n1 != 1); + Assert.IsFalse(n2 != 1); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs new file mode 100644 index 000000000..21db6acc0 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Persistence.Test.cs @@ -0,0 +1,42 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy; + +/// +/// https://numpy.org/doc/stable/reference/generated/numpy.save.html +/// +[TestClass] +public class PersistenceTest : EagerModeTestBase +{ + [TestMethod] + public void SaveNpy() + { + var x = np.arange(10f).reshape((2, 5)); + np.save("arange.npy", x); + + var x2 = np.load("arange.npy"); + Assert.AreEqual(x.shape, x2.shape); + } + + [TestMethod] + public void SaveNpz() + { + var x = np.arange(10f).reshape((2, 5)); + var y = np.arange(10f).reshape((5, 2)); + + np.savez("arange.npz", x, y); + var z = np.loadz("arange.npz"); + + np.savez("arange_named.npz", new { x, y }); + z = np.loadz("arange_named.npz"); + Assert.AreEqual(z["x"].shape, x.shape); + Assert.AreEqual(z["y"].shape, y.shape); + + np.savez_compressed("arange_compressed.npz", x, y); + np.savez_compressed("arange_compressed_named.npz", new { x, y }); + z = np.loadz("arange_compressed_named.npz"); + Assert.AreEqual(z["x"].shape, x.shape); + Assert.AreEqual(z["y"].shape, y.shape); + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs new file mode 100644 index 000000000..55801f55d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs @@ -0,0 +1,47 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/1.20/reference/random/index.html + /// + [TestClass] + public class RandomizeTest : EagerModeTestBase + { + [TestMethod] + public void permutation() + { + var x = np.random.permutation(10); + Assert.AreEqual(x.shape, 10); + var y = np.random.permutation(x); + Assert.AreEqual(x.shape, 10); + Assert.AreNotEqual(x.ToArray(), y.ToArray()); + } + + /// + /// https://numpy.org/doc/stable/reference/random/generated/numpy.random.normal.html + /// + [TestMethod] + public void normal() + { + var x = np.random.normal(0, 0.1f, 1000); + Equal(np.mean(x), 0f); + } + + [TestMethod] + public void randn() + { + var x = np.random.randn(); + Assert.AreEqual(np.float32, x.dtype); + + x = np.random.randn(2, 4); + Equal(np.mean(x), 0f); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs new file mode 100644 index 000000000..f5a8685be --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs @@ -0,0 +1,44 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using static Tensorflow.Binding; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.NumPy +{ + [TestClass] + public class ShapeTest : EagerModeTestBase + { + [Ignore] + [TestMethod] + public unsafe void ShapeGetLastElements() + { + // test code from function _CheckAtLeast3DImage + // 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的 + // todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试 + + var image_shape = new Shape(new[] { 32, 64, 3 }); + var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 }); + + var image_shape_last_three_elements = new Shape(new[] { + image_shape.dims[image_shape.dims.Length - 3], + image_shape.dims[image_shape.dims.Length - 2], + image_shape.dims[image_shape.dims.Length - 1]}); + + var image_shape_last_three_elements2 = image_shape["-3:"]; + + Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail."); + + var image_shape_last_three_elements_4d = new Shape(new[] { + image_shape_4d.dims[image_shape_4d.dims.Length - 3], + image_shape_4d.dims[image_shape_4d.dims.Length - 2], + image_shape_4d.dims[image_shape_4d.dims.Length - 1]}); + + var image_shape_last_three_elements2_4d = image_shape_4d["-3:"]; + + Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail."); + } + + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs new file mode 100644 index 000000000..42005b151 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs @@ -0,0 +1,32 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/routines.statistics.html + /// + [TestClass] + public class StatisticsTest : EagerModeTestBase + { + [TestMethod] + public void average() + { + var data = np.arange(1, 5); + var avg = np.average(data); + Assert.AreEqual(avg, 2.5); + + data = np.arange(6).reshape((3, 2)); + avg = np.average(data, axis: 1); + assertAllEqual(avg.ToArray(), new[] { 0.5, 2.5, 4.5 }); + + // avg = np.average(data, axis: 1, weights: new[] { 1.0 / 4, 3.0 / 4 }); + // assertAllEqual(avg.ToArray(), new[] { 0.75, 2.75, 4.75 }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs new file mode 100644 index 000000000..fc309c3c6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs @@ -0,0 +1,119 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/routines.array-creation.html + /// + [TestClass] + public class ArrayCreationTest : EagerModeTestBase + { + [TestMethod] + public void empty_zeros_ones_full() + { + var empty = np.empty((2, 2)); + var zeros = np.zeros((2, 2)); + var ones = np.ones((2, 2)); + var full = np.full((2, 2), 0.1f); + Assert.AreEqual(np.float32, full.dtype); + } + + [TestMethod] + public void arange() + { + var x = np.arange(3); + AssetSequenceEqual(new[] { 0, 1, 2 }, x.ToArray()); + + x = np.arange(3f); + Assert.IsTrue(Equal(new float[] { 0, 1, 2 }, x.ToArray())); + + var y = np.arange(3, 7); + AssetSequenceEqual(new[] { 3, 4, 5, 6 }, y.ToArray()); + + y = np.arange(3, 7, 2); + AssetSequenceEqual(new[] { 3, 5 }, y.ToArray()); + } + + [TestMethod] + public void array() + { + var x = np.array(1, 2, 3); + AssetSequenceEqual(new[] { 1, 2, 3 }, x.ToArray()); + + x = np.array(new[,] { { 1, 2 }, { 3, 4 }, { 5, 6 } }); + AssetSequenceEqual(new[] { 1, 2, 3, 4, 5, 6 }, x.ToArray()); + } + + [TestMethod] + public void to_multi_dim_array() + { + var x1 = np.arange(12); + var y1 = x1.ToMultiDimArray(); + AssetSequenceEqual((int[])y1, x1.ToArray()); + + var x2 = np.arange(12).reshape((2, 6)); + var y2 = (int[,])x2.ToMultiDimArray(); + Assert.AreEqual(x2[0, 5], y2[0, 5]); + + var x3 = np.arange(12).reshape((2, 2, 3)); + var y3 = (int[,,])x3.ToMultiDimArray(); + Assert.AreEqual(x3[0, 1, 2], y3[0, 1, 2]); + } + + [TestMethod] + public void eye() + { + var x = np.eye(3, k: 1); + Assert.IsTrue(Equal(new double[] { 0, 1, 0, 0, 0, 1, 0, 0, 0 }, x.ToArray())); + } + + [TestMethod] + public void linspace() + { + var x = np.linspace(2.0, 3.0, num: 5); + Assert.IsTrue(Equal(new double[] { 2, 2.25, 2.5, 2.75, 3 }, x.ToArray())); + + x = np.linspace(2.0, 3.0, num: 5, endpoint: false); + Assert.IsTrue(Equal(new double[] { 2, 2.2, 2.4, 2.6, 2.8 }, x.ToArray())); + } + + [TestMethod] + public void meshgrid() + { + var x = np.linspace(0, 1, num: 3); + var y = np.linspace(0, 1, num: 2); + var (xv, yv) = np.meshgrid(x, y); + Assert.IsTrue(Equal(new double[] { 0, 0.5, 1, 0, 0.5, 1 }, xv.ToArray())); + Assert.IsTrue(Equal(new double[] { 0, 0, 0, 1, 1, 1 }, yv.ToArray())); + + (xv, yv) = np.meshgrid(x, y, sparse: true); + Assert.IsTrue(Equal(new double[] { 0, 0.5, 1 }, xv.ToArray())); + AssetSequenceEqual(new long[] { 1, 3 }, xv.shape.dims); + Assert.IsTrue(Equal(new double[] { 0, 1 }, yv.ToArray())); + AssetSequenceEqual(new long[] { 2, 1 }, yv.shape.dims); + } + + [TestMethod] + public void meshgrid_same_ndim() + { + var (a, b) = np.meshgrid(np.arange(3), np.arange(3)); + AssetSequenceEqual(a.ToArray(), new int[] { 0, 1, 2, 0, 1, 2, 0, 1, 2 }); + AssetSequenceEqual(b.ToArray(), new int[] { 0, 0, 0, 1, 1, 1, 2, 2, 2 }); + } + + [TestMethod] + public void to_numpy_string() + { + var nd = np.arange(10 * 10 * 10 * 10).reshape((10, 10, 10, 10)); + var str = NDArrayRender.ToString(nd); + Assert.AreEqual("array([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],", str.Substring(0, 40)); + Assert.AreEqual("[9990, 9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998, 9999]]]])", str.Substring(str.Length - 64)); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs new file mode 100644 index 000000000..65cdaedd9 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs @@ -0,0 +1,111 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.prod.html + /// + [TestClass] + public class MathTest : EagerModeTestBase + { + [TestMethod] + public void prod() + { + var p = np.prod(1.0, 2.0); + Assert.AreEqual(p, 2.0); + + p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); + Assert.AreEqual(p, 24.0); + + p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }, axis: 1); + Assert.AreEqual(p.shape, 2); + Assert.IsTrue(Equal(p.ToArray(), new[] { 2.0, 12.0 })); + } + + [TestMethod] + public void astype() + { + var x = np.array(new byte[] { 1, 100, 200 }); + var x1 = x.astype(np.float32); + Assert.AreEqual(x1[2], 200f); + } + + [TestMethod] + public void divide() + { + var x = np.array(new float[] { 1, 100, 200 }); + var y = x / 2; + Assert.AreEqual(y.dtype, np.float32); + } + + [TestMethod] + public void sin() + { + var x = np.sin(np.pi / 2); + Assert.AreEqual(x, 1d); + } + + [TestMethod] + public void cos() + { + var x = np.cos(np.pi / 2); + Assert.AreEqual(x, 6.123233995736766e-17); + } + + [TestMethod] + public void power() + { + var x = np.arange(6); + var y = np.power(x, 3); + Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 }); + } + [TestMethod] + public void square() + { + var x = np.arange(6); + var y = np.square(x); + Assert.AreEqual(y, new[] { 0, 1, 4, 9, 16, 25 }); + } + [TestMethod] + public void dotproduct() + { + var x1 = new NDArray(new[] { 1, 2, 3 }); + var x2 = new NDArray(new[] { 4, 5, 6 }); + double result1 = np.dot(x1, x2); + NDArray y1 = new float[,] { + { 1.0f, 2.0f, 3.0f }, + { 4.0f, 5.1f,6.0f }, + { 4.0f, 5.1f,6.0f } + }; + NDArray y2 = new float[,] { + { 3.0f, 2.0f, 1.0f }, + { 6.0f, 5.1f, 4.0f }, + { 6.0f, 5.1f, 4.0f } + }; + double result2 = np.dot(y1, y2); + Assert.AreEqual(result1, 32); + Assert.AreEqual(Math.Round(result2, 2), 158.02); + } + [TestMethod] + public void maximum() + { + var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } }); + var x2 = new NDArray(new[,] { { 3, 2, 1 }, { 6, 5.1, 4 } }); + var y0 = np.maximum(x1,x2); + var y1 = np.maximum(x1, x2, axis: 0); + var y2 = np.maximum(x1, x2, axis: 1); + var y3 = new NDArray(new[,] { { 3, 2, 3 }, { 6, 5.1, 6 } }); + var y4 = new NDArray(new[] { 6, 5.1, 6 }); + var y5 = new NDArray(new[] { 3.0, 6 }); + Assert.AreEqual(y0, y3); + Assert.AreEqual(y1, y4); + Assert.AreEqual(y2, y5); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Open.snk b/test/TensorFlowNET.UnitTest/Open.snk new file mode 100644 index 000000000..22a3cbd25 Binary files /dev/null and b/test/TensorFlowNET.UnitTest/Open.snk differ diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs deleted file mode 100644 index 23b170702..000000000 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ /dev/null @@ -1,55 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; - -namespace TensorFlowNET.UnitTest -{ - [TestClass] - public class OperationsTest - { - [TestMethod] - public void constant() - { - var x = tf.constant(4.0f); - } - - [TestMethod] - public void placeholder() - { - var x = tf.placeholder(tf.float32); - } - - [TestMethod] - public void addInPlaceholder() - { - var a = tf.placeholder(tf.float32); - var b = tf.placeholder(tf.float32); - var c = tf.add(a, b); - - using(var sess = tf.Session()) - { - var feed_dict = new Dictionary(); - feed_dict.Add(a, 3.0f); - feed_dict.Add(b, 2.0f); - - var o = sess.run(c, feed_dict); - } - } - - [TestMethod] - public void addInConstant() - { - var a = tf.constant(4.0f); - var b = tf.constant(5.0f); - var c = tf.add(a, b); - - using (var sess = tf.Session()) - { - var o = sess.run(c); - Assert.AreEqual(o, 9.0f); - } - } - } -} diff --git a/test/TensorFlowNET.UnitTest/StatusTest.cs b/test/TensorFlowNET.UnitTest/StatusTest.cs index 8e1baede5..6dcdc158e 100644 --- a/test/TensorFlowNET.UnitTest/StatusTest.cs +++ b/test/TensorFlowNET.UnitTest/StatusTest.cs @@ -1,10 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using System.Collections.Generic; -using System.Text; using Tensorflow; -namespace TensorFlowNET.UnitTest +namespace TensorFlowNET.UnitTest.Basics { [TestClass] public class StatusTest @@ -23,14 +21,13 @@ public void SetStatus() var s = new Status(); s.SetStatus(TF_Code.TF_CANCELLED, "cancel"); Assert.AreEqual(s.Code, TF_Code.TF_CANCELLED); - // Assert.AreEqual(s.Message, "cancel"); + Assert.AreEqual(s.Message, "cancel"); } [TestMethod] public void DeleteStatus() { var s = new Status(); - s.Dispose(); } } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj deleted file mode 100644 index de2eb99e8..000000000 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ /dev/null @@ -1,25 +0,0 @@ - - - - netcoreapp2.1 - - false - - - - DEBUG;TRACE - true - - - - - - - - - - - - - - diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings new file mode 100644 index 000000000..6cbf8796d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs deleted file mode 100644 index f111e07c4..000000000 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ /dev/null @@ -1,31 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using NumSharp.Core; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Tensorflow; - -namespace TensorFlowNET.UnitTest -{ - [TestClass] - public class TensorTest - { - [TestMethod] - public unsafe void NewTensor() - { - var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); - - var tensor = new Tensor(nd); - var array = tensor.Data(); - - Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); - Assert.AreEqual(tensor.rank, nd.ndim); - Assert.AreEqual(tensor.shape[0], nd.shape[0]); - Assert.AreEqual(tensor.shape[1], nd.shape[1]); - Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); - Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), array)); - } - } -} diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj new file mode 100644 index 000000000..5264cb104 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -0,0 +1,66 @@ + + + + net6.0 + false + false + false + Open.snk + 10.0 + AnyCPU;x64 + + + + DEBUG;TRACE + true + x64 + + + + DEBUG;TRACE + true + x64 + + + + true + x64 + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + PreserveNewest + + + Always + + + + diff --git a/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs new file mode 100644 index 000000000..65c69a3f9 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs @@ -0,0 +1,21 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.TextApi; + +namespace TensorFlowNET.UnitTest.Text +{ + [TestClass] + public class TokenizerTest + { + [TestMethod, Ignore] + public void Tokenize() + { + var docs = tf.constant(new[] { "Everything not saved will be lost." }); + var tokenizer = text.WhitespaceTokenizer(); + var tokens = tokenizer.tokenize(docs); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs b/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs new file mode 100644 index 000000000..1283ecaf2 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs @@ -0,0 +1,64 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Training +{ + [TestClass] + public class BasicLinearModel + { + /// + /// Linear Regression without tf.train.Optimizer + /// https://www.tensorflow.org/tutorials/customization/custom_training + /// + [TestMethod] + public void LinearRegression() + { + // Initialize the weights to `5.0` and the bias to `0.0` + // In practice, these should be initialized to random values (for example, with `tf.random.normal`) + var W = tf.Variable(5.0f); + var b = tf.Variable(0.0f); + + // Define linear model + Func model = (x) => W * x + b; + + // Define the loss function + Func loss = (target_y, predicted_y) + => tf.reduce_mean(tf.square(target_y - predicted_y)); + + int NUM_EXAMPLES = 1000; + float TRUE_W = 3.0f; + float TRUE_b = 2.0f; + + var inputs = tf.random.normal(shape: NUM_EXAMPLES); + var noise = tf.random.normal(shape: NUM_EXAMPLES); + var outputs = inputs * TRUE_W + TRUE_b + noise; + + Tensor init_loss = loss(model(inputs), outputs); + // print($"Current loss: {init_loss.numpy()}"); + + // Define a training loop + Func train = (inputs, outputs, learning_rate) + => + { + using var t = tf.GradientTape(); + var current_loss = loss(outputs, model(inputs)); + var (dW, db) = t.gradient(current_loss, (W, b)); + W.assign_sub(learning_rate * dW); + b.assign_sub(learning_rate * db); + return current_loss; + }; + + var epochs = range(10); + foreach (var epoch in epochs) + { + var current_loss = train(inputs, outputs, 0.1f); + print($"Epoch {epoch}: W={(float)W.numpy()} b={(float)b.numpy()}, loss={(float)current_loss.numpy()}"); + + if (epoch > 0) // skip first epoch + Assert.IsTrue((bool)(current_loss < init_loss)); + } + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs new file mode 100644 index 000000000..3b53ff9cd --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs @@ -0,0 +1,232 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Linq; +using Tensorflow; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Training +{ + [TestClass] + public class GradientDescentOptimizerTest : PythonTest + { + private static TF_DataType GetTypeForNumericType() where T : struct + { + return Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Single => np.float32, + TypeCode.Double => np.float64, + _ => throw new NotImplementedException(), + }; + } + + private void TestBasic() where T : struct + { + var dtype = GetTypeForNumericType(); + + // train.GradientDescentOptimizer is V1 only API. + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); + var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); + var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); + var optimizer = tf.train.GradientDescentOptimizer(3.0f); + var grads_and_vars = new[] { + Tuple.Create(grads0, var0 as IVariableV1), + Tuple.Create(grads1, var1 as IVariableV1) + }; + var sgd_op = optimizer.apply_gradients(grads_and_vars); + + var global_variables = tf.global_variables_initializer(); + sess.run(global_variables); + + var initialVar0 = sess.run(var0); + var initialVar1 = sess.run(var1); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params + self.assertAllCloseAccordingToType( + new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, + self.evaluate(var0)); + self.assertAllCloseAccordingToType( + new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, + self.evaluate(var1)); + // TODO: self.assertEqual(0, len(optimizer.variables())); + } + } + + [TestMethod] + public void TestBasic() + { + //TODO: add np.half + TestBasic(); + TestBasic(); + } + + private void TestMinimizeResourceVariable() where T : struct + { + var dtype = GetTypeForNumericType(); + + // train.GradientDescentOptimizer is V1 only API. + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var var0 = tf.Variable(new[,] { { 1.0f, 2.0f } }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype); + var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype); + + var pred = math_ops.matmul(var0, x) + var1; + var loss = pred * pred; + var sgd_op = tf.train.GradientDescentOptimizer(1.0f).minimize(loss); + + var global_variables = tf.global_variables_initializer(); + sess.run(global_variables); + + sess.run(new[] { var0, var1 }); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[,] { { 1.0, 2.0 } }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params + var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0; + var np_grad = 2 * np_pred; + self.assertAllCloseAccordingToType( + new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } }, + self.evaluate(var0)); + self.assertAllCloseAccordingToType( + new[] { 3.0 - np_grad }, + self.evaluate(var1)); + } + } + + [TestMethod] + public void TestMinimizeResourceVariable() + { + //TODO: add np.half + TestMinimizeResourceVariable(); + TestMinimizeResourceVariable(); + } + + private void TestTensorLearningRate() where T : struct + { + var dtype = GetTypeForNumericType(); + + // train.GradientDescentOptimizer is V1 only API. + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); + var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); + var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); + var lrate = constant_op.constant(3.0); + var grads_and_vars = new[] { + Tuple.Create(grads0, var0 as IVariableV1), + Tuple.Create(grads1, var1 as IVariableV1) + }; + var sgd_op = tf.train.GradientDescentOptimizer(lrate) + .apply_gradients(grads_and_vars); + + var global_variables = tf.global_variables_initializer(); + sess.run(global_variables); + + var initialVar0 = sess.run(var0); + var initialVar1 = sess.run(var1); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params + self.assertAllCloseAccordingToType( + new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, + self.evaluate(var0)); + self.assertAllCloseAccordingToType( + new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, + self.evaluate(var1)); + // TODO: self.assertEqual(0, len(optimizer.variables())); + } + } + + [TestMethod] + public void TestTensorLearningRate() + { + //TODO: add np.half + TestTensorLearningRate(); + TestTensorLearningRate(); + } + + public void TestGradWrtRef() where T : struct + { + var dtype = GetTypeForNumericType(); + + var graph = tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var opt = tf.train.GradientDescentOptimizer(3.0f); + var values = new[] { 1.0, 3.0 }; + var vars_ = values.Select( + v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1 + ).ToList(); + var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_); + sess.run(tf.global_variables_initializer()); + foreach (var (grad, _) in grads_and_vars) + self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate(grad)); + + } + } + + [TestMethod] + public void TestGradWrtRef() + { + TestGradWrtRef(); + TestGradWrtRef(); + } + + public void TestWithGlobalStep() where T : struct + { + var dtype = GetTypeForNumericType(); + + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var global_step = tf.Variable(0, trainable: false); + var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); + var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); + var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); + var grads_and_vars = new[] { + Tuple.Create(grads0, var0 as IVariableV1), + Tuple.Create(grads1, var1 as IVariableV1) + }; + var sgd_op = tf.train.GradientDescentOptimizer(3.0f) + .apply_gradients(grads_and_vars, global_step: global_step); + + sess.run(tf.global_variables_initializer()); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params and global_step + self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate(var1)); + Assert.AreEqual(1, self.evaluate(global_step)); + } + + } + + [TestMethod] + public void TestWithGlobalStep() + { + TestWithGlobalStep(); + TestWithGlobalStep(); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs new file mode 100644 index 000000000..41d8ab031 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs @@ -0,0 +1,744 @@ +using FluentAssertions; +using FluentAssertions.Execution; +using FluentAssertions.Primitives; +using Tensorflow.NumPy; +using System; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [DebuggerStepThrough] + public static class FluentExtension + { + public static ShapeAssertions Should(this Shape shape) + { + return new ShapeAssertions(shape); + } + + public static NDArrayAssertions Should(this NDArray arr) + { + return new NDArrayAssertions(arr); + } + + public static string ToString(this Array arr, bool flat) + { + // return new NDArray(arr).ToString(flat); + throw new NotImplementedException(""); + } + } + + [DebuggerStepThrough] + public class ShapeAssertions : ReferenceTypeAssertions + { + public ShapeAssertions(Shape instance) + { + Subject = instance; + } + + protected override string Identifier => "shape"; + + public AndConstraint BeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.size.Should().Be(size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint NotBeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.size.Should().NotBe(size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(params int[] dimensions) + { + if (dimensions == null) + throw new ArgumentNullException(nameof(dimensions)); + + if (dimensions.Length == 0) + throw new ArgumentException("Value cannot be an empty collection.", nameof(dimensions)); + + Subject.dims.Should().BeEquivalentTo(dimensions); + return new AndConstraint(this); + } + + public AndConstraint Be(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(Subject.Equals(shape)) + .FailWith($"Expected shape to be {shape.ToString()} but got {Subject.ToString()}"); + + return new AndConstraint(this); + } + + public AndConstraint BeEquivalentTo(int? size = null, int? ndim = null, ITuple shape = null) + { + if (size.HasValue) + { + BeOfSize(size.Value, null); + } + + if (ndim.HasValue) + HaveNDim(ndim.Value); + + if (shape != null) + for (int i = 0; i < shape.Length; i++) + { + Subject.dims[i].Should().Be((int)shape[i]); + } + + return new AndConstraint(this); + } + + public AndConstraint NotBe(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(!Subject.Equals(shape)) + .FailWith($"Expected shape to be {shape.ToString()} but got {Subject.ToString()}"); + + return new AndConstraint(this); + } + + public AndConstraint HaveNDim(int ndim) + { + Subject.dims.Length.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint BeScalar() + { + Subject.IsScalar.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint NotBeScalar() + { + Subject.IsScalar.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint BeNDim(int ndim) + { + Subject.dims.Length.Should().Be(ndim); + return new AndConstraint(this); + } + } + + //[DebuggerStepThrough] + public class NDArrayAssertions : ReferenceTypeAssertions + { + public NDArrayAssertions(NDArray instance) + { + Subject = instance; + } + + protected override string Identifier => "shape"; + + public AndConstraint BeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.size.Should().Be((ulong)size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(params int[] dimensions) + { + if (dimensions == null) + throw new ArgumentNullException(nameof(dimensions)); + + if (dimensions.Length == 0) + throw new ArgumentException("Value cannot be an empty collection.", nameof(dimensions)); + + Subject.dims.Should().BeEquivalentTo(dimensions); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(int? size = null, int? ndim = null, ITuple shape = null) + { + if (size.HasValue) + { + BeOfSize(size.Value, null); + } + + if (ndim.HasValue) + HaveNDim(ndim.Value); + + if (shape != null) + for (int i = 0; i < shape.Length; i++) + { + Subject.dims[i].Should().Be((int)shape[i]); + } + + return new AndConstraint(this); + } + + public AndConstraint NotBeShaped(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(!Subject.dims.Equals(shape.dims)) + .FailWith($"Expected shape to be {shape} but got {Subject}"); + + return new AndConstraint(this); + } + + public AndConstraint HaveNDim(int ndim) + { + Subject.ndim.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint BeScalar() + { + Subject.shape.IsScalar.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint BeOfType(Type typeCode) + { + Subject.dtype.Should().Be(typeCode); + return new AndConstraint(this); + } + + public AndConstraint NotBeScalar() + { + Subject.shape.IsScalar.Should().BeFalse(); + return new AndConstraint(this); + } + + + public AndConstraint BeNDim(int ndim) + { + Subject.ndim.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint Be(NDArray expected) + { + Execute.Assertion + .ForCondition(np.array_equal(Subject, expected)) + .FailWith($"Expected the subject and other ndarray to be equals.\n------- Subject -------\n{Subject}\n------- Expected -------\n{expected}"); + + return new AndConstraint(this); + } + + public AndConstraint AllValuesBe(object val) + { + + #region Compute + + /*switch (Subject.typecode) + { + case NPTypeCode.Boolean: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToBoolean(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Byte: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToByte(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Byte).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt16(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt16(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt32(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt32(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt64(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt64(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Char: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToChar(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Char).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Double: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToDouble(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Double).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Single: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToSingle(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Single).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Decimal: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToDecimal(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Decimal).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + default: + throw new NotSupportedException(); + }*/ + + #endregion + + return new AndConstraint(this); + } + + public AndConstraint BeOfValuesApproximately(double sensitivity, params object[] values) + { + if (values == null) + throw new ArgumentNullException(nameof(values)); + + Subject.size.Should().Be((ulong)values.Length, "the method BeOfValuesApproximately also confirms the sizes are matching with given values."); + + #region Compute + + /*switch (Subject.typecode) + { + case NPTypeCode.Boolean: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToBoolean(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Byte: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToByte(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs((double)(expected - nextval)) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Char: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToChar(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Double: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDouble(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Single: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToSingle(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Decimal: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDecimal(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= (decimal)sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + default: + throw new NotSupportedException(); + }*/ + + #endregion + + return new AndConstraint(this); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs new file mode 100644 index 000000000..f58d765b7 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs @@ -0,0 +1,573 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting +{ + using System; + //using System.Diagnostics; + using System.Globalization; + using System.Reflection; + + /// + /// This class represents a private class for the Private Accessor functionality. + /// + internal class PrivateType + { + /// + /// Binds to everything + /// + private const BindingFlags BindToEveryThing = BindingFlags.Default + | BindingFlags.NonPublic | BindingFlags.Instance + | BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy; + + /// + /// The wrapped type. + /// + private Type type; + + ///// + ///// Initializes a new instance of the class that contains the private type. + ///// + ///// Assembly name + ///// fully qualified name of the + //public PrivateType(string assemblyName, string typeName) + //{ + // Helper.CheckParameterNotNullOrEmpty(assemblyName, "assemblyName", string.Empty); + // Helper.CheckParameterNotNullOrEmpty(typeName, "typeName", string.Empty); + // Assembly asm = Assembly.Load(assemblyName); + + // this.type = asm.GetType(typeName, true); + //} + + /// + /// Initializes a new instance of the class that contains + /// the private type from the type object + /// + /// The wrapped Type to create. + public PrivateType(Type type) + { + if (type == null) + { + throw new ArgumentNullException("type"); + } + + this.type = type; + } + + /// + /// Gets the referenced type + /// + public Type ReferencedType => this.type; + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, params object[] args) + //{ + // return this.InvokeStatic(name, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) + //{ + // return this.InvokeStatic(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture info + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, BindingFlags.InvokeMethod, parameterTypes, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, culture, null); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // MethodInfo member = this.type.GetMethod(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, parameterTypes, null); + // if (member == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // try + // { + // if (member.IsGenericMethodDefinition) + // { + // MethodInfo constructed = member.MakeGenericMethod(typeArguments); + // return constructed.Invoke(null, bindingFlags, null, args, culture); + // } + // else + // { + // return member.Invoke(null, bindingFlags, null, args, culture); + // } + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); + // } + //} + + ///// + ///// Gets the element in static array + ///// + ///// Name of the array + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the indices would be {10,11} + ///// + ///// element at the specified location + //public object GetStaticArrayElement(string name, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticArrayElement(name, BindToEveryThing, indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticArrayElement(name, BindToEveryThing, value, indices); + //} + + ///// + ///// Gets the element in satatic array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the array would be {10,11} + ///// + ///// element at the spcified location + //public object GetStaticArrayElement(string name, BindingFlags bindingFlags, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); + // return arr.GetValue(indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + // arr.SetValue(value, indices); + //} + + ///// + ///// Gets the static field + ///// + ///// Name of the field + ///// The static field. + //public object GetStaticField(string name) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticField(name, BindToEveryThing); + //} + + ///// + ///// Sets the static field + ///// + ///// Name of the field + ///// Arguement to the invocation + //public void SetStaticField(string name, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticField(name, BindToEveryThing, value); + //} + + ///// + ///// Gets the static field using specified InvokeHelper attributes + ///// + ///// Name of the field + ///// Additional invocation attributes + ///// The static field. + //public object GetStaticField(string name, BindingFlags bindingFlags) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + //} + + ///// + ///// Sets the static field using binding attributes + ///// + ///// Name of the field + ///// Additional InvokeHelper attributes + ///// Arguement to the invocation + //public void SetStaticField(string name, BindingFlags bindingFlags, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.InvokeHelperStatic(name, BindingFlags.SetField | bindingFlags | BindingFlags.Static, new[] { value }, CultureInfo.InvariantCulture); + //} + + /// + /// Gets the static field or property + /// + /// Name of the field or property + /// The static field or property. + public object GetStaticFieldOrProperty(string name) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.GetStaticFieldOrProperty(name, BindToEveryThing); + } + + /// + /// Sets the static field or property + /// + /// Name of the field or property + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.SetStaticFieldOrProperty(name, BindToEveryThing, value); + } + + /// + /// Gets the static field or property using specified InvokeHelper attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// The static field or property. + public object GetStaticFieldOrProperty(string name, BindingFlags bindingFlags) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + } + + /// + /// Sets the static field or property using binding attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, BindingFlags bindingFlags, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.InvokeHelperStatic(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags | BindingFlags.Static, new[] { value }, CultureInfo.InvariantCulture); + } + + ///// + ///// Gets the static property + ///// + ///// Name of the field or property + ///// Arguements to the invocation + ///// The static property. + //public object GetStaticProperty(string name, params object[] args) + //{ + // return this.GetStaticProperty(name, BindToEveryThing, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, params object[] args) + //{ + // this.SetStaticProperty(name, BindToEveryThing, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, Type[] parameterTypes, object[] args) + //{ + // this.SetStaticProperty(name, BindingFlags.SetProperty, value, parameterTypes, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.GetStaticProperty(name, BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // return pi.GetValue(null, args); + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.GetProperty, args, null); + // } + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// Optional index values for indexed properties. The indexes of indexed properties are zero-based. This value should be null for non-indexed properties. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, params object[] args) + //{ + // this.SetStaticProperty(name, bindingFlags, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // pi.SetValue(null, value, args); + // } + // else + // { + // object[] pass = new object[(args?.Length ?? 0) + 1]; + // pass[0] = value; + // args?.CopyTo(pass, 1); + // this.InvokeHelperStatic(name, bindingFlags | BindingFlags.SetProperty, pass, null); + // } + //} + + /// + /// Invokes the static method + /// + /// Name of the member + /// Additional invocation attributes + /// Arguements to the invocation + /// Culture + /// Result of invocation + private object InvokeHelperStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + try + { + return this.type.InvokeMember(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, null, args, culture); + } + catch (TargetInvocationException e) + { + //Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + if (e.InnerException != null) + { + throw e.InnerException; + } + + throw; + } + } + } + + /// + /// The helper. + /// + internal static class Helper + { + /// + /// The check parameter not null. + /// + /// + /// The parameter. + /// + /// + /// The parameter name. + /// + /// + /// The message. + /// + /// Throws argument null exception when parameter is null. + internal static void CheckParameterNotNull(object param, string parameterName, string message) + { + if (param == null) + { + throw new ArgumentNullException(parameterName, message); + } + } + + ///// + ///// The check parameter not null or empty. + ///// + ///// + ///// The parameter. + ///// + ///// + ///// The parameter name. + ///// + ///// + ///// The message. + ///// + ///// Throws ArgumentException when parameter is null. + //internal static void CheckParameterNotNullOrEmpty(string param, string parameterName, string message) + //{ + // if (string.IsNullOrEmpty(param)) + // { + // throw new ArgumentException(message, parameterName); + // } + //} + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs new file mode 100644 index 000000000..d1cda7286 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs @@ -0,0 +1,22 @@ +using System; +using System.IO; + +namespace TensorFlowNET.UnitTest +{ + public class TestHelper + { + public static string GetFullPathFromDataDir(string fileName) + { + var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); + return Path.Combine(dataDir, fileName); + } + + static string GetRootContentDir(string dir) + { + var path = Path.GetFullPath(Path.Combine(dir, "data")); + if (Directory.Exists(path)) + return path; + return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb b/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb new file mode 100644 index 000000000..f37debb5a Binary files /dev/null and b/test/TensorFlowNET.UnitTest/Utilities/models/example1/saved_model.pb differ diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs deleted file mode 100644 index a761d93e8..000000000 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; - -namespace TensorFlowNET.UnitTest -{ - [TestClass] - public class VariableTest - { - public void Creating() - { - var mammal = tf.Variable("Elephant", tf.chars); - } - } -} diff --git a/test/Tensorflow.UnitTest/PythonTest.cs b/test/Tensorflow.UnitTest/PythonTest.cs new file mode 100644 index 000000000..1ccd39f02 --- /dev/null +++ b/test/Tensorflow.UnitTest/PythonTest.cs @@ -0,0 +1,555 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json.Linq; +using Tensorflow.NumPy; +using System.Collections; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + /// + /// Use as base class for test classes to get additional assertions + /// + public class PythonTest + { + #region python compatibility layer + protected PythonTest self { get => this; } + protected int None => -1; + #endregion + + #region pytest assertions + + public void assertItemsEqual(ICollection given, ICollection expected) + { + if (given is Hashtable && expected is Hashtable) + { + Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); + return; + } + Assert.IsNotNull(expected); + Assert.IsNotNull(given); + var e = expected.OfType().ToArray(); + var g = given.OfType().ToArray(); + Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}"); + for (int i = 0; i < e.Length; i++) + { + /*if (g[i] is NDArray && e[i] is NDArray) + assertItemsEqual((g[i] as NDArray).GetData(), (e[i] as NDArray).GetData()); + else*/ + if (e[i] is ICollection && g[i] is ICollection) + assertEqual(g[i], e[i]); + else + Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); + } + } + + public void assertAllEqual(ICollection given, ICollection expected) + { + assertItemsEqual(given, expected); + } + + public void assertFloat32Equal(float expected, float actual, string msg) + { + float eps = 1e-6f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + public void assertFloat64Equal(double expected, double actual, string msg) + { + double eps = 1e-16f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + public void AssetSequenceEqual(float[] expected, float[] actual) + { + float eps = 1e-5f; + for (int i = 0; i < expected.Length; i++) + Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); + } + + public void AssetSequenceEqual(double[] expected, double[] actual) + { + double eps = 1e-5f; + for (int i = 0; i < expected.Length; i++) + Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); + } + + public void assertEqual(object given, object expected) + { + /*if (given is NDArray && expected is NDArray) + { + assertItemsEqual((given as NDArray).GetData(), (expected as NDArray).GetData()); + return; + }*/ + if (given is Hashtable && expected is Hashtable) + { + Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); + return; + } + if (given is ICollection collectionGiven && expected is ICollection collectionExpected) + { + assertItemsEqual(collectionGiven, collectionExpected); + return; + } + if (given is float && expected is float) + { + assertFloat32Equal((float)expected, (float)given, ""); + return; + } + if (given is double && expected is double) + { + assertFloat64Equal((double)expected, (double)given, ""); + return; + } + Assert.AreEqual(expected, given); + } + + public void assertEquals(object given, object expected) + { + assertEqual(given, expected); + } + + public void assert(object given) + { + if (given is bool) + Assert.IsTrue((bool)given); + Assert.IsNotNull(given); + } + + public void assertIsNotNone(object given) + { + Assert.IsNotNull(given); + } + + public void assertFalse(bool cond) + { + Assert.IsFalse(cond); + } + + public void assertTrue(bool cond) + { + Assert.IsTrue(cond); + } + + public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5) + { + CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps)); + + //TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); + } + + public void assertAllClose(double value, NDArray array2, double eps = 1e-5) + { + if (array2.shape.IsScalar) + { + double value2 = array2; + Assert.AreEqual(value, value2, eps); + return; + } + var array1 = np.ones_like(array2) * value; + CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps)); + + //TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); + } + + private class CollectionComparer : IComparer + { + private readonly double _epsilon; + + public CollectionComparer(double eps = 1e-06) + { + _epsilon = eps; + } + public int Compare(object? x, object? y) + { + if (x == null && y == null) + { + return 0; + } + else if (x == null) + { + return -1; + } + else if (y == null) + { + return 1; + } + + var a = Convert.ToDouble(x); + var b = Convert.ToDouble(y); + + double delta = Math.Abs(a - b); + if (delta < _epsilon) + { + return 0; + } + return a.CompareTo(b); + } + } + + public void assertAllCloseAccordingToType( + double[,] expected, + T[,] given, + double eps = 1e-6, + float float_eps = 1e-6f) + { + Assert.AreEqual(expected.GetLength(0), given.GetLength(0)); + Assert.AreEqual(expected.GetLength(1), given.GetLength(1)); + + var flattenGiven = given.Cast().ToArray(); + assertAllCloseAccordingToType(expected, flattenGiven, eps, float_eps); + } + + public void assertAllCloseAccordingToType( + ICollection expected, + ICollection given, + double eps = 1e-6, + float float_eps = 1e-6f) + { + // TODO: check if any of arguments is not double and change toletance + // remove givenAsDouble and cast expected instead + var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray(); + CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps)); + } + + public void assertProtoEquals(object toProto, object o) + { + throw new NotImplementedException(); + } + + #endregion + + #region tensor evaluation and test session + + private Session? _cached_session = null; + private Graph? _cached_graph = null; + private object? _cached_config = null; + private bool _cached_force_gpu = false; + + private void _ClearCachedSession() + { + if (self._cached_session != null) + { + self._cached_session.Dispose(); + self._cached_session = null; + } + } + + //protected object _eval_helper(Tensor[] tensors) + //{ + // if (tensors == null) + // return null; + // return nest.map_structure(self._eval_tensor, tensors); + //} + + protected object? _eval_tensor(object tensor) + { + if (tensor == null) + return None; + //else if (callable(tensor)) + // return self._eval_helper(tensor()) + else + { + try + { + //TODO: + // if sparse_tensor.is_sparse(tensor): + // return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values, + // tensor.dense_shape) + //return (tensor as Tensor).numpy(); + } + catch (Exception) + { + throw new ValueError("Unsupported type: " + tensor.GetType()); + } + return null; + } + } + + /// + /// This function is used in many original tensorflow unit tests to evaluate tensors + /// in a test session with special settings (for instance constant folding off) + /// + /// + public T evaluate(Tensor tensor) + { + object? result = null; + // if context.executing_eagerly(): + // return self._eval_helper(tensors) + // else: + { + var sess = tf.get_default_session(); + var ndarray = tensor.eval(sess); + + if (typeof(T) == typeof(int)) + { + int i = ndarray; + result = i; + } + else if (typeof(T) == typeof(float)) + { + float f = ndarray; + result = f; + } + else if (typeof(T) == typeof(double)) + { + double d = ndarray; + result = d; + } + else if ( + typeof(T) == typeof(double[]) + || typeof(T) == typeof(double[,])) + { + result = ndarray.ToMultiDimArray(); + } + else if (typeof(T) == typeof(float[]) + || typeof(T) == typeof(float[,])) + { + result = ndarray.ToMultiDimArray(); + } + else if (typeof(T) == typeof(int[]) + || typeof(T) == typeof(int[,])) + { + result = ndarray.ToMultiDimArray(); + } + else + { + result = ndarray; + } + + return (T)result; + } + } + + + ///Returns a TensorFlow Session for use in executing tests. + public Session? cached_session( + Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false) + { + // This method behaves differently than self.session(): for performance reasons + // `cached_session` will by default reuse the same session within the same + // test.The session returned by this function will only be closed at the end + // of the test(in the TearDown function). + + // Use the `use_gpu` and `force_gpu` options to control where ops are run.If + // `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if + // `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as + // possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to + // the CPU. + + // Example: + // python + // class MyOperatorTest(test_util.TensorFlowTestCase) : + // def testMyOperator(self): + // with self.cached_session() as sess: + // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] + // result = MyOperator(valid_input).eval() + // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] + // invalid_input = [-1.0, 2.0, 7.0] + // with self.assertRaisesOpError("negative input not supported"): + // MyOperator(invalid_input).eval() + + + // Args: + // graph: Optional graph to use during the returned session. + // config: An optional config_pb2.ConfigProto to use to configure the + // session. + // use_gpu: If True, attempt to run as many ops as possible on GPU. + // force_gpu: If True, pin all ops to `/device:GPU:0`. + + // Yields: + // A Session object that should be used as a context manager to surround + // the graph building and execution code in a test case. + + + // TODO: + // if context.executing_eagerly(): + // return self._eval_helper(tensors) + // else: + { + var sess = self._get_cached_session( + graph, config, force_gpu, crash_if_inconsistent_args: true); + using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); + return cached; + } + } + + //Returns a TensorFlow Session for use in executing tests. + public Session session(Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false) + { + //Note that this will set this session and the graph as global defaults. + + //Use the `use_gpu` and `force_gpu` options to control where ops are run.If + //`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if + //`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as + //possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to + //the CPU. + + //Example: + //```python + //class MyOperatorTest(test_util.TensorFlowTestCase): + // def testMyOperator(self): + // with self.session(use_gpu= True): + // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] + // result = MyOperator(valid_input).eval() + // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] + // invalid_input = [-1.0, 2.0, 7.0] + // with self.assertRaisesOpError("negative input not supported"): + // MyOperator(invalid_input).eval() + //``` + + //Args: + // graph: Optional graph to use during the returned session. + // config: An optional config_pb2.ConfigProto to use to configure the + // session. + // use_gpu: If True, attempt to run as many ops as possible on GPU. + // force_gpu: If True, pin all ops to `/device:GPU:0`. + + //Yields: + // A Session object that should be used as a context manager to surround + // the graph building and execution code in a test case. + + Session? s = null; + //if (context.executing_eagerly()) + // yield None + //else + //{ + s = self._create_session(graph, config, force_gpu); + //} + return s.as_default(); + } + + private Session? _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) + { + // Set the session and its graph to global default and constrain devices.""" + if (tf.executing_eagerly()) + return null; + else + { + sess.graph.as_default(); + sess.as_default(); + { + if (force_gpu) + { + // TODO: + + // Use the name of an actual device if one is detected, or + // '/device:GPU:0' otherwise + /* var gpu_name = gpu_device_name(); + if (!gpu_name) + gpu_name = "/device:GPU:0" + using (sess.graph.device(gpu_name)) { + yield return sess; + }*/ + return sess; + } + else if (use_gpu) + return sess; + else + using (sess.graph.device("/device:CPU:0")) + return sess; + } + + } + } + + // See session() for details. + private Session _create_session(Graph? graph, object? cfg, bool forceGpu) + { + var prepare_config = new Func((config) => + { + // """Returns a config for sessions. + // Args: + // config: An optional config_pb2.ConfigProto to use to configure the + // session. + // Returns: + // A config_pb2.ConfigProto object. + + //TODO: config + + // # use_gpu=False. Currently many tests rely on the fact that any device + // # will be used even when a specific device is supposed to be used. + // allow_soft_placement = not force_gpu + // if config is None: + // config = config_pb2.ConfigProto() + // config.allow_soft_placement = allow_soft_placement + // config.gpu_options.per_process_gpu_memory_fraction = 0.3 + // elif not allow_soft_placement and config.allow_soft_placement: + // config_copy = config_pb2.ConfigProto() + // config_copy.CopyFrom(config) + // config = config_copy + // config.allow_soft_placement = False + // # Don't perform optimizations for tests so we don't inadvertently run + // # gpu ops on cpu + // config.graph_options.optimizer_options.opt_level = -1 + // # Disable Grappler constant folding since some tests & benchmarks + // # use constant input and become meaningless after constant folding. + // # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE + // # GRAPPLER TEAM. + // config.graph_options.rewrite_options.constant_folding = ( + // rewriter_config_pb2.RewriterConfig.OFF) + // config.graph_options.rewrite_options.pin_to_host_optimization = ( + // rewriter_config_pb2.RewriterConfig.OFF) + return config; + }); + //TODO: use this instead of normal session + //return new ErrorLoggingSession(graph = graph, config = prepare_config(config)) + return new Session(graph);//, config = prepare_config(config)) + } + + private Session _get_cached_session( + Graph? graph = null, + object? config = null, + bool force_gpu = false, + bool crash_if_inconsistent_args = true) + { + // See cached_session() for documentation. + if (self._cached_session == null) + { + var sess = self._create_session(graph, config, force_gpu); + self._cached_session = sess; + self._cached_graph = graph; + self._cached_config = config; + self._cached_force_gpu = force_gpu; + return sess; + } + else + { + + if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph)) + throw new ValueError(@"The graph used to get the cached session is + different than the one that was used to create the + session. Maybe create a new session with + self.session()"); + if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config)) + { + throw new ValueError(@"The config used to get the cached session is + different than the one that was used to create the + session. Maybe create a new session with + self.session()"); + } + if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) + { + throw new ValueError(@"The force_gpu value used to get the cached session is + different than the one that was used to create the + session. Maybe create a new session with + self.session()"); + } + return self._cached_session; + } + } + + [TestCleanup] + public void Cleanup() + { + _ClearCachedSession(); + } + + #endregion + + public void AssetSequenceEqual(T[] a, T[] b) + { + Assert.IsTrue(Enumerable.SequenceEqual(a, b)); + } + } +} diff --git a/test/Tensorflow.UnitTest/Tensorflow.UnitTest.csproj b/test/Tensorflow.UnitTest/Tensorflow.UnitTest.csproj new file mode 100644 index 000000000..9ad6bc7a5 --- /dev/null +++ b/test/Tensorflow.UnitTest/Tensorflow.UnitTest.csproj @@ -0,0 +1,24 @@ + + + + net6.0 + enable + enable + + false + true + + + + + + + + + + + + + + + diff --git a/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs b/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs new file mode 100644 index 000000000..b9a8ed804 --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs @@ -0,0 +1,47 @@ +using static Tensorflow.Binding; +using static Tensorflow.HubAPI; + +namespace Tensorflow.Hub.Unittest +{ + [TestClass] + public class KerasLayerTest + { + [Ignore] + [TestMethod] + public void SmallBert() + { + var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); + + var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); + input_type_ids = tf.reshape(input_type_ids, (1, 128)); + var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); + input_word_ids = tf.reshape(input_word_ids, (1, 128)); + var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); + input_mask = tf.reshape(input_mask, (1, 128)); + + var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); + } + + } +} \ No newline at end of file diff --git a/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj new file mode 100644 index 000000000..c93b89256 --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj @@ -0,0 +1,23 @@ + + + + net6 + enable + enable + + false + + + + + + + + + + + + + + + diff --git a/test/TensorflowNET.Hub.Unittest/Usings.cs b/test/TensorflowNET.Hub.Unittest/Usings.cs new file mode 100644 index 000000000..ab67c7ea9 --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/Usings.cs @@ -0,0 +1 @@ +global using Microsoft.VisualStudio.TestTools.UnitTesting; \ No newline at end of file diff --git a/tools/TensorFlowNET.Benchmarks/Crash/RepeatDataSetCrash.cs b/tools/TensorFlowNET.Benchmarks/Crash/RepeatDataSetCrash.cs new file mode 100644 index 000000000..76ba7c281 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Crash/RepeatDataSetCrash.cs @@ -0,0 +1,28 @@ +using BenchmarkDotNet.Attributes; +using System; +using System.Collections.Generic; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Benchmark.Crash +{ + public class RepeatDataSetCrash + { + [Benchmark] + public void Run() + { + var data = tf.convert_to_tensor(np.arange(0, 50000 * 10).astype(np.float32).reshape((50000, 10))); + + var dataset = keras.preprocessing.timeseries_dataset_from_array(data, + sequence_length: 10, + sequence_stride: 1, + shuffle: false, + batch_size: 32); + + while (true) + foreach (var d in dataset) + ; + } + } +} diff --git a/tools/TensorFlowNET.Benchmarks/Leak/GpuLeakByCNN.cs b/tools/TensorFlowNET.Benchmarks/Leak/GpuLeakByCNN.cs new file mode 100644 index 000000000..ed4e69cc8 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Leak/GpuLeakByCNN.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.NumPy; +using Tensorflow.Keras; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using BenchmarkDotNet.Attributes; + +namespace Tensorflow.Benchmark.Leak +{ + public class GpuLeakByCNN + { + protected static LayersApi layers = new LayersApi(); + [Benchmark] + public void Run() + { + // tf.debugging.set_log_device_placement(true); + tf.Context.Config.GpuOptions.AllowGrowth = true; + + int num = 50, width = 64, height = 64; + // if width = 128, height = 128, the exception occurs faster + + var bytes = new byte[num * width * height * 3]; + var inputImages = np.array(bytes) / 255.0f; + // inputImages = inputImages.reshape((num, height, width, 3)); + + bytes = new byte[num]; + var outLables = np.array(bytes); + Console.WriteLine("Image.Shape={0}", inputImages.dims); + Console.WriteLine("Label.Shape={0}", outLables.dims); + + tf.enable_eager_execution(); + + var inputs = keras.Input((height, width, 3)); + + var layer = layers.Conv2D(32, (3, 3), activation: keras.activations.Relu).Apply(inputs); + layer = layers.MaxPooling2D((2, 2)).Apply(layer); + + layer = layers.Flatten().Apply(layer); + + var outputs = layers.Dense(10).Apply(layer); + + var model = keras.Model(inputs, outputs, "gpuleak"); + + model.summary(); + + model.compile(loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), + optimizer: keras.optimizers.RMSprop(), + metrics: new[] { "accuracy" }); + + model.fit(inputImages, outLables, batch_size: 32, epochs: 200); + + keras.backend.clear_session(); + } + } +} diff --git a/tools/TensorFlowNET.Benchmarks/Leak/SavedModelCleanup.cs b/tools/TensorFlowNET.Benchmarks/Leak/SavedModelCleanup.cs new file mode 100644 index 000000000..9231f3a80 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Leak/SavedModelCleanup.cs @@ -0,0 +1,37 @@ +using BenchmarkDotNet.Attributes; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.NumPy; +using static Tensorflow.Binding; + +namespace Tensorflow.Benchmark.Leak +{ + /// + /// https://github.com/SciSharp/TensorFlow.NET/issues/418 + /// + public class SavedModelCleanup + { + [Benchmark] + public void Run() + { + var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); + + for (var i = 0; i < 1024; i++) + { + var sess = Session.LoadFromSavedModel(ClassifierModelPath); + var g = sess.graph.as_default(); + var inputOp = g.OperationByName("inference_input"); + var outputOp = g.OperationByName("StatefulPartitionedCall"); + + var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); + sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); + } + } + } +} diff --git a/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/saved_model.pb b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/saved_model.pb new file mode 100644 index 000000000..f75f28564 Binary files /dev/null and b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/saved_model.pb differ diff --git a/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..4c7f99dba Binary files /dev/null and b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 differ diff --git a/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.index b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.index new file mode 100644 index 000000000..ee0efb7c0 Binary files /dev/null and b/tools/TensorFlowNET.Benchmarks/Leak/TestModel/saved_model/variables/variables.index differ diff --git a/tools/TensorFlowNET.Benchmarks/Program.cs b/tools/TensorFlowNET.Benchmarks/Program.cs new file mode 100644 index 000000000..22abf7302 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Program.cs @@ -0,0 +1,40 @@ +using BenchmarkDotNet.Configs; +using BenchmarkDotNet.Running; +using System; +using System.Reflection; +using Tensorflow.Benchmark.Crash; +using Tensorflow.Benchmark.Leak; +using static Tensorflow.Binding; + +namespace TensorFlowBenchmark +{ + class Program + { + static void Main(string[] args) + { + print(tf.VERSION); + + /*new SavedModelCleanup().Run(); + new RepeatDataSetCrash().Run(); + new GpuLeakByCNN().Run();*/ + + if (args?.Length > 0) + { + for (int i = 0; i < args.Length; i++) + { + string name = $"TensorFlowBenchmark.{args[i]}"; + var type = Type.GetType(name); + BenchmarkRunner.Run(type); + } + } + else + { +#pragma warning disable CS0618 // Type or member is obsolete + BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); +#pragma warning restore CS0618 // Type or member is obsolete + } + + Console.ReadLine(); + } + } +} diff --git a/tools/TensorFlowNET.Benchmarks/README.md b/tools/TensorFlowNET.Benchmarks/README.md new file mode 100644 index 000000000..29a915691 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/README.md @@ -0,0 +1,4 @@ +```powershell +dotnet run -c release +``` + diff --git a/tools/TensorFlowNET.Benchmarks/TensorBenchmark.cs b/tools/TensorFlowNET.Benchmarks/TensorBenchmark.cs new file mode 100644 index 000000000..fa99755e2 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/TensorBenchmark.cs @@ -0,0 +1,73 @@ +using BenchmarkDotNet.Attributes; + +namespace TensorFlowBenchmark +{ + [SimpleJob(launchCount: 1, warmupCount: 1)] + [MinColumn, MaxColumn, MeanColumn, MedianColumn] + public class TensorBenchmark + { + private double[] data; + + [GlobalSetup] + public void Setup() + { + data = new double[100]; + } + + /*[Benchmark] + public void ScalarTensor() + { + var g = new Graph(); + for (int i = 0; i < 100; i++) + { + using (var tensor = new Tensor(17.0)) + { + + } + } + } + + [Benchmark] + public unsafe void TensorFromFixedPtr() + { + var g = new Graph(); + for (int i = 0; i < 100; i++) + { + fixed (double* ptr = &data[0]) + { + using (var t = new Tensor((IntPtr)ptr, new long[] { data.Length }, tf.float64, 8 * data.Length)) + { + } + } + } + } + + [Benchmark] + public void TensorFromArray() + { + var g=new Graph(); + for (int i = 0; i < 100; i++) + { + using (var tensor = new Tensor(data)) + { + + } + } + } + + + [Benchmark] + public void TensorFromNDArray() + { + var g = new Graph(); + for (int i = 0; i < 100; i++) + { + using (var tensor = new Tensor(new NDArray(data))) + { + + } + } + }*/ + } +} + diff --git a/tools/TensorFlowNET.Benchmarks/Tensorflow.Benchmark.csproj b/tools/TensorFlowNET.Benchmarks/Tensorflow.Benchmark.csproj new file mode 100644 index 000000000..dd6f9538b --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Tensorflow.Benchmark.csproj @@ -0,0 +1,63 @@ + + + + Exe + net6.0 + AnyCPU;x64 + + + + true + DEBUG;TRACE + x64 + + + + true + DEBUG;TRACE + + + + true + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/tools/TensorFlowNET.Benchmarks/Unmanaged/StructCastBenchmark.cs b/tools/TensorFlowNET.Benchmarks/Unmanaged/StructCastBenchmark.cs new file mode 100644 index 000000000..6e2b71605 --- /dev/null +++ b/tools/TensorFlowNET.Benchmarks/Unmanaged/StructCastBenchmark.cs @@ -0,0 +1,72 @@ +using BenchmarkDotNet.Attributes; +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace TensorFlowBenchmark.Unmanaged +{ + public struct UnmanagedStruct + { + public int a; + public long b; + public UnmanagedStruct(int _) + { + a = 2; + b = 3; + } + } + + [SimpleJob(launchCount: 1, warmupCount: 2)] + [MinColumn, MaxColumn, MeanColumn, MedianColumn] + public unsafe class StructCastBenchmark + { + private static void EnsureIsUnmanaged(T _) where T : unmanaged + { } + + static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. + => EnsureIsUnmanaged(new UnmanagedStruct()); + + private IntPtr data; + private void* dataptr; + + [GlobalSetup] + public void Setup() + { + data = Marshal.AllocHGlobal(Marshal.SizeOf()); + dataptr = data.ToPointer(); + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Marshal_PtrToStructure() + { + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Marshal.PtrToStructure(data); + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void PointerCast() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = *(UnmanagedStruct*)dptr; + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Unsafe_Read() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Unsafe.Read(dptr); + } + } + + } +} \ No newline at end of file diff --git a/tools/TensorFlowNET.Console/Diagnostician.cs b/tools/TensorFlowNET.Console/Diagnostician.cs new file mode 100644 index 000000000..c52be7737 --- /dev/null +++ b/tools/TensorFlowNET.Console/Diagnostician.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Linq; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; + +namespace Tensorflow +{ + public class Diagnostician + { + public void Diagnose(string log) + { + var lines = File.ReadAllLines(log); + + foreach(var (i, line) in enumerate(lines)) + { + if(line.StartsWith("New Tensor ")) + { + var pointers = Regex.Matches(line, "0x[0-9a-f]{16}"); + var tensorHandle = pointers[0].Value; + var tensorDataHandle = pointers[1].Value; + + if (lines.Skip(i).Count(x => x.StartsWith("Delete Tensor ") + && x.Contains(tensorHandle) + && x.Contains(tensorDataHandle)) == 0) + Console.WriteLine(line); + } + else if (line.StartsWith("New EagerTensorHandle ")) + { + var pointers = Regex.Matches(line, "0x[0-9a-f]{16}"); + var tensorHandle = pointers[0].Value; + + var del = $"Delete EagerTensorHandle {tensorHandle}"; + + if (lines.Skip(i).Count(x => x == del) == 0) + Console.WriteLine(line); + } + else if (line.StartsWith("Take EagerTensorHandle ")) + { + var pointers = Regex.Matches(line, "0x[0-9a-f]{16}"); + var eagerTensorHandle = pointers[0].Value; + var tensorHandle = pointers[1].Value; + + var delTensor = $"Delete Tensor {tensorHandle}"; + var delEagerTensor = $"Delete EagerTensorHandle {eagerTensorHandle}"; + if (lines.Skip(i).Count(x => x.StartsWith(delTensor)) == 0 + || lines.Skip(i).Count(x => x.StartsWith(delEagerTensor)) == 0) + Console.WriteLine(line); + } + else if (line.StartsWith("Created Resource ")) + { + var pointers = Regex.Matches(line, "0x[0-9a-f]{16}"); + var eagerTensorHandle = pointers[0].Value; + + var delTensor = $"Deleted Resource {eagerTensorHandle}"; + if (lines.Skip(i).Count(x => x.StartsWith(delTensor)) == 0) + Console.WriteLine(line); + } + } + } + } +} diff --git a/tools/TensorFlowNET.Console/Exploring.cs b/tools/TensorFlowNET.Console/Exploring.cs new file mode 100644 index 000000000..4241c9bf3 --- /dev/null +++ b/tools/TensorFlowNET.Console/Exploring.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.TextApi; + +namespace Tensorflow +{ + public class Exploring + { + public void Run() + { + var docs = tf.constant(new[] { "Everything not saved will be lost." }); + var tokenizer = text.WhitespaceTokenizer(); + text.wordshape(docs, Text.WordShape.HAS_TITLE_CASE); + + throw new NotImplementedException(""); + } + } +} diff --git a/tools/TensorFlowNET.Console/MemoryBasicTest.cs b/tools/TensorFlowNET.Console/MemoryBasicTest.cs new file mode 100644 index 000000000..2bb11a02d --- /dev/null +++ b/tools/TensorFlowNET.Console/MemoryBasicTest.cs @@ -0,0 +1,176 @@ +using Tensorflow.NumPy; +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using System.Linq; +using System.Collections.Generic; + +namespace Tensorflow +{ + class MemoryBasicTest + { + public Action Placeholder + => (epoch, iterate) => + { + var ph = array_ops.placeholder(tf.float32, (10, 512, 512, 3)); + }; + + /// + /// + /// + public Action Constant + => (epoch, iterate) => + { + var tensor = tf.constant(3112.0f); + }; + + public Action Constant2x3 + => (epoch, iterate) => + { + var nd = np.arange(1000).reshape((10, 100)); + var tensor = tf.constant(nd); + var data = tensor.numpy(); + }; + + public Action ConstantString + => (epoch, iterate) => + { + var strList = new string[] + { + "Biden immigration bill would put millions of illegal immigrants on 8-year fast-track to citizenship", + "The Associated Press, which also reported that the eight-year path is in the bill.", + "The bill would also include provisions to stem the flow of migration by addressing root causes of migration from south of the border." + }; + + var tensor = tf.constant(strList, TF_DataType.TF_STRING); + var data = tensor.StringData(); + }; + + public Action Variable + => (epoch, iterate) => + { + var nd = np.arange(1 * 256 * 256 * 3).reshape((1, 256, 256, 3)); + ResourceVariable variable = tf.Variable(nd); + }; + + public Action VariableRead + => (epoch, iterate) => + { + var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape((1, 256, 256, 3)); + ResourceVariable variable = tf.Variable(nd); + + for (int i = 0; i< 10; i++) + { + var v = variable.numpy(); + } + }; + + public Action VariableAssign + => (epoch, iterate) => + { + ResourceVariable variable = tf.Variable(3112f); + AssignVariable(variable); + for (int i = 0; i < 100; i++) + { + var v = variable.numpy(); + if ((float)v != 1984f) + throw new ValueError(""); + } + }; + + void AssignVariable(IVariableV1 v) + { + using var tensor = tf.constant(1984f); + v.assign(tensor); + } + + public Action MathAdd + => (epoch, iterate) => + { + var x = tf.constant(3112.0f); + var y = tf.constant(3112.0f); + var z = x + y; + }; + + public Action Gradient + => (epoch, iterate) => + { + var w = tf.constant(3112.0f); + using var tape = tf.GradientTape(); + tape.watch(w); + var loss = w * w; + var grad = tape.gradient(loss, w); + }; + + public Action Conv2DWithTensor + => (epoch, iterate) => + { + var input = array_ops.zeros((10, 32, 32, 3), dtypes.float32); + var filter = array_ops.zeros((3, 3, 3, 32), dtypes.float32); + var strides = new[] { 1, 1, 1, 1 }; + var dilations = new[] { 1, 1, 1, 1 }; + + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "Conv2D", null, input, filter) + { + attrs = ConvertToDict(new + { + strides, + use_cudnn_on_gpu = true, + padding = "VALID", + explicit_paddings = new int[0], + data_format = "NHWC", + dilations + }) + }); + }; + + public Action Conv2DWithVariable + => (epoch, iterate) => + { + var input = array_ops.zeros((10, 32, 32, 3), dtypes.float32); + var filter = tf.Variable(array_ops.zeros((3, 3, 3, 32), dtypes.float32)); + var strides = new[] { 1, 1, 1, 1 }; + var dilations = new[] { 1, 1, 1, 1 }; + + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "Conv2D", null, input, filter) + { + attrs = ConvertToDict(new + { + strides, + use_cudnn_on_gpu = true, + padding = "VALID", + explicit_paddings = new int[0], + data_format = "NHWC", + dilations + }) + }); + }; + + public Action Dataset + => (epoch, iterate) => + { + Shape shape = (16, 32, 32, 3); + var images = np.arange(shape.size).astype(np.float32).reshape(shape.dims); + var data_handler = new DataHandler(new DataHandlerArgs + { + X = images, + BatchSize = 2, + StepsPerEpoch = -1, + InitialEpoch = 0, + Epochs = 2, + MaxQueueSize = 10, + Workers = 1, + UseMultiprocessing = false, + StepsPerExecution = tf.Variable(1) + }); + + /*foreach (var (_epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach (var step in data_handler.steps()) + iterator.next(); + }*/ + }; + } +} diff --git a/tools/TensorFlowNET.Console/MemoryFuncGraphTest.cs b/tools/TensorFlowNET.Console/MemoryFuncGraphTest.cs new file mode 100644 index 000000000..8c7ccaaf2 --- /dev/null +++ b/tools/TensorFlowNET.Console/MemoryFuncGraphTest.cs @@ -0,0 +1,30 @@ +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + class MemoryFuncGraphTest + { + public Action ConcreteFunction + => (epoch, iterate) => + { + var func = new ConcreteFunction(Guid.NewGuid().ToString()); + func.Enter(); + var input = tf.placeholder(tf.float32); + var output = permutation(input); + func.ToGraph(input, output); + func.Exit(); + }; + + Tensor permutation(Tensor tensor) + { + Shape shape = (8, 64, 64, 3); + var images = np.arange(shape.size).astype(np.float32).reshape(shape.dims); + return tf.constant(images); + } + } +} diff --git a/tools/TensorFlowNET.Console/MemoryKerasTest.cs b/tools/TensorFlowNET.Console/MemoryKerasTest.cs new file mode 100644 index 000000000..5cd452ff0 --- /dev/null +++ b/tools/TensorFlowNET.Console/MemoryKerasTest.cs @@ -0,0 +1,51 @@ +using Tensorflow.NumPy; +using System; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow +{ + class MemoryKerasTest + { + public Action Conv2DLayer + => (epoch, iterate) => + { + var input_shape = new int[] { 4, 512, 512, 3 }; + var x = tf.random.normal(input_shape); + var conv2d = keras.layers.Conv2D(2, 3, activation: keras.activations.Relu); + var output = conv2d.Apply(x); + }; + + public Action InputLayer + => (epoch, iterate) => + { + Shape shape = (32, 256, 256, 3); // 48M + var images = np.arange(shape.size).astype(np.float32).reshape(shape.dims); + + var inputs = keras.Input((shape.dims[1], shape.dims[2], 3)); + var conv2d = keras.layers.Conv2D(32, kernel_size: (3, 3), + activation: keras.activations.Linear); + var outputs = conv2d.Apply(inputs); + }; + + public Action Prediction + => (epoch, iterate) => + { + Shape shape = (32, 256, 256, 3); // 48M + var images = np.arange(shape.size).astype(np.float32).reshape(shape.dims); + + var inputs = keras.Input((shape.dims[1], shape.dims[2], 3)); + var conv2d = keras.layers.Conv2D(32, kernel_size: (3, 3), + activation: keras.activations.Linear).Apply(inputs); + + var flatten = keras.layers.Flatten().Apply(inputs); + var outputs = keras.layers.Dense(10).Apply(flatten); + + var model = keras.Model(inputs, outputs, "prediction"); + for (int i = 0; i < 10; i++) + { + model.predict(images, batch_size: 8); + } + }; + } +} diff --git a/tools/TensorFlowNET.Console/MemoryMonitor.cs b/tools/TensorFlowNET.Console/MemoryMonitor.cs new file mode 100644 index 000000000..f9a6bfd1d --- /dev/null +++ b/tools/TensorFlowNET.Console/MemoryMonitor.cs @@ -0,0 +1,85 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow +{ + public class MemoryMonitor + { + public void WarmUp() + { + var x1 = tf.Variable(10, name: "x"); + + tf.compat.v1.disable_eager_execution(); + var input = np.array(4); + var nd = tf.reshape(input, new int[] { 1, 1}); + var z = nd[0, 0]; + while (true) + { + var x = tf.placeholder(tf.float64, shape: (1024, 1024)); + var log = tf.log(x); + + var sess = tf.Session(); + var ones = np.ones((1024, 1024), dtype: np.float64); + var o = sess.run(log, new FeedItem(x, ones)); + // Thread.Sleep(1); + } + + Shape shape = (1, 32, 32, 3); + np.arange(shape.size).astype(np.float32).reshape(shape.dims); + + print($"tensorflow native version: v{tf.VERSION}"); + tf.Context.ensure_initialized(); + var a = tf.constant(np.ones((10, 10))); + var b = tf.Variable(a); + var c = tf.Variable(b); + var d = b * c; + print(d.numpy()); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + public void Execute(int epoch, int iterate, Action process) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + var initialTotalMemory = Process.GetCurrentProcess().PrivateMemorySize64; + print($"{process.Method.Name} started..."); + + for (int i = 0; i < epoch; i++) + { + var initialMemory = Process.GetCurrentProcess().PrivateMemorySize64; + for (int j = 0; j < iterate; j++) + process(i, j); + + keras.backend.clear_session(); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + var finalMemory = Process.GetCurrentProcess().PrivateMemorySize64; + print($"Epoch {i}: {Format(finalMemory - initialMemory)}."); + } + + var finalTotalMemory = Process.GetCurrentProcess().PrivateMemorySize64; + print($"Memory usage difference: {Format(finalTotalMemory - initialTotalMemory)} / {Format(Process.GetCurrentProcess().PrivateMemorySize64)}"); + } + + private string Format(long usage) + { + if (usage < 0) + return $"-{Format(0 - usage)}"; + + if (usage <= 1024 && usage >= 0) + return $"{usage} Bytes"; + else if (usage > 1024 && usage <= 1024 * 1024) + return $"{usage / 1024} KB"; + else + return $"{usage / 1024 / 1024} MB"; + } + } +} diff --git a/tools/TensorFlowNET.Console/Program.cs b/tools/TensorFlowNET.Console/Program.cs new file mode 100644 index 000000000..5f12badb0 --- /dev/null +++ b/tools/TensorFlowNET.Console/Program.cs @@ -0,0 +1,97 @@ +using System; +using Tensorflow.Keras; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + class Program + { + static void Main(string[] args) + { + var diag = new Diagnostician(); + // diag.Diagnose(@"D:\memory.txt"); + + var rnn = new SimpleRnnTest(); + rnn.Run(); + + // this class is used explor new features. + var exploring = new Exploring(); + // exploring.Run(); + + // boot .net core 10.5M. + var mm = new MemoryMonitor(); + // warm up tensorflow.net 37.3M. + mm.WarmUp(); + + BasicTest(mm); + + KerasTest(mm); + + FuncGraph(mm); + + // 65M + Console.WriteLine("Finished."); + Console.ReadLine(); + } + + static void BasicTest(MemoryMonitor mm) + { + int batchSize = 1000; + + var basic = new MemoryBasicTest(); + + // 1 million placeholder + /*tf.compat.v1.disable_eager_execution(); + mm.Execute(10, 100 * batchSize, basic.Placeholder); + tf.enable_eager_execution();*/ + + // 1 million tensor + mm.Execute(10, 100 * batchSize, basic.Constant); + + // explaination of constant + mm.Execute(10, 100 * batchSize, basic.Constant2x3); + + mm.Execute(10, batchSize, basic.ConstantString); + + // 100K float variable. + mm.Execute(10, batchSize, basic.Variable); + + mm.Execute(10, batchSize, basic.VariableRead); + + mm.Execute(10, batchSize, basic.VariableAssign); + + // 1 million math. + mm.Execute(10, 100 * batchSize, basic.MathAdd); + + // Conv2d in constant tensor + mm.Execute(10, batchSize, basic.Conv2DWithTensor); + + // Conv2d in variable + mm.Execute(10, batchSize, basic.Conv2DWithVariable); + + // 100K gradient 44M. + mm.Execute(10, 10 * batchSize, basic.Gradient); + + // memory leak when increasing the epoch + mm.Execute(10, 10, basic.Dataset); + } + + static void KerasTest(MemoryMonitor mm) + { + var keras = new MemoryKerasTest(); + + // +1M (10,50) + mm.Execute(10, 1, keras.Conv2DLayer); + + mm.Execute(10, 50, keras.InputLayer); + + mm.Execute(10, 10, keras.Prediction); + } + + static void FuncGraph(MemoryMonitor mm) + { + var func = new MemoryFuncGraphTest(); + mm.Execute(10, 100, func.ConcreteFunction); + } + } +} diff --git a/tools/TensorFlowNET.Console/SimpleRnnTest.cs b/tools/TensorFlowNET.Console/SimpleRnnTest.cs new file mode 100644 index 000000000..ae6ebb8a8 --- /dev/null +++ b/tools/TensorFlowNET.Console/SimpleRnnTest.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow +{ + public class SimpleRnnTest + { + public void Run() + { + var inputs = np.random.random((6, 10, 8)).astype(np.float32); + //var simple_rnn = tf.keras.layers.SimpleRNN(4); + //var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. + + var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); + + // whole_sequence_output has shape `[32, 10, 4]`. + // final_state has shape `[32, 4]`. + var (whole_sequence_output, final_states) = simple_rnn.Apply(inputs); + } + } +} diff --git a/tools/TensorFlowNET.Console/Tensorflow.Console.csproj b/tools/TensorFlowNET.Console/Tensorflow.Console.csproj new file mode 100644 index 000000000..bb60b6b63 --- /dev/null +++ b/tools/TensorFlowNET.Console/Tensorflow.Console.csproj @@ -0,0 +1,28 @@ + + + + Exe + net6.0 + Tensorflow + Tensorflow + AnyCPU;x64 + 10.0 + + + + TRACE;DEBUG + x64 + + + + DEBUG;TRACE + AnyCPU + + + + + + + + + diff --git a/tools/Tensorflow.CodeGen/DescriptionGenerator.cs b/tools/Tensorflow.CodeGen/DescriptionGenerator.cs new file mode 100644 index 000000000..0437370a1 --- /dev/null +++ b/tools/Tensorflow.CodeGen/DescriptionGenerator.cs @@ -0,0 +1,263 @@ +using Microsoft.CodeAnalysis.CSharp; +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Metadata.Ecma335; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public class DescriptionGenerator + { + private static readonly string replaceStrInner = "~~%~~"; + private static readonly string replaceStrInnerQuotationMarks = "^%^"; + Dictionary> _opDescriptions = new Dictionary>(); + Dictionary _opDescriptionDefs = new Dictionary(); + public DescriptionGenerator(string apiDefDirectory) + { + DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); + + int errors = 0; + foreach (FileInfo file in directory.GetFiles()) + { + string target = file.Name.Split('.')[0].Split('_').Last(); + OpDef op = null; + try + { + op = ReadOpDefs(file.FullName).Op[0]; + } + catch + { + errors++; + continue; + } + _opDescriptionDefs[target] = op; + _opDescriptions[target] = new Dictionary(); + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + _opDescriptions[target][argName] = arg.Description ?? ""; + } + foreach (var arg in op.Attr) + { + var token = SyntaxFactory.ParseToken(arg.Name); + string realKey = arg.Name; + if (token.IsKeyword()) + { + realKey += "_"; + } + _opDescriptions[target][realKey] = arg.Description ?? ""; + } + _opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; + _opDescriptions[target]["DESC"] = op.Description ?? ""; + } + Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + + $"the failed files number is large, or ignore it."); + } + + /// + /// + /// + /// + /// + public void AppendDescription(OpDef fullOp, StringBuilder sb) + { + var opName = fullOp.Name; + if(_opDescriptions.TryGetValue(opName, out var op)) + { + var def = _opDescriptionDefs[opName]; + sb.AppendLine("/// "); + sb.AppendLine($"/// {op["SUMMARY"]}"); + sb.AppendLine("/// "); + + string totalDesc = op["DESC"]; + if (!string.IsNullOrEmpty(totalDesc)) + { + totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); + sb.AppendLine("/// "); + string[] lines = totalDesc.Split(replaceStrInner); + foreach (var line in lines) + { + sb.AppendLine($"/// {line}"); + } + sb.AppendLine("/// "); + } + + var argNames = GetInputArgNames(fullOp); + foreach (var argName in argNames) + { + if(op.TryGetValue(argName, out var desc)) + { + desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); + string[] lines = desc.Split(replaceStrInner); + sb.AppendLine($"/// "); + foreach (var line in lines) + { + sb.AppendLine($"/// {line}"); + } + sb.AppendLine("/// "); + } + else + { + sb.AppendLine($"/// "); + } + } + + List returnValueDescs = new(); + foreach (var arg in def.OutputArg) + { + if (!string.IsNullOrEmpty(arg.Description)) + { + returnValueDescs.Add($"{arg.Name}: {arg.Description}"); + } + } + string returnValueDesc = ""; + if (returnValueDescs.Count > 0) + { + returnValueDesc = string.Join(" && ", returnValueDescs); + } + sb.AppendLine($"/// {returnValueDesc}"); + } + else + { + sb.AppendLine("/// "); + sb.AppendLine($"///"); + sb.AppendLine("/// "); + + var argNames = GetInputArgNames(fullOp); + foreach (var argName in argNames) + { + sb.AppendLine($"/// "); + } + + sb.AppendLine($"/// "); + } + } + + /// + /// + /// + /// + /// + /// + /// + public List GetInputArgNames(OpDef op) + { + List names = new(); + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + names.Add(argName); + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); + foreach (var (key, typeStr, value) in attrValueDic) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + names.Add(realKey); + } + return names; + } + + private static OpList ReadOpDefs(string path) + { + var text = File.ReadAllText(path); + text = RemoveLintTags(text); + text = PreProcessText(text); + + string pattern = @"< { + string matchedText = match.Value; + string innerText = match.Groups[1].Value; + innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) + .Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 + return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 + }, RegexOptions.Multiline); + + var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse(replacedText); + return opDefs; + } + + static string PreProcessText(string input) + { + int depth = 0; + int endBlockDepth = -1; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < input.Length; i++) + { + char c = input[i]; + if (c == '{') + { + depth++; + sb.Append(c); + } + else if (c == '}') + { + if (depth == endBlockDepth) + { + sb.Append("END\n"); + endBlockDepth = -1; + } + sb.Append(c); + depth--; + } + else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "< x.IsRef, null); + sb.AppendLine($"throw new RuntimeError(\"{funcName} op does not support eager execution. Arg {possibleRefArg.Name} is a ref.\");"); + } + else + { + sb.Append("try\n{\n"); + + AppendFastPathExecute(op, sb); + if (outputArgsCount == 0) + { + sb.AppendLine("return null;"); + } + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) + { + sb.AppendLine("return _fast_path_result[0];"); + } + else + { + sb.AppendLine("return _fast_path_result;"); + } + + sb.AppendLine("}"); // try + + sb.Append("catch(NotOkStatusException ex1)\n{\n"); + sb.AppendLine("throw ex1;"); + sb.AppendLine("}"); // catch + + sb.Append("catch(InvalidArgumentError ex2)\n{\n"); + sb.AppendLine("throw ex2;"); + sb.AppendLine("}"); // catch + + sb.Append("catch(Exception)\n{\n"); + sb.AppendLine("}"); // catch + + sb.Append("try\n{\n"); + AppendEagerFallbackCall(op, sb); + sb.AppendLine("}"); // try + + sb.Append("catch(Exception)\n{\n"); + sb.AppendLine("}"); // catch + } + + sb.AppendLine("}"); // if + + foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string")) + { + if(value != "NOVALUE") + { + sb.AppendLine($"if({name} is null)"); + sb.AppendLine("{"); + sb.AppendLine($"{name} = {value};"); + sb.AppendLine("}"); + } + } + + // begin to use op helper. + AppendOpHelperCall(op, sb); + sb.AppendLine("var _result = _op.outputs;"); + + // check if it needs to record gradient. + sb.Append("if(_execute.must_record_gradient())\n{\n"); + sb.Append("object[] _attrs = new object[]{"); + foreach (var attr in op.Attr) + { + string attrRealName = attr.Name; + if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) + { + attrRealName += "_"; + } + if (attr.Type == "type") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_type(\"{attrRealName}\"), "); + } + else if (attr.Type == "int") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_int(\"{attrRealName}\"), "); + } + else if (attr.Type == "bool") + { + sb.Append($"\"{attr.Name}\", _op._get_attr_bool(\"{attrRealName}\"), "); + } + else + { + sb.Append($"\"{attr.Name}\", _op.get_attr(\"{attr.Name}\"), "); + } + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _op.inputs, _attrs, _result);"); + + sb.AppendLine("}"); // if + + if (outputArgsCount == 0) + { + sb.AppendLine("return _op;"); + } + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) + { + sb.AppendLine("return _result[0];"); + } + else + { + sb.AppendLine("return _result;"); + } + sb.AppendLine("}"); // body + + sb.AppendLine(); + + AppendEagerFallbackDefinition(op, sb); + } + + public void AppendArgs(OpDef op, StringBuilder sb) + { + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr)) + { + sb.Append($"Tensors {argName}, "); + } + else + { + sb.Append($"Tensor {argName}, "); + } + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); + foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + sb.Append($"{typeStr} {realKey}, "); + } + foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE")) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + sb.Append($"{typeStr} {realKey} = {value}, "); + } + sb.Append($"string? name = null"); + } + + public void AppendFastPathExecute(OpDef op, StringBuilder sb) + { + sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)"); + sb.Append("{ args = new object[]{ "); + foreach (var arg in op.InputArg) + { + string attrArgName = arg.Name; + if (SyntaxFactory.ParseToken(attrArgName).IsKeyword()) + { + attrArgName += "_"; + } + sb.Append($"{attrArgName}, "); + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + + sb.Append("}, attrs = new Dictionary(){ "); + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) + { + sb.Append($"[\"{key}\"] = {key}, "); + } + + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("}});\n"); + } + + public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) + { + string funcName = $"{Utils.ConvertToUnderscore(op.Name)}_eager_fallback"; + sb.Append($"return {funcName}("); + foreach (var arg in op.InputArg) + { + string inputArgRealName = arg.Name; + if (SyntaxFactory.ParseToken(inputArgRealName).IsKeyword()) + { + inputArgRealName += "_"; + } + sb.Append($"{inputArgRealName}, "); + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) + { + string keyRealName = key; + if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) + { + keyRealName += "_"; + } + sb.Append($"{key}: {keyRealName}, "); + } + sb.Append("name: name, ctx: _ctx);\n"); + } + + public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) + { + sb.Append("public static "); + int outputArgsCount = op.OutputArg.Count; + if (outputArgsCount == 0) + { + sb.Append("Operation "); + } + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) + { + sb.Append("Tensor "); + } + else + { + sb.Append("Tensor[] "); + } + string opName = op.Name; + string funcName = Utils.ConvertToUnderscore(op.Name); + sb.Append($" {funcName}_eager_fallback("); + AppendFallBackFunctionArgs(op, sb); + sb.Append(")\n{\n"); + + var possibleRefArg = op.InputArg.FirstOrDefault(x => x.IsRef, null); + if (possibleRefArg is not null) + { + sb.AppendLine($"throw new RuntimeError($\"{funcName} op does not support eager execution." + + $" Arg '{possibleRefArg.Name}' is a ref.\");"); + sb.AppendLine("}"); // body + return; + } + + if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr))) + { + sb.AppendLine("List _inputs_flat_list = new();"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (string.IsNullOrEmpty(arg.NumberAttr)) + { + sb.AppendLine($"_inputs_flat_list.Add({realArgName});"); + } + else + { + sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});"); + } + } + sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();"); + } + else + { + sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + sb.Append($"{realArgName}, "); + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + } + + sb.Append("object[] _attrs = new object[]{"); + foreach (var attr in op.Attr) + { + if (attr.Type == "type") + { + bool found = false; + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (arg.TypeAttr == attr.Name) + { + sb.Append($"\"{attr.Name}\", {realArgName}.dtype, "); + found = true; + break; + } + } + if (!found) + { + string attrRealName = attr.Name; + if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) + { + attrRealName = $"{attrRealName}_"; + } + sb.Append($"\"{attr.Name}\", {attrRealName}, "); + } + } + else if(attr.Type == "list(type)") + { + if (op.InputArg.Any(x => x.TypeListAttr == attr.Name)) + { + continue; + } + } + else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) + { + bool found = false; + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (arg.NumberAttr == attr.Name) + { + sb.Append($"\"{attr.Name}\", {realArgName}.Length, "); + found = true; + break; + } + } + } + else + { + sb.Append($"\"{attr.Name}\", {attr.Name}, "); + } + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); + + sb.AppendLine($"var _result = _execute.execute(\"{op.Name}\", {outputArgsCount}, inputs: _inputs_flat, " + + $"attrs: _attrs, ctx: ctx, name: name);"); + + sb.Append("if(_execute.must_record_gradient())\n{\n"); + + sb.AppendLine($"_execute.record_gradient(\"{op.Name}\", _inputs_flat, _attrs, _result);"); + + sb.AppendLine("}"); // if + + if (outputArgsCount == 0) + { + sb.AppendLine("return null;"); + } + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) + { + sb.AppendLine("return _result[0];"); + } + else + { + sb.AppendLine("return _result;"); + } + + sb.AppendLine("}"); // body + } + + public void AppendFallBackFunctionArgs(OpDef op, StringBuilder sb) + { + foreach (var arg in op.InputArg) + { + string argName = arg.Name; + var token = SyntaxFactory.ParseToken(argName); + if (token.IsKeyword()) + { + argName = $"{argName}_"; + } + if (!string.IsNullOrEmpty(arg.NumberAttr)) + { + sb.Append($"Tensors {argName}, "); + } + else + { + sb.Append($"Tensor {argName}, "); + } + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); + foreach (var (key, typeStr, _) in attrValueDic) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) + { + realKey += "_"; + } + sb.Append($"{typeStr} {realKey}, "); + } + sb.Append($"string name, Context ctx"); + } + + public void AppendOpHelperCall(OpDef op, StringBuilder sb) + { + sb.AppendLine("Dictionary keywords = new();"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName += "_"; + } + sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); + } + var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) + { + sb.AppendLine($"keywords[\"{key}\"] = {key};"); + } + sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); + } + + private static bool HasRefArgs(OpDef op) + { + return op.InputArg.Any(x => x.IsRef); + } + } +} diff --git a/tools/Tensorflow.CodeGen/GenOpsWriter.cs b/tools/Tensorflow.CodeGen/GenOpsWriter.cs new file mode 100644 index 000000000..9eefca07e --- /dev/null +++ b/tools/Tensorflow.CodeGen/GenOpsWriter.cs @@ -0,0 +1,81 @@ +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public class GenOpsWriter + { + private string _basePath; + private Dictionary _opMap; + private OpClassifier _opClassifier; + private FunctionGenerator _fg = new(); + private DescriptionGenerator _dg; + + public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) + { + _basePath = basePath; + + var opDefs = Utils.ReadAllOpDefs(opDefFilename); + _opMap = opDefs.Op.ToDictionary( + x => Utils.ConvertToUnderscore(x.Name), x => x); + _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); + _dg = new DescriptionGenerator(apiDefFilesDirectory); + } + + public void WriteAll() + { + foreach(var (target, set) in _opClassifier.OpSet) + { + StringBuilder sb = new StringBuilder(); + + // Write file header. + sb.AppendLine("/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/"); + sb.AppendLine(); + + // Add commonly used namespaces. + sb.AppendLine("using Tensorflow.Eager;"); + sb.AppendLine("using Tensorflow.Contexts;"); + sb.AppendLine("using Tensorflow.Exceptions;"); + sb.AppendLine("using static Tensorflow.Binding;"); + sb.AppendLine(); + + // Specify the namespace + sb.AppendLine("namespace Tensorflow;"); + sb.AppendLine(); + + // Write class name + sb.AppendLine($"public static class {target}"); + sb.AppendLine("{"); + + foreach(var funcName in set) + { + if(_opMap.ContainsKey(funcName)) + { + var opDef = _opMap[funcName]; + + // write the descriptions. + _dg.AppendDescription(opDef, sb); + + // write the function body. + _fg.AppendFunction(opDef, sb); + } + else if (funcName.StartsWith("_")) + { + var opDef = _opMap[funcName.Substring(1)]; + _fg.AppendFunction(opDef, sb); + } + } + + // Close class scope. + sb.AppendLine("}"); + + string fullFilePath = Path.Combine(_basePath, $"{target}.cs"); + File.WriteAllText(fullFilePath, sb.ToString()); + } + } + } +} diff --git a/tools/Tensorflow.CodeGen/OpClassifier.cs b/tools/Tensorflow.CodeGen/OpClassifier.cs new file mode 100644 index 000000000..2d22c5d22 --- /dev/null +++ b/tools/Tensorflow.CodeGen/OpClassifier.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Text.RegularExpressions; + +namespace Tensorflow.CodeGen +{ + public class OpClassifier + { + private static readonly string _filenamePattern = @"^gen_[a-z_]*_ops.py$"; + private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; + private Dictionary> _opSet = new(); + public Dictionary> OpSet => _opSet; + public OpClassifier(string pythonFileFolder, IEnumerable funcNames) + { + DirectoryInfo directory = new DirectoryInfo(pythonFileFolder); + + Dictionary fileContentMap = new(); + foreach (FileInfo file in directory.GetFiles()) + { + if (Regex.IsMatch(file.Name, _filenamePattern)) + { + Console.WriteLine(file.Name); + string filenamePrefix = file.Name.Split('.')[0]; + string content = File.ReadAllText(file.FullName); + fileContentMap[filenamePrefix] = content; + } + } + + foreach(var funcName in funcNames) + { + Console.WriteLine(funcName); + string funcPattern = @$"^def\s+{funcName}\("; + string fallbackFuncPattern = @$"^def\s+{funcName}_eager_fallback\("; + foreach (var (target, content) in fileContentMap) + { + if(content.Contains($"def {funcName}") && content.Contains($"def {funcName}_eager_fallback")) + { + _opSet.SetDefault(target, new HashSet()).Add(funcName); + } + else if (content.Contains($"def _{funcName}") && content.Contains($"def _{funcName}_eager_fallback")) + { + _opSet.SetDefault(target, new HashSet()).Add(funcName); + } + } + } + } + } +} diff --git a/tools/Tensorflow.CodeGen/Program.cs b/tools/Tensorflow.CodeGen/Program.cs new file mode 100644 index 000000000..cea52e0b4 --- /dev/null +++ b/tools/Tensorflow.CodeGen/Program.cs @@ -0,0 +1,13 @@ +using OneOf.Types; +using Protobuf.Text; +using System.Diagnostics; +using System.Text; +using System.Xml.Linq; +using Tensorflow.CodeGen; + +GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2", + @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", + @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", + @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); + +writer.WriteAll(); diff --git a/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj b/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj new file mode 100644 index 000000000..2afc68a3c --- /dev/null +++ b/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj @@ -0,0 +1,18 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + + + + + diff --git a/tools/Tensorflow.CodeGen/Utils.cs b/tools/Tensorflow.CodeGen/Utils.cs new file mode 100644 index 000000000..6c69b7f95 --- /dev/null +++ b/tools/Tensorflow.CodeGen/Utils.cs @@ -0,0 +1,271 @@ +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Metadata.Ecma335; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.CodeGen +{ + public static class Utils + { + public static string ConvertToUnderscore(string input) + { + if (string.IsNullOrEmpty(input)) + { + return input; + } + + StringBuilder result = new StringBuilder(); + + int state = 1; // the previous char was not lowered. + for (int i = 0; i < input.Length; i++) + { + char current = input[i]; + + // 首字母不需要添加下划线 + if (char.IsUpper(current)) + { + if(i > 0) + { + char pre = input[i - 1]; + if (char.IsDigit(pre)) + { + result.Append(char.ToLower(current)); + continue; + } + } + if (state == 0) + { + result.Append("_"); + state = 1; + } + result.Append(char.ToLower(current)); + } + else + { + result.Append(char.ToLower(current)); + state = 0; + } + } + + return result.ToString(); + } + + public static OpList ReadAllOpDefs(string path) + { + var text = File.ReadAllText(path); + var opDefs = OpList.Parser.ParseText(text); + return opDefs; + } + + // name, type string, default value + public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary dynamicDefaultValues) + { + dynamicDefaultValues = new(); + List<(string, string, string)> res = new(); + foreach (var attr in op.Attr) + { + if (attr.Type == "type") + { + bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); + if (!found) + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); + string enumPath = typeof(TF_DataType).Name + "." + name; + res.Add((attr.Name, "TF_DataType", enumPath)); + } + else + { + res.Add((attr.Name, "TF_DataType", "NOVALUE")); + } + } + } + else if (attr.Type == "int") + { + if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) + { + continue; + } + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) + { + res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); + } + else + { + res.Add((attr.Name, "int", "0")); + } + } + else if (attr.Type == "float") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) + { + res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); + } + else + { + res.Add((attr.Name, "float", "NOVALUE")); + } + } + else if (attr.Type == "string") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) + { + res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); + } + else + { + res.Add((attr.Name, "string", "NOVALUE")); + } + } + else if (attr.Type == "bool") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) + { + res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); + } + else + { + res.Add((attr.Name, "bool", "NOVALUE")); + } + } + else if (attr.Type == "shape") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) + { + if (attr.DefaultValue.Shape.UnknownRank) + { + res.Add((attr.Name, "Shape", $"null")); + } + else + { + Shape shape = new Shape(attr.DefaultValue.Shape); + string expression = $"new Shape({string.Join(", ", shape.dims)})"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "Shape", $"null")); + } + } + else + { + res.Add((attr.Name, "Shape", "NOVALUE")); + } + } + else if (attr.Type == "list(type)") + { + if(op.InputArg.Any(x => x.TypeListAttr == attr.Name)) + { + continue; + } + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.Type) + { + values.Add(value.as_tf_dtype()); + } + string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "TF_DataType[]", $"null")); + } + else + { + res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); + } + } + else if (attr.Type == "list(shape)") + { + res.Add((attr.Name, "Shape[]", "NOVALUE")); + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List exps = new(); + foreach (var value in attr.DefaultValue.List.Shape) + { + exps.Add($"new Shape({string.Join(", ", value.Dim.Select(x => x.Size))})"); + } + string expression = "new Shape[]{" + $"{string.Join(", ", exps)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "string[]", $"null")); + } + else + { + res.Add((attr.Name, "string[]", "NOVALUE")); + } + } + else if (attr.Type == "list(string)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.S) + { + values.Add(value.ToStringUtf8()); + } + string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "string[]", $"null")); + } + else + { + res.Add((attr.Name, "string[]", "NOVALUE")); + } + } + else if (attr.Type == "list(int)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.I) + { + values.Add((int)value); + } + string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "int[]", $"null")); + } + else + { + res.Add((attr.Name, "int[]", "NOVALUE")); + } + } + else if (attr.Type == "list(float)") + { + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.F) + { + values.Add(value); + } + string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "float[]", $"null")); + } + else + { + res.Add((attr.Name, "float[]", "NOVALUE")); + } + } + else if (attr.Type == "func") + { + res.Add((attr.Name, "object", "NOVALUE")); + } + else if (attr.Type == "list(func)") + { + res.Add((attr.Name, "object[]", "NOVALUE")); + } + else if (attr.Type == "tensor") + { + res.Add((attr.Name, "TensorProto", "NOVALUE")); + } + else + { + throw new NotImplementedException(); + } + } + return res; + } + } +} diff --git a/tools/Tensorflow.Redist.NativeLibrarySplitter/Program.cs b/tools/Tensorflow.Redist.NativeLibrarySplitter/Program.cs new file mode 100644 index 000000000..cdc011ea9 --- /dev/null +++ b/tools/Tensorflow.Redist.NativeLibrarySplitter/Program.cs @@ -0,0 +1,212 @@ + +// =================================================================== // +// This is a tool to split the native .so file of linux gpu library // +// =================================================================== // + +using System.Security.Cryptography; + +string filename = "libtensorflow.so"; +int count = 5; +SplitFile(filename, count); + +static void SplitFile(string filename, int count) +{ + // 打开读取二进制文件的文件流 + using (FileStream input = new FileStream(filename, FileMode.Open, FileAccess.Read)) + { + long filesize = new FileInfo(filename).Length; // 获取文件大小 + long fragmentSize = (long)(filesize / count + 1); // 计算每个分片的大小 + + byte[] buffer = new byte[fragmentSize]; // 设置缓冲区大小 + int bytesRead; // 存储读取长度 + int fragmentIndex = 1; // 分片计数器 + + // 使用循环遍历分片并写入相应的文件 + while ((bytesRead = input.Read(buffer, 0, buffer.Length)) > 0) + { + string outputFileName = $"{filename}.fragment{fragmentIndex++}"; + using (FileStream output = new FileStream(outputFileName, FileMode.Create, FileAccess.Write)) + { + output.Write(buffer, 0, bytesRead); + } + } + + // 计算整个文件的 SHA-256 哈希值并写入 .sha 文件 + using (SHA256 sha256Hash = SHA256.Create()) + { + input.Seek(0, SeekOrigin.Begin); + byte[] hashValue = sha256Hash.ComputeHash(input); + + string shaFileName = $"{filename}.sha"; + using (StreamWriter writer = new StreamWriter(shaFileName, false)) + { + writer.Write(BitConverter.ToString(hashValue).Replace("-", "")); + } + } + } +} + +// Resume the file from fregments. Thanks for the code in TorchSharp! +static void Restitch(string RestitcherPackage) +{ + // !!!!!!!------------------------------NOTE------------------------------------!!!!!! + // !!!!!!! This code is manually copied into pkg\common\RestitchPackage.targets !!!!!! + // !!!!!!!------------------------------NOTE------------------------------------!!!!!! + // + // vvvvvvvvvvvvvvvvvvvvvvvvvvvvv START HERE vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + try + { + if (Directory.Exists(RestitcherPackage)) + { + using (var writer = File.CreateText("obj/tensorflow_redist_build_log.txt")) + { + foreach (var p in Directory.EnumerateFiles(RestitcherPackage, "*", SearchOption.AllDirectories)) + { + + var primaryFile = Path.GetFullPath(p); + writer.WriteLine("Found primary file at {0}", primaryFile); + + // See if there are fragments in the parallel nuget packages. If the primary is + // some-package-primary\runtimes\....\a.so + // some-package-primary\runtimes\....\a.so.sha + // then the expected fragments are + // some-package-fragment1\fragments\....\a.so + // some-package-fragment2\fragments\....\a.so + // some-package-fragment3\fragments\....\a.so + // some-package-fragment4\fragments\....\a.so + // some-package-fragment5\fragments\....\a.so + // some-package-fragment6\fragments\....\a.so + // some-package-fragment7\fragments\....\a.so + // some-package-fragment8\fragments\....\a.so + // some-package-fragment9\fragments\....\a.so + // some-package-fragment10\fragments\....\a.so + var shaFile = primaryFile + ".sha"; + var fragmentFile1 = primaryFile.Replace("-primary", "-fragment1").Replace("runtimes", "fragments") + ".fragment1"; + var fragmentFile2 = primaryFile.Replace("-primary", "-fragment2").Replace("runtimes", "fragments") + ".fragment2"; + var fragmentFile3 = primaryFile.Replace("-primary", "-fragment3").Replace("runtimes", "fragments") + ".fragment3"; + var fragmentFile4 = primaryFile.Replace("-primary", "-fragment4").Replace("runtimes", "fragments") + ".fragment4"; + var fragmentFile5 = primaryFile.Replace("-primary", "-fragment5").Replace("runtimes", "fragments") + ".fragment5"; + + + if (File.Exists(fragmentFile1)) writer.WriteLine("Found fragment file at {0}", fragmentFile1); + if (File.Exists(fragmentFile2)) writer.WriteLine("Found fragment file at {0}", fragmentFile2); + if (File.Exists(fragmentFile3)) writer.WriteLine("Found fragment file at {0}", fragmentFile3); + if (File.Exists(fragmentFile4)) writer.WriteLine("Found fragment file at {0}", fragmentFile4); + if (File.Exists(fragmentFile5)) writer.WriteLine("Found fragment file at {0}", fragmentFile5); + + if (File.Exists(fragmentFile1)) + { + var tmpFile = Path.GetTempFileName(); + + { + writer.WriteLine("Writing restored primary file at {0}", tmpFile); + using (var os = File.OpenWrite(tmpFile)) + { + + //writer.WriteLine("Writing bytes from {0} to {1}", primaryFile, tmpFile); + //var primaryBytes = File.ReadAllBytes(primaryFile); + + //os.Write(primaryBytes, 0, primaryBytes.Length); + if (File.Exists(fragmentFile1)) + { + writer.WriteLine("Writing fragment bytes from {0} to {1}", fragmentFile1, tmpFile); + var fragmentBytes1 = File.ReadAllBytes(fragmentFile1); + os.Write(fragmentBytes1, 0, fragmentBytes1.Length); + } + if (File.Exists(fragmentFile2)) + { + writer.WriteLine("Writing fragment bytes from {0} to {1}", fragmentFile2, tmpFile); + var fragmentBytes2 = File.ReadAllBytes(fragmentFile2); + os.Write(fragmentBytes2, 0, fragmentBytes2.Length); + } + if (File.Exists(fragmentFile3)) + { + writer.WriteLine("Writing fragment bytes from {0} to {1}", fragmentFile3, tmpFile); + var fragmentBytes3 = File.ReadAllBytes(fragmentFile3); + os.Write(fragmentBytes3, 0, fragmentBytes3.Length); + } + if (File.Exists(fragmentFile4)) + { + writer.WriteLine("Writing fragment bytes from {0} to {1}", fragmentFile4, tmpFile); + var fragmentBytes4 = File.ReadAllBytes(fragmentFile4); + os.Write(fragmentBytes4, 0, fragmentBytes4.Length); + } + if (File.Exists(fragmentFile5)) + { + writer.WriteLine("Writing fragment bytes from {0} to {1}", fragmentFile5, tmpFile); + var fragmentBytes5 = File.ReadAllBytes(fragmentFile5); + os.Write(fragmentBytes5, 0, fragmentBytes5.Length); + } + } + } + + var shaExpected = File.Exists(shaFile) ? File.ReadAllText(shaFile).ToUpper() : ""; + writer.WriteLine($"real sha: {shaExpected}"); + + using (var sha256Hash = System.Security.Cryptography.SHA256.Create()) + { + using (var os2 = File.OpenRead(tmpFile)) + { + + byte[] bytes = sha256Hash.ComputeHash(os2); + var builder = new System.Text.StringBuilder(); + for (int i = 0; i < bytes.Length; i++) + { + builder.Append(bytes[i].ToString("x2")); + } + var shaReconstituted = builder.ToString().ToUpper(); + if (shaExpected != shaReconstituted) + { + string msg = + $"Error downloading and reviving packages. Reconsituted file contents have incorrect SHA\n\tExpected SHA: ${shaExpected}\n\tActual SHA: ${shaReconstituted}\n\tFile was reconstituted from:" + + $"\n\t{primaryFile} (length ${new FileInfo(primaryFile).Length})" + + (File.Exists(fragmentFile1) ? $"\n\t{fragmentFile1} (length ${new FileInfo(fragmentFile1).Length})" : "") + + (File.Exists(fragmentFile2) ? $"\n\t{fragmentFile2} (length ${new FileInfo(fragmentFile2).Length})" : "") + + (File.Exists(fragmentFile3) ? $"\n\t{fragmentFile3} (length ${new FileInfo(fragmentFile3).Length})" : "") + + (File.Exists(fragmentFile4) ? $"\n\t{fragmentFile4} (length ${new FileInfo(fragmentFile4).Length})" : "") + + (File.Exists(fragmentFile5) ? $"\n\t{fragmentFile5} (length ${new FileInfo(fragmentFile5).Length})" : ""); + writer.WriteLine(msg); + throw new Exception(msg); + } + } + + } + + writer.WriteLine("Deleting {0}", primaryFile); + File.Delete(primaryFile); + if (File.Exists(primaryFile)) + throw new Exception("wtf?"); + + writer.WriteLine("Moving {0} --> {1}", tmpFile, primaryFile); + File.Move(tmpFile, primaryFile); + + writer.WriteLine("Deleting {0}", fragmentFile1); + File.Delete(fragmentFile1); // free up space and prevent us doing this again + + writer.WriteLine("Deleting {0}", fragmentFile2); + if (File.Exists(fragmentFile2)) + File.Delete(fragmentFile2); // free up space and prevent us doing this again + + writer.WriteLine("Deleting {0}", fragmentFile3); + if (File.Exists(fragmentFile3)) + File.Delete(fragmentFile3); // free up space and prevent us doing this again + + writer.WriteLine("Deleting {0}", fragmentFile4); + if (File.Exists(fragmentFile4)) + File.Delete(fragmentFile4); // free up space and prevent us doing this again + + writer.WriteLine("Deleting {0}", fragmentFile5); + if (File.Exists(fragmentFile5)) + File.Delete(fragmentFile5); // free up space and prevent us doing this again + } + } + } + } + } + catch (Exception ex) + { + Console.Error.WriteLine(ex.ToString()); + Console.Error.WriteLine(ex.StackTrace); + } + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ END HERE^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +} \ No newline at end of file diff --git a/tools/Tensorflow.Redist.NativeLibrarySplitter/Tensorflow.Redist.NativeLibrarySplitter.csproj b/tools/Tensorflow.Redist.NativeLibrarySplitter/Tensorflow.Redist.NativeLibrarySplitter.csproj new file mode 100644 index 000000000..74abf5c97 --- /dev/null +++ b/tools/Tensorflow.Redist.NativeLibrarySplitter/Tensorflow.Redist.NativeLibrarySplitter.csproj @@ -0,0 +1,10 @@ + + + + Exe + net6.0 + enable + enable + + + diff --git a/tools/Tensorflow.UnitTest.RedistHolder/EmptyClass.cs b/tools/Tensorflow.UnitTest.RedistHolder/EmptyClass.cs new file mode 100644 index 000000000..563f18b8f --- /dev/null +++ b/tools/Tensorflow.UnitTest.RedistHolder/EmptyClass.cs @@ -0,0 +1,3 @@ +internal class EmptyClass +{ +} diff --git a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj new file mode 100644 index 000000000..0d1018cab --- /dev/null +++ b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + + + + + + + + diff --git a/tools/scripts/Copy-NativeTensorFlowLibs.ps1 b/tools/scripts/Copy-NativeTensorFlowLibs.ps1 new file mode 100644 index 000000000..cf6521ae8 --- /dev/null +++ b/tools/scripts/Copy-NativeTensorFlowLibs.ps1 @@ -0,0 +1,167 @@ +<# +.SYNOPSIS + Copy the native TensorFlow library to enable the packing a nuget to make + them available to TensorFlow.NET + +.DESCRIPTION + The TensorFlow libraries are copied for Windows and Linux and it becomes + possible to bundle a meta-package containing them. + +.PARAMETER SkipCpuLibraries + Setting this to true skips the downloading of the CPU version of the + TensorFlow libraries. + By default the CPU version of the libraries are downloaded and put in the + relevant projects. + +.PARAMETER SkipGpuLibraries + Setting this to tru skips the downloading of the GPU version of the + TensorFlow libraries. + By default the GPU version of the libraries are downloaded and put in the + releavant projects. + +#> +param( + [switch] $SkipCpuLibraries = $false, + [switch] $SkipGpuLibraries = $false +) + +function Expand-TarGzFiles { + <# + .SYNOPSIS + Expands the given list of files from the given archive into the given + target directory. + + .PARAMETER Archive + Path to the archive that should be considered. + + .PARAMETER Files + Files that should be extracted from the archive. + + .PARAMETER TargetDirectory + Directory into which the files should be expanded. + + #> + param + ( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $Archive, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string []] $Files, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $TargetDirectory + ) + + & 7z e $Archive -o"$TargetDirectory" + $TarArchive = Join-Path $TargetDirectory "libtensorflow.tar" + + & 7z e $TarArchive $Files -o"$TargetDirectory" + Remove-Item $TarArchive +} + +function Expand-ZipFiles { + <# + .SYNOPSIS + Expands the given list of files from the given archive into the given target directory. + + .PARAMETER Archive + Path to the archive that should be considered. + + .PARAMETER Files + Files that should be extracted from the archive. + + .PARAMETER TargetDirectory + Directory into which the files should be expanded. + #> + param( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $Archive, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string []] $Files, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $TargetDirectory + ) + + & 7z e $Archive $Files -o"$TargetDirectory" +} + +function Split-ArchiveFromUrl { + <# + .SYNOPSIS + Extracts the archive name out of the given Url. + + .PARAMETER ArchiveUrl + Url of the archive that will be downloaded. + + #> + param( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $ArchiveUrl + ) + + $uriParts = $ArchiveUrl.split("/") + $ArchivePath = $uriParts[$uriParts.Count - 1] + + return $ArchivePath +} + +function Copy-Archive { + <# + .SYNOPSIS + This function copies the given binary file to the given target location. + + .PARAMETER ArchiveUrl + Url where the archive should be downloaded from. + + .PARAMETER TargetDirectory + Target directory where the archive should be downloaded. +#> + param ( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] + [string] $ArchiveUrl, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] + [string] $TargetDirectory + ) + + $ArchiveName = Split-ArchiveFromUrl $ArchiveUrl + + $TargetPath = [IO.Path]::Combine($PSScriptRoot, "..", "packages", $ArchiveName) + + if (Test-Path $TargetPath -PathType Leaf) { + Write-Error "$TargetPath already exists, please remove to download againg." + return $TargetPath + } + + if (-not (Test-Path $TargetDirectory -PathType Container)) { + Write-Host "Creating missing $TargetDirectory" + New-Item -Path $TargetDirectory -ItemType Directory + } + Write-Host "Downloading $ArchiveUrl, this might take a while..." + $wc = New-Object System.Net.WebClient + $wc.DownloadFile($ArchiveUrl, $TargetPath) + + return $TargetPath +} + +$LinuxGpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz" +$LinuxCpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz" +$LinuxFiles = @(".\libtensorflow.tar", ".\lib\libtensorflow.so", ".\lib\libtensorflow.so.1", ".\lib\libtensorflow.so.1.14.0", ` + ".\lib\libtensorflow_framework.so", ".\lib\libtensorflow_framework.so.1", ".\lib\libtensorflow_framework.so.1.14.0") +$WindowsGpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip" +$WindowsCpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip" +$WindowsFiles = @("lib\tensorflow.dll") +$PackagesDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "packages") + + +if (-not $SkipGpuLibraries) { + $Archive = Copy-Archive -ArchiveUrl $WindowsGpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist") + Expand-ZipFiles $Archive $WindowsFiles $TargetDirectory + + $Archive = Copy-Archive -ArchiveUrl $LinuxGpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist") + Expand-TarGzFiles $Archive $LinuxFiles $TargetDirectory +} + +if (-not $SkipCpuLibraries) { + $Archive = Copy-Archive -ArchiveUrl $WindowsCpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist") + Expand-ZipFiles $Archive $WindowsFiles $TargetDirectory + + $Archive = Copy-Archive -ArchiveUrl $LinuxCpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist") + Expand-TarGzFiles $Archive $LinuxFiles $TargetDirectory +} + diff --git a/tools/tensorflowlib/README.md b/tools/tensorflowlib/README.md new file mode 100644 index 000000000..ae04c3988 --- /dev/null +++ b/tools/tensorflowlib/README.md @@ -0,0 +1,88 @@ +TensorFlow.NET pack all required libraries in architecture-specific assemblies folders per NuGet standard. + +```powershell +PM> Install-Package TensorFlow.NET +PM> Install-Package SciSharp.TensorFlow.Redist +``` + +Add `win-x64` to a `PropertyGroup` in your `.csproj` when targeting `.NET 472`. + +### Run in Linux + +Download Linux pre-built library and unzip `libtensorflow.so` and `libtensorflow_framework.so` into current running directory. + +To run image recognition in Linux, please ensure some prerequisite libraries is install. + +```shell +sudo apt install libc6-dev +sudo apt install libgdiplus +``` + +More information about [System.Drawing on Linux](). + +### Run TensorFlow with GPU +Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. + +#### Mac OS +There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos). + +#### GPU for Windows + +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU +``` + +#### GPU for Linux +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU +``` + +Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow). + +### Download prebuild binary manually + +TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. + + +### Build from source for Windows + +https://www.tensorflow.org/install/source_windows + +Download [Bazel 2.0.0](https://github.com/bazelbuild/bazel/releases/tag/2.0.0) to build tensorflow2.x. We build customized binary to export c_api from this [fork](https://github.com/SciSharp/tensorflow). + +Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC`. + +`pacman -S git patch unzip` + +1. Build static library + +`bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow` + +2. Build pip package + +`bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package` + +3. Generate pip installation file + +`bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg` + +4. Install from local wheel file. + +`pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` + +### Build from source for MacOS + +```shell +$ cd /usr/local/lib/bazel/bin +$ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64 +$ chmod +x bazel-3.7.2-darwin-x86_64 +$ cd ~/Projects/tensorflow +$ bazel build --config=opt //tensorflow:tensorflow +``` + +### Build specific version for tf.net + +https://github.com/SciSharp/tensorflow + +For Linux version, these APIs symbols should also be put into `tensorflow/c/version_script.lds` to be exported. +Please refer to commit `https://github.com/SciSharp/tensorflow/commit/58122da06be3e7707500ad889dfd5c760a3e0424` \ No newline at end of file